# Visualize Assistant Axis

This notebook loads a computed assistant axis and visualizes its properties.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

from assistant_axis import load_axis, axis_norm_per_layer

## Load the Axis

In [None]:
# Update this path to your computed axis file
AXIS_PATH = "../outputs/gemma-2-27b/axis.pt"

# Load axis
axis_data = torch.load(AXIS_PATH, map_location="cpu", weights_only=False)

axis = axis_data["axis"]
default_mean = axis_data.get("default_mean")
role_mean = axis_data.get("role_mean")

print(f"Axis shape: {axis.shape}")
print(f"Number of layers: {axis.shape[0]}")
print(f"Hidden dimension: {axis.shape[1]}")

## Plot Axis Norm per Layer

In [None]:
norms = axis_norm_per_layer(axis)

fig = go.Figure()
fig.add_trace(go.Scatter(
    x=list(range(len(norms))),
    y=norms,
    mode='lines+markers',
    name='Axis Norm'
))

fig.update_layout(
    title='Assistant Axis Norm per Layer',
    xaxis_title='Layer',
    yaxis_title='L2 Norm',
    width=800,
    height=500
)

fig.show()

## Compare Default and Role Means

In [None]:
if default_mean is not None and role_mean is not None:
    default_norms = default_mean.norm(dim=1).numpy()
    role_norms = role_mean.norm(dim=1).numpy()
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=list(range(len(default_norms))),
        y=default_norms,
        mode='lines+markers',
        name='Default Mean'
    ))
    fig.add_trace(go.Scatter(
        x=list(range(len(role_norms))),
        y=role_norms,
        mode='lines+markers',
        name='Role Mean'
    ))
    
    fig.update_layout(
        title='Mean Activation Norms per Layer',
        xaxis_title='Layer',
        yaxis_title='L2 Norm',
        width=800,
        height=500
    )
    
    fig.show()
else:
    print("Default/role means not available in axis file")

## Identify Target Layer

The target layer is typically where the axis has the largest norm (most separation).

In [None]:
target_layer = norms.argmax()
max_norm = norms.max()

print(f"Recommended target layer: {target_layer}")
print(f"Maximum axis norm: {max_norm:.4f}")