# PCA Analysis: Pre vs Post RL

This notebook compares the persona representation before and after RL training.

**Goal**: Understand how RL changes the overall structure of persona space by comparing:
1. Default vector positions
2. Role vector distributions
3. Assistant Axis alignment

In [None]:
import sys
sys.path.insert(0, '../..')

import torch
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from huggingface_hub import snapshot_download, hf_hub_download
from huggingface_hub.utils import disable_progress_bars
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from tqdm.auto import tqdm

from assistant_axis import load_axis, compute_axis, compute_pca, MeanScaler, project
from assistant_axis.internals import ProbingModel, ConversationEncoder, ActivationExtractor

disable_progress_bars()

## Configuration

In [None]:
# Model configuration
BASE_MODEL = "google/gemma-2-27b-it"
MODEL_NAME = "gemma-2-27b"  # For loading pre-computed vectors
TARGET_LAYER = 22

# Path to final RL checkpoint
RL_MODEL_PATH = Path("../checkpoints/final")  # Update this

# HuggingFace for pre-computed baseline vectors
REPO_ID = "lu-christina/assistant-axis-vectors"

# Output directory
OUTPUT_DIR = Path("../outputs/after_rl")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

## Load Pre-RL Baseline Data

In [None]:
print(f"Loading baseline from HuggingFace: {REPO_ID}")

# Download all vectors
local_dir = snapshot_download(
    repo_id=REPO_ID,
    repo_type="dataset",
    allow_patterns=[f"{MODEL_NAME}/role_vectors/*.pt", 
                    f"{MODEL_NAME}/default_vector.pt",
                    f"{MODEL_NAME}/assistant_axis.pt"]
)

# Load role vectors
pre_rl_role_vectors = {p.stem: torch.load(p, map_location="cpu", weights_only=False)
                       for p in Path(local_dir, MODEL_NAME, "role_vectors").glob("*.pt")}
print(f"Loaded {len(pre_rl_role_vectors)} pre-RL role vectors")

# Load default vector
pre_rl_default = torch.load(Path(local_dir, MODEL_NAME, "default_vector.pt"), 
                            map_location="cpu", weights_only=False)
print(f"Pre-RL default vector shape: {pre_rl_default.shape}")

# Load pre-computed axis
pre_rl_axis = torch.load(Path(local_dir, MODEL_NAME, "assistant_axis.pt"),
                         map_location="cpu", weights_only=False)
print(f"Pre-RL axis shape: {pre_rl_axis.shape}")

## Extract Post-RL Vectors

Extract default and role vectors from the RL-trained model.

In [None]:
# Sample prompts for extracting vectors
DEFAULT_PROMPTS = [
    "You are a helpful AI assistant.",
    "You are an AI assistant.",
    "You are a helpful assistant.",
    "You are an AI assistant here to help.",
    "You are a helpful AI.",
]

SAMPLE_QUESTIONS = [
    "What is the capital of France?",
    "Explain photosynthesis briefly.",
    "What are prime numbers?",
    "How does gravity work?",
    "What is machine learning?",
    "Describe the water cycle.",
    "What causes seasons?",
    "How do computers work?",
]

# Create conversations
default_conversations = [
    [{"role": "system", "content": prompt}, {"role": "user", "content": q}]
    for prompt in DEFAULT_PROMPTS
    for q in SAMPLE_QUESTIONS[:3]
]

In [None]:
# Extract post-RL default vector (or load if cached)
post_rl_default_path = OUTPUT_DIR / "post_rl_default_vector.pt"

def extract_default_vector(model_path, conversations, layers=None):
    """Extract mean default activations from a model."""
    pm = ProbingModel(str(model_path))
    encoder = ConversationEncoder(pm)
    extractor = ActivationExtractor(pm, encoder)
    
    all_activations = []
    for conv in tqdm(conversations, desc="Extracting activations"):
        # Get full activations: (num_layers, num_tokens, hidden_size)
        acts = extractor.full_conversation(conv, layer=layers)
        # Mean over tokens to get (num_layers, hidden_size)
        mean_acts = acts.mean(dim=1)
        all_activations.append(mean_acts)
    
    # Mean across conversations: (num_layers, hidden_size)
    mean_activation = torch.stack(all_activations).mean(dim=0)
    
    pm.close()
    torch.cuda.empty_cache()
    
    return mean_activation

