# PCA Analysis During RL Training

This notebook tracks how the model's persona representation changes across RL training checkpoints.

**Goal**: Understand whether RL training causes persona drift by projecting checkpoint activations onto the original persona PCA space.

In [1]:
import sys
sys.path.insert(0, '../..')

import torch
import numpy as np
from pathlib import Path
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from tqdm.auto import tqdm

from assistant_axis import load_axis, compute_pca, MeanScaler
from assistant_axis.internals import ProbingModel, ConversationEncoder, ActivationExtractor

disable_progress_bars()

## Configuration

In [2]:
# Model configuration
BASE_MODEL = "google/gemma-2-27b-it"
MODEL_NAME = "gemma-2-27b"  # For loading pre-computed vectors
TARGET_LAYER = 22

# Checkpoint paths - update these after running RL training
CHECKPOINTS_DIR = Path("../checkpoints")
CHECKPOINT_STEPS = [100, 200, 500, 1000, 2000]  # Steps to analyze

# HuggingFace for pre-computed baseline vectors
REPO_ID = "lu-christina/assistant-axis-vectors"

# Output directory for extracted activations
OUTPUT_DIR = Path("../outputs/during_rl")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

## Load Baseline Role Vectors (Pre-RL)

Load the original role vectors computed from the base Gemma 2 model to establish the PCA space.

In [3]:
print(f"Loading baseline vectors from HuggingFace: {REPO_ID}")

# Download all vectors for this model
local_dir = snapshot_download(
    repo_id=REPO_ID,
    repo_type="dataset",
    allow_patterns=[f"{MODEL_NAME}/role_vectors/*.pt", f"{MODEL_NAME}/default_vector.pt"]
)

# Load role vectors
role_vectors = {p.stem: torch.load(p, map_location="cpu", weights_only=False)
                for p in Path(local_dir, MODEL_NAME, "role_vectors").glob("*.pt")}
print(f"Loaded {len(role_vectors)} role vectors")

# Load default vector
default_vector = torch.load(Path(local_dir, MODEL_NAME, "default_vector.pt"), map_location="cpu", weights_only=False)
print(f"Default vector shape: {default_vector.shape}")

Loading baseline vectors from HuggingFace: lu-christina/assistant-axis-vectors
Loaded 275 role vectors
Default vector shape: torch.Size([46, 4608])


## Fit PCA on Baseline Vectors

Create the persona space from the original model's role vectors.

In [4]:
# Stack role vectors at target layer
role_vectors_at_layer = torch.stack([v[TARGET_LAYER] for v in role_vectors.values()]).float()
role_labels = list(role_vectors.keys())

# Fit PCA - we'll project checkpoint vectors into this space
scaler = MeanScaler()
pca_transformed, variance_explained, n_components, pca, scaler = compute_pca(
    role_vectors_at_layer,
    layer=None,
    scaler=scaler
)

print(f"Fitted PCA with {len(variance_explained)} components")
print(f"Top 3 PCs explain {sum(variance_explained[:3])*100:.1f}% of variance")

PCA fitted with 275 components
Cumulative variance for first 5 components: [0.4880164  0.5860041  0.65581596 0.7108976  0.74415094]

PCA Analysis Results:
Elbow point at component: 2
Dimensions for 70% variance: 4
Dimensions for 80% variance: 8
Dimensions for 90% variance: 18
Dimensions for 95% variance: 36
Fitted PCA with 275 components
Top 3 PCs explain 65.6% of variance


## Extract Default Vector from Each Checkpoint

For each RL checkpoint, extract the default response activations to see how the "Assistant" position moves in persona space.

In [5]:
# Sample conversations for extracting default activations
# Using neutral Assistant prompts
DEFAULT_CONVERSATIONS = [
    [{"role": "system", "content": "You are a helpful AI assistant."},
     {"role": "user", "content": "What is the capital of France?"}],
    [{"role": "system", "content": "You are an AI assistant."},
     {"role": "user", "content": "Explain photosynthesis briefly."}],
    [{"role": "system", "content": "You are a helpful assistant."},
     {"role": "user", "content": "What are prime numbers?"}],
    [{"role": "system", "content": "You are an AI assistant here to help."},
     {"role": "user", "content": "How does gravity work?"}],
    [{"role": "system", "content": "You are a helpful AI."},
     {"role": "user", "content": "What is machine learning?"}],
]

