# PCA Analysis

This notebook runs PCA on role activations and compares PC1 with the assistant axis.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import plotly.graph_objects as go
from pathlib import Path
from tqdm import tqdm

from assistant_axis import (
    load_axis,
    compute_pca,
    plot_variance_explained,
    cosine_similarity_per_layer,
    MeanScaler
)

## Load Activations

In [None]:
# Update these paths
ACTIVATIONS_DIR = Path("../outputs/gemma-2-27b/activations")
AXIS_PATH = "../outputs/gemma-2-27b/axis.pt"
TARGET_LAYER = 22  # Update for your model

# Load all activations
activation_files = sorted(ACTIVATIONS_DIR.glob("*.pt"))
print(f"Found {len(activation_files)} activation files")

# Stack activations at target layer
all_activations = []
labels = []

for act_file in tqdm(activation_files, desc="Loading activations"):
    role = act_file.stem
    acts = torch.load(act_file, map_location="cpu", weights_only=False)
    
    for key, act in acts.items():
        all_activations.append(act[TARGET_LAYER])  # (hidden_dim,)
        labels.append(f"{role}:{key}")

activations = torch.stack(all_activations)  # (n_samples, hidden_dim)
print(f"Loaded {len(activations)} activations, shape: {activations.shape}")

## Run PCA

In [None]:
# Run PCA with mean centering
pca_result, variance_explained, n_components, pca, scaler = compute_pca(
    activations,
    layer=None,  # Already selected layer
    scaler=MeanScaler()
)

In [None]:
# Plot variance explained
fig = plot_variance_explained(
    variance_explained,
    title="PCA Variance Explained",
    max_components=50
)
fig.show()

## Compare PC1 with Axis

In [None]:
# Load axis
axis = load_axis(AXIS_PATH)
print(f"Axis shape: {axis.shape}")

# Get PC1 at target layer
pc1 = torch.tensor(pca.components_[0])  # (hidden_dim,)

# Get axis at target layer
axis_layer = axis[TARGET_LAYER]  # (hidden_dim,)

# Compute cosine similarity
pc1_norm = pc1 / pc1.norm()
axis_norm = axis_layer / axis_layer.norm()
cosine_sim = float(pc1_norm @ axis_norm)

print(f"\nCosine similarity between PC1 and Axis at layer {TARGET_LAYER}: {cosine_sim:.4f}")

## Cosine Similarity Across Layers

To compare PC1 with the axis across all layers, we'd need to run PCA at each layer.

In [None]:
# This is a simplified version - for full layer-wise comparison,
# you'd need to load activations for all layers and run PCA at each

print("\nTop 5 components variance:")
for i in range(5):
    print(f"  PC{i+1}: {variance_explained[i]*100:.2f}%")

print(f"\nTotal variance in top 5: {sum(variance_explained[:5])*100:.2f}%")