if post_rl_default_path.exists():
    print(f"Loading cached post-RL default vector")
    post_rl_default = torch.load(post_rl_default_path, map_location="cpu", weights_only=False)
else:
    if RL_MODEL_PATH.exists():
        print(f"Extracting post-RL default vector from {RL_MODEL_PATH}")
        post_rl_default = extract_default_vector(RL_MODEL_PATH, default_conversations)
        
        torch.save(post_rl_default, post_rl_default_path)
        print(f"Saved to {post_rl_default_path}")
    else:
        print(f"RL model not found at {RL_MODEL_PATH}")
        print("Using pre-RL default as placeholder for demonstration")
        post_rl_default = pre_rl_default.clone()

print(f"Post-RL default vector shape: {post_rl_default.shape}")

## Compare Default Vector Positions

In [None]:
# Fit PCA on pre-RL role vectors
role_labels = list(pre_rl_role_vectors.keys())
role_vectors_at_layer = torch.stack([v[TARGET_LAYER] for v in pre_rl_role_vectors.values()]).float()

scaler = MeanScaler()
pca_transformed, variance_explained, n_components, pca, scaler = compute_pca(
    role_vectors_at_layer, layer=None, scaler=scaler
)

print(f"PCA fitted on {len(role_labels)} role vectors")
print(f"Top 3 PCs explain {sum(variance_explained[:3])*100:.1f}% variance")

In [None]:
# Project both default vectors into PCA space
def project_to_pca(vector, target_layer, scaler, pca):
    vec = vector[target_layer].float().numpy().reshape(1, -1)
    vec_scaled = scaler.transform(vec)
    return pca.transform(vec_scaled)[0]

pre_rl_pca = project_to_pca(pre_rl_default, TARGET_LAYER, scaler, pca)
post_rl_pca = project_to_pca(post_rl_default, TARGET_LAYER, scaler, pca)

print("\nDefault vector PCA projections (top 3 PCs):")
print(f"  Pre-RL:  [{pre_rl_pca[0]:+.4f}, {pre_rl_pca[1]:+.4f}, {pre_rl_pca[2]:+.4f}]")
print(f"  Post-RL: [{post_rl_pca[0]:+.4f}, {post_rl_pca[1]:+.4f}, {post_rl_pca[2]:+.4f}]")
print(f"\nEuclidean distance: {np.linalg.norm(post_rl_pca[:3] - pre_rl_pca[:3]):.4f}")

## Visualize Pre vs Post RL in Persona Space

In [None]:
def plot_pre_post_comparison(pca_transformed, role_labels, pre_rl_pca, post_rl_pca,
                              figsize=(14, 5)):
    """Compare pre and post RL default vectors in persona space."""
    
    fig, axes = plt.subplots(1, 3, figsize=figsize)
    
    pc_pairs = [(0, 1, 'PC1 vs PC2'), (0, 2, 'PC1 vs PC3'), (1, 2, 'PC2 vs PC3')]
    
    for ax, (pc_x, pc_y, title) in zip(axes, pc_pairs):
        # Plot role vectors
        ax.scatter(pca_transformed[:, pc_x], pca_transformed[:, pc_y],
                   c='lightgray', s=30, alpha=0.5, label='Role vectors')
        
        # Plot pre-RL default
        ax.scatter(pre_rl_pca[pc_x], pre_rl_pca[pc_y], 
                   c='blue', s=200, marker='*', edgecolors='black',
                   linewidth=1, zorder=5, label='Pre-RL Default')
        
        # Plot post-RL default
        ax.scatter(post_rl_pca[pc_x], post_rl_pca[pc_y], 
                   c='red', s=200, marker='*', edgecolors='black',
                   linewidth=1, zorder=5, label='Post-RL Default')
        
        # Draw arrow from pre to post
        ax.annotate('', xy=(post_rl_pca[pc_x], post_rl_pca[pc_y]),
                    xytext=(pre_rl_pca[pc_x], pre_rl_pca[pc_y]),
                    arrowprops=dict(arrowstyle='->', color='green', lw=2))
        
        ax.set_xlabel(f'PC{pc_x + 1}')
        ax.set_ylabel(f'PC{pc_y + 1}')
        ax.set_title(title)
        ax.grid(True, alpha=0.3)
    
    axes[0].legend(loc='upper left')
    plt.suptitle('Default Persona: Pre-RL vs Post-RL', fontsize=14, fontweight='bold')
    plt.tight_layout()
    return fig