In [6]:
def extract_checkpoint_default_vector(checkpoint_path, conversations, layers=None):
    """Extract mean default activations from a checkpoint.
    
    Args:
        checkpoint_path: Path to the model checkpoint
        conversations: List of conversations (each is a list of {"role", "content"} dicts)
        layers: List of layer indices to extract, or None for all layers
    
    Returns:
        Mean activation tensor of shape (num_layers, hidden_size)
    """
    print(f"Loading checkpoint: {checkpoint_path}")
    pm = ProbingModel(str(checkpoint_path))
    encoder = ConversationEncoder(pm)
    extractor = ActivationExtractor(pm, encoder)
    
    # Extract activations for each conversation
    all_activations = []
    for conv in conversations:
        # Get full activations: (num_layers, num_tokens, hidden_size)
        acts = extractor.full_conversation(conv, layer=layers)
        # Mean over tokens to get (num_layers, hidden_size)
        mean_acts = acts.mean(dim=1)
        all_activations.append(mean_acts)
    
    # Mean across conversations: (num_layers, hidden_size)
    mean_activation = torch.stack(all_activations).mean(dim=0)
    
    # Clean up
    pm.close()
    torch.cuda.empty_cache()
    
    return mean_activation

In [7]:
# Extract default vectors from each checkpoint
# Skip this cell if you've already extracted and saved them

checkpoint_vectors = {}

# Add baseline (step 0)
checkpoint_vectors[0] = default_vector

for step in tqdm(CHECKPOINT_STEPS, desc="Processing checkpoints"):
    checkpoint_path = CHECKPOINTS_DIR / f"checkpoint-{step}"
    
    if not checkpoint_path.exists():
        print(f"Checkpoint not found: {checkpoint_path}, skipping...")
        continue
    
    output_path = OUTPUT_DIR / f"default_vector_step{step}.pt"
    
    if output_path.exists():
        print(f"Loading cached: {output_path}")
        checkpoint_vectors[step] = torch.load(output_path, map_location="cpu", weights_only=False)
    else:
        vec = extract_checkpoint_default_vector(checkpoint_path, DEFAULT_CONVERSATIONS, TARGET_LAYER)
        torch.save(vec, output_path)
        checkpoint_vectors[step] = vec

print(f"\nLoaded {len(checkpoint_vectors)} checkpoint vectors")

Processing checkpoints:   0%|          | 0/5 [00:00<?, ?it/s]

Checkpoint not found: ..\checkpoints\checkpoint-100, skipping...
Checkpoint not found: ..\checkpoints\checkpoint-200, skipping...
Checkpoint not found: ..\checkpoints\checkpoint-500, skipping...
Checkpoint not found: ..\checkpoints\checkpoint-1000, skipping...
Checkpoint not found: ..\checkpoints\checkpoint-2000, skipping...

Loaded 1 checkpoint vectors


## Project Checkpoints into Persona PCA Space

Transform each checkpoint's default vector into the PCA space fitted on the baseline role vectors.

In [8]:
# Project each checkpoint vector into PCA space
checkpoint_projections = {}

for step, vec in checkpoint_vectors.items():
    vec_at_layer = vec[TARGET_LAYER].float().numpy().reshape(1, -1)
    vec_scaled = scaler.transform(vec_at_layer)
    vec_pca = pca.transform(vec_scaled)
    checkpoint_projections[step] = vec_pca[0, :3]  # Top 3 PCs

print("Checkpoint projections (PC1, PC2, PC3):")
for step in sorted(checkpoint_projections.keys()):
    proj = checkpoint_projections[step]
    print(f"  Step {step:5d}: [{proj[0]:+.4f}, {proj[1]:+.4f}, {proj[2]:+.4f}]")

Checkpoint projections (PC1, PC2, PC3):
  Step     0: [+674.2837, -21.2927, +65.7775]


## Visualize Persona Drift During Training

