# Visualize Assistant Axis

This notebook loads a computed assistant axis and visualizes its properties.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from huggingface_hub import hf_hub_download

from assistant_axis import load_axis, axis_norm_per_layer

## Load the Axis

In [None]:
MODEL_NAME = "gemma-2-27b"
REPO_ID = "lu-christina/assistant-axis-vectors"

# Load axis from HuggingFace
axis_path = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME}/assistant_axis.pt", repo_type="dataset")
axis_data = torch.load(axis_path, map_location="cpu", weights_only=False)

axis = axis_data["axis"]

# Load default vector for visualization
default_path = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME}/default_vector.pt", repo_type="dataset")
default_data = torch.load(default_path, map_location="cpu", weights_only=False)
default_mean = default_data["vector"]

print(f"Axis shape: {axis.shape}")
print(f"Number of layers: {axis.shape[0]}")
print(f"Hidden dimension: {axis.shape[1]}")

## Plot Axis Norm per Layer

In [None]:
norms = axis_norm_per_layer(axis)

fig = go.Figure()
fig.add_trace(go.Scatter(
    x=list(range(len(norms))),
    y=norms,
    mode='lines+markers',
    name='Axis Norm'
))

fig.update_layout(
    title='Assistant Axis Norm per Layer',
    xaxis_title='Layer',
    yaxis_title='L2 Norm',
    width=800,
    height=500
)

fig.show()

## Compare Default and Role Means

In [None]:
if default_mean is not None:
    default_norms = default_mean.norm(dim=1).numpy()
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=list(range(len(default_norms))),
        y=default_norms,
        mode='lines+markers',
        name='Default Mean'
    ))
    
    fig.update_layout(
        title='Default Mean Activation Norms per Layer',
        xaxis_title='Layer',
        yaxis_title='L2 Norm',
        width=800,
        height=500
    )
    
    fig.show()
else:
    print("Default mean not available")

## Identify Target Layer

The target layer is typically where the axis has the largest norm (most separation).

In [None]:
target_layer = norms.argmax()
max_norm = norms.max()

print(f"Recommended target layer: {target_layer}")
print(f"Maximum axis norm: {max_norm:.4f}")