fig = plot_pre_post_comparison(pca_transformed, role_labels, pre_rl_pca, post_rl_pca)
plt.show()

## Analyze Projection onto Assistant Axis

In [None]:
# Project both default vectors onto the pre-RL Assistant Axis
pre_rl_proj = project(pre_rl_default, pre_rl_axis, layer=TARGET_LAYER)
post_rl_proj = project(post_rl_default, pre_rl_axis, layer=TARGET_LAYER)

print("Projection onto Pre-RL Assistant Axis:")
print(f"  Pre-RL default:  {pre_rl_proj:.4f}")
print(f"  Post-RL default: {post_rl_proj:.4f}")
print(f"  Change: {post_rl_proj - pre_rl_proj:+.4f}")

if post_rl_proj > pre_rl_proj:
    print("\n→ Post-RL model is MORE Assistant-like on the original axis")
elif post_rl_proj < pre_rl_proj:
    print("\n→ Post-RL model is LESS Assistant-like on the original axis")
else:
    print("\n→ No change in Assistant-likeness")

## Compute Cosine Similarity Between Axes

If we compute a new axis from the post-RL model, how aligned is it with the original?

In [None]:
# This requires extracting role vectors from post-RL model
# For now, we'll demonstrate with the pre-RL axis alignment check

# Compute axis direction alignment with PC1 (which approximates the axis)
pc1_direction = pca.components_[0]
pc1_direction = pc1_direction / np.linalg.norm(pc1_direction)

pre_rl_axis_at_layer = pre_rl_axis[TARGET_LAYER].float().numpy()
pre_rl_axis_norm = pre_rl_axis_at_layer / np.linalg.norm(pre_rl_axis_at_layer)

axis_pc1_cosine = np.dot(pre_rl_axis_norm, pc1_direction)
print(f"Cosine similarity between Assistant Axis and PC1: {axis_pc1_cosine:.4f}")
print("(High similarity confirms PC1 approximates the Assistant direction)")

## Summary Statistics

In [None]:
# Compute various comparison metrics
pre_vec = pre_rl_default[TARGET_LAYER].float()
post_vec = post_rl_default[TARGET_LAYER].float()

# Cosine similarity between default vectors
default_cosine = F.cosine_similarity(pre_vec.unsqueeze(0), post_vec.unsqueeze(0)).item()

# L2 distance
l2_dist = torch.norm(post_vec - pre_vec).item()

# Relative norm change
pre_norm = torch.norm(pre_vec).item()
post_norm = torch.norm(post_vec).item()
norm_change = (post_norm - pre_norm) / pre_norm * 100

print("=" * 50)
print("Summary: Pre-RL vs Post-RL Default Vector")
print("=" * 50)
print(f"Cosine similarity:     {default_cosine:.4f}")
print(f"L2 distance:           {l2_dist:.4f}")
print(f"Pre-RL norm:           {pre_norm:.4f}")
print(f"Post-RL norm:          {post_norm:.4f}")
print(f"Norm change:           {norm_change:+.2f}%")
print(f"\nPCA space distance (3D): {np.linalg.norm(post_rl_pca[:3] - pre_rl_pca[:3]):.4f}")
print(f"Axis projection change:  {post_rl_proj - pre_rl_proj:+.4f}")

## Interpretation Guide

| Metric | Interpretation |
|--------|----------------|
| **Cosine similarity** | High (>0.95) = RL preserved direction; Low (<0.9) = significant shift |
| **L2 distance** | Magnitude of change in activation space |
| **Axis projection change** | Positive = more Assistant-like; Negative = drifted away |
| **PCA distance** | Movement in persona space (top 3 PCs) |

### Key Questions Answered:

1. **Did RL change the model's default persona?** → Check cosine similarity and PCA distance
2. **Did it become more or less Assistant-like?** → Check axis projection change
3. **Is the original axis still valid?** → If cosine similarity is high, the axis direction is preserved

## Next Steps

1. **Full role vector extraction**: Extract all 275 role vectors from post-RL model to compute a new axis
2. **Compare axes**: Measure alignment between pre-RL and post-RL Assistant Axes
3. **Track specific roles**: See if certain personas become more/less accessible after RL