In [9]:
def plot_persona_trajectory(checkpoint_projections, role_pca_transformed, role_labels,
                            figsize=(12, 5)):
    """Plot the trajectory of the default vector through persona space during RL."""
    
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    
    # Sort checkpoints by step
    steps = sorted(checkpoint_projections.keys())
    trajectory = np.array([checkpoint_projections[s] for s in steps])
    
    # Color map for training progress
    cmap = plt.cm.viridis
    colors = [cmap(i / (len(steps) - 1)) if len(steps) > 1 else cmap(0.5) for i in range(len(steps))]
    
    for ax_idx, (pc_x, pc_y, title) in enumerate([(0, 1, 'PC1 vs PC2'), (0, 2, 'PC1 vs PC3')]):
        ax = axes[ax_idx]
        
        # Plot role vectors as background
        ax.scatter(role_pca_transformed[:, pc_x], role_pca_transformed[:, pc_y],
                   c='lightgray', s=30, alpha=0.5, label='Role vectors')
        
        # Plot trajectory line
        ax.plot(trajectory[:, pc_x], trajectory[:, pc_y], 'k--', alpha=0.3, linewidth=1)
        
        # Plot checkpoint points
        for i, (step, color) in enumerate(zip(steps, colors)):
            marker = 'o' if i > 0 else '*'  # Star for baseline
            size = 100 if i > 0 else 200
            ax.scatter(trajectory[i, pc_x], trajectory[i, pc_y], 
                       c=[color], s=size, marker=marker, edgecolors='black', 
                       linewidth=1, zorder=5, label=f'Step {step}')
        
        ax.set_xlabel(f'PC{pc_x + 1}')
        ax.set_ylabel(f'PC{pc_y + 1}')
        ax.set_title(title)
        ax.grid(True, alpha=0.3)
    
    # Add legend
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='center right', bbox_to_anchor=(1.15, 0.5))
    
    plt.suptitle('Default Persona Trajectory During RL Training', fontsize=14, fontweight='bold')
    plt.tight_layout()
    return fig


# Plot if we have checkpoint data
if len(checkpoint_projections) > 1:
    fig = plot_persona_trajectory(checkpoint_projections, pca_transformed, role_labels)
    plt.show()
else:
    print("Need at least 2 checkpoints to visualize trajectory.")
    print("Run RL training first, then re-run the extraction cells above.")

Need at least 2 checkpoints to visualize trajectory.
Run RL training first, then re-run the extraction cells above.


## Track PC1 (Assistant Axis Proxy) Over Training

In [10]:
def plot_pc1_over_training(checkpoint_projections, figsize=(10, 4)):
    """Plot PC1 projection over training steps."""
    steps = sorted(checkpoint_projections.keys())
    pc1_values = [checkpoint_projections[s][0] for s in steps]
    
    fig, ax = plt.subplots(figsize=figsize)
    
    ax.plot(steps, pc1_values, 'o-', markersize=8, linewidth=2, color='#1a5276')
    ax.axhline(y=pc1_values[0], color='gray', linestyle='--', alpha=0.5, label='Baseline')
    
    ax.set_xlabel('Training Step')
    ax.set_ylabel('PC1 Projection')
    ax.set_title('PC1 (Assistant-like Direction) During RL Training')
    ax.grid(True, alpha=0.3)
    ax.legend()
    
    # Annotate direction
    ax.annotate('More Assistant-like', xy=(0.02, 0.98), xycoords='axes fraction',
                fontsize=10, color='#457b9d', ha='left', va='top')
    ax.annotate('More Role-playing', xy=(0.02, 0.02), xycoords='axes fraction',
                fontsize=10, color='#e63946', ha='left', va='bottom')
    
    plt.tight_layout()
    return fig


if len(checkpoint_projections) > 1:
    fig = plot_pc1_over_training(checkpoint_projections)
    plt.show()
else:
    print("Need checkpoints to plot. Run RL training first.")

Need checkpoints to plot. Run RL training first.


## Compute Distance from Baseline

In [11]:
if len(checkpoint_projections) > 1:
    baseline = checkpoint_projections[0]
    
    print("Euclidean distance from baseline (in top 3 PCs):")
    for step in sorted(checkpoint_projections.keys()):
        if step == 0:
            continue
        dist = np.linalg.norm(checkpoint_projections[step] - baseline)
        print(f"  Step {step:5d}: {dist:.4f}")

## Next Steps

1. **Run RL training** with checkpoint saving enabled (see `../README.md`)
2. **Re-run this notebook** to extract activations and visualize persona drift
3. **Compare with `pca_after_rl.ipynb`** for detailed pre/post analysis