# Residual Stream Analysis

This notebook explores how much of a transformer's residual stream contributes to token prediction versus persistent memory/latent reasoning.

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from residual_analysis import ResidualStreamAnalyzer

  from .autonotebook import tqdm as notebook_tqdm


## Initialize the Analyzer

First, let's create our analyzer with the model of choice.

In [2]:
# You can use any model supported by TransformerLens
# Smaller models load faster but might not exhibit as strong reasoning patterns
model_name = "gpt2-small"  # Options: "gpt2-small", "gpt2-medium", "gpt2-large", "gpt2-xl", "pythia-70m", etc.

analyzer = ResidualStreamAnalyzer(model_name)
print(f"Loaded {model_name} with {analyzer.model.cfg.n_layers} layers and {analyzer.d_model} hidden dimensions")

Using MPS (Metal Performance Shaders) for Apple Silicon
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  mps
Loaded gpt2-small with 12 layers and 768 hidden dimensions


## Examining the Unembedding Matrix

Let's look at the singular value decomposition of the unembedding matrix to understand the token prediction subspace.

In [3]:
# Get the token prediction subspace
P, S, cumulative_var = analyzer.compute_token_subspace(variance_explained=0.95)

# Package results for visualization
results = {
    "singular_values": S,
    "cumulative_variance": cumulative_var,
    "projection_matrix": P
}

print("here")

# Visualize the singular values
analyzer.visualize_singular_values(results)

: 

## Experiment 1: Comparing Factual Recall vs. Reasoning

In this experiment, we compare how much of the residual stream is used for token prediction in a factual recall task versus a reasoning task.

In [None]:
# Define example prompts
factual_prompt = "The capital of France is"  # Simple factual recall
reasoning_prompt = "If 5x + 3 = 18, then x equals"  # Simple reasoning task

print(f"Factual prompt: '{factual_prompt}'")
print(f"Reasoning prompt: '{reasoning_prompt}'")

In [None]:
# Analyze the factual prompt
factual_results = analyzer.run_with_hooks(factual_prompt)

# Display predicted tokens
print("Factual prompt:", factual_prompt)
print("Predicted next tokens:", " ".join(factual_results["token_strs"][:3]))

# Visualize projection percentages
analyzer.visualize_projection_percentages(
    factual_results, 
    title="Projection Percentages - Factual Recall"
)

In [None]:
# Analyze the reasoning prompt
reasoning_results = analyzer.run_with_hooks(reasoning_prompt)

# Display predicted tokens
print("Reasoning prompt:", reasoning_prompt)
print("Predicted next tokens:", " ".join(reasoning_results["token_strs"][:3]))

# Visualize projection percentages
analyzer.visualize_projection_percentages(
    reasoning_results, 
    title="Projection Percentages - Reasoning Task"
)

In [None]:
# Compare final layer percentages
final_layer = "ln_final.hook_normalized"

factual_percentages = factual_results["projection_percentages"][final_layer]
reasoning_percentages = reasoning_results["projection_percentages"][final_layer]

print("\nFinal Layer Projection Percentages:")
print(f"Factual: {factual_percentages[0].item():.2f}%")
print(f"Reasoning: {reasoning_percentages[0].item():.2f}%")

# Calculate difference
diff = factual_percentages[0].item() - reasoning_percentages[0].item()
print(f"Difference: {diff:.2f}% {'more' if diff > 0 else 'less'} token-focused in factual task")

## Experiment 2: Analyzing a Multi-step Reasoning Chain

In this experiment, we examine how the projection percentages evolve throughout a structured reasoning process.

In [None]:
# Define a multi-step reasoning prompt
reasoning_chain = """
Problem: If a shirt costs $15 and is on sale for 20% off, what is the sale price?

Step 1: Calculate the discount amount.
Discount = Original price × Discount percentage
Discount = $15 × 0.20
Discount = $3

Step 2: Subtract the discount from the original price.
Sale price = Original price - Discount
Sale price = $15 - $3
Sale price = $12

Therefore, the sale price of the shirt is $12.
"""

print(reasoning_chain)

In [None]:
# Run the analysis
chain_results = analyzer.run_with_hooks(reasoning_chain)

# Visualize the projection percentages across the reasoning chain
analyzer.visualize_projection_percentages(
    chain_results,
    title="Projection Percentages Throughout Reasoning Chain"
)

In [None]:
# Find positions where specific reasoning steps start
step1_pos = reasoning_chain.find("Step 1:") 
step2_pos = reasoning_chain.find("Step 2:")
conclusion_pos = reasoning_chain.find("Therefore")

# Get token representations for reference
tokens = chain_results["tokens"][0]
token_strs = [analyzer.model.tokenizer.decode([t.item()]) for t in tokens]

# Analyze the final layer
final_layer = "ln_final.hook_normalized"
percentages = chain_results["projection_percentages"][final_layer][0].cpu().numpy()

# Plot percentages over the sequence
plt.figure(figsize=(15, 7))
plt.plot(percentages, linewidth=2)
plt.title("Token Space Projection % Throughout Reasoning Chain", fontsize=14)
plt.xlabel("Token Position", fontsize=12)
plt.ylabel("% in Token Space", fontsize=12)
plt.grid(True)

# Calculate approximate token positions for each step
# This is an approximation as character positions don't directly map to token positions
char_to_token_ratio = len(tokens) / len(reasoning_chain)

# Highlight transition points
for name, pos in [("Problem", 0), 
                 ("Step 1", step1_pos), 
                 ("Step 2", step2_pos), 
                 ("Conclusion", conclusion_pos)]:
    token_pos = int(pos * char_to_token_ratio)
    plt.axvline(x=token_pos, color='r', linestyle='--', alpha=0.7)
    plt.text(token_pos + 1, np.max(percentages) * 0.9, name, rotation=0, fontsize=10)

plt.tight_layout()
plt.show()

## Experiment 3: Custom Analysis

Try your own prompts and analyses below.

In [None]:
# Define your custom prompt
custom_prompt = "Write your prompt here"

# Run analysis
custom_results = analyzer.run_with_hooks(custom_prompt)

# Visualize results
analyzer.visualize_projection_percentages(custom_results, title="Custom Prompt Analysis")

## Advanced Analysis: Layer-by-Layer Comparisons

Let's examine how different layers contribute to token prediction vs. latent reasoning.

In [None]:
# Compare multiple prompts across layers
prompts = {
    "Factual": "The capital of France is",
    "Simple Math": "2 + 2 equals",
    "Complex Math": "If 5x + 3 = 18, then x equals",
    "Logic": "All humans are mortal. Socrates is human. Therefore, Socrates is"
}

# Run analysis for each prompt
results_by_prompt = {}
for name, prompt in prompts.items():
    results_by_prompt[name] = analyzer.run_with_hooks(prompt)

# Compare final layer percentages
final_layer = "ln_final.hook_normalized"

# Create comparison plot
plt.figure(figsize=(10, 6))

for name, results in results_by_prompt.items():
    percentages = results["projection_percentages"][final_layer][0].item()
    plt.bar(name, percentages)

plt.title("Token Space Projection % by Task Type")
plt.ylabel("% in Token Space")
plt.ylim(0, 100)
plt.grid(axis='y')

# Add value labels
for i, (name, results) in enumerate(results_by_prompt.items()):
    percentages = results["projection_percentages"][final_layer][0].item()
    plt.text(i, percentages + 1, f"{percentages:.1f}%", ha='center')

plt.tight_layout()
plt.show()