# LLM Interpretability Dashboard - GPT-2 Example

This notebook demonstrates the core features of the LLM Interpretability Dashboard using GPT-2 as an example model. We'll explore:

1. **Attribution Methods** - Understanding which tokens are important
2. **Activation Patching** - Testing causal relationships
3. **Attention Analysis** - Visualizing attention patterns
4. **Mechanistic Analysis** - Logit lens and neuron analysis
5. **Visualization** - Creating publication-ready plots

## Setup

First, let's import the necessary libraries and load a GPT-2 model.

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Import our interpretability library
import sys
sys.path.append('../')

from interpret_llm import load_model_and_tokenizer, get_config
from interpret_llm.attribution import GradientAttributor, AttentionAttributor, AttributionVisualizer
from interpret_llm.patching import ActivationPatcher, CausalTracer
from interpret_llm.circuits import LogitLens, NeuronAnalyzer, AttentionHeadAblator
from interpret_llm.visualization import TextOverlayVisualizer, HeatmapVisualizer

# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

ModuleNotFoundError: No module named 'torch'

In [None]:
# Load GPT-2 model and tokenizer
model_name = "gpt2"
model, tokenizer = load_model_and_tokenizer(model_name, device=device)

print(f"Loaded model: {model_name}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Vocab size: {tokenizer.vocab_size}")

## 1. Attribution Analysis

Let's start by understanding which tokens are important for the model's predictions using various attribution methods.

In [None]:
# Example text for analysis
text = "The Eiffel Tower is located in Paris, the capital of France."

print(f"Analyzing text: '{text}'")
print(f"Tokens: {tokenizer.tokenize(text)}")

### Gradient-based Attribution

In [None]:
# Initialize gradient attributor
grad_attributor = GradientAttributor(model, tokenizer, device)

# Compute different attribution methods
attribution_results = grad_attributor.compute_all_attributions(
    text,
    methods=["vanilla_gradient", "gradient_x_input", "integrated_gradients"]
)

# Display results
for method_name, result in attribution_results.items():
    print(f"\n{method_name.upper()}:")
    print(f"Target token: {result.metadata.get('target_token_id', 'N/A')}")
    print(f"Target score: {result.metadata.get('target_score', 'N/A'):.4f}")
    
    # Show top attributed tokens
    top_indices = torch.argsort(result.attributions, descending=True)[:5]
    print("Top attributed tokens:")
    for i, idx in enumerate(top_indices):
        token = result.tokens[idx]
        score = result.attributions[idx].item()
        print(f"  {i+1}. '{token}': {score:.4f}")

### Attention-based Attribution

In [None]:
# Initialize attention attributor
attn_attributor = AttentionAttributor(model, tokenizer, device)

# Compute attention rollout
rollout_result = attn_attributor.attention_rollout(text)

print("ATTENTION ROLLOUT:")
print(f"Sequence length: {len(rollout_result.tokens)}")

# Show top attributed tokens
top_indices = torch.argsort(rollout_result.attributions, descending=True)[:5]
print("Top attributed tokens:")
for i, idx in enumerate(top_indices):
    token = rollout_result.tokens[idx]
    score = rollout_result.attributions[idx].item()
    print(f"  {i+1}. '{token}': {score:.4f}")

### Visualizing Attribution Results

In [None]:
# Initialize visualizers
attr_viz = AttributionVisualizer()
text_viz = TextOverlayVisualizer()

# Create token heatmap for integrated gradients
ig_result = attribution_results["integrated_gradients"]
fig = attr_viz.plot_token_heatmap(
    ig_result,
    title="Integrated Gradients Attribution",
    show_values=True
)
plt.show()

# Create comparison plot
comparison_fig = attr_viz.compare_attributions(
    attribution_results,
    title="Attribution Method Comparison"
)
plt.show()

In [None]:
# Create interactive HTML visualization
html_viz = text_viz.create_html_overlay(
    ig_result.tokens,
    ig_result.attributions,
    title="Interactive Token Attribution",
    save_path="attribution_visualization.html"
)

print("Interactive HTML visualization saved to 'attribution_visualization.html'")
print("Open this file in a web browser to see the interactive visualization.")

## 2. Activation Patching

Now let's explore causal relationships by patching activations at different components.

In [None]:
# Initialize activation patcher
with ActivationPatcher(model, tokenizer, device) as patcher:
    
    # Run systematic ablation on a few layers
    ablation_results = patcher.run_systematic_ablation(
        text,
        layers=[6, 7, 8, 9, 10, 11],  # Focus on later layers
        components=["attention", "mlp"],
        heads=[0, 1, 2, 3, 4, 5]  # First few heads
    )
    
    print(f"Ran {len(ablation_results)} ablation experiments")
    
    # Find most critical components
    critical_components = patcher.find_critical_components(
        text,
        metric="logit_l2_diff",
        threshold=0.1
    )
    
    print(f"\nFound {len(critical_components)} critical components:")
    for (location, impact), result in critical_components[:10]:
        print(f"  {location.layer_idx}_{location.component}", end="")
        if location.head_idx is not None:
            print(f"_H{location.head_idx}", end="")
        print(f": {impact:.4f}")

### Causal Tracing

Let's perform causal tracing to understand information flow for factual associations.

In [None]:
# Initialize causal tracer
with CausalTracer(model, tokenizer, device) as tracer:
    
    # Trace causal effect for "Paris" given "Eiffel Tower"
    trace_result = tracer.trace_causal_effect(
        text="The Eiffel Tower is located in",
        subject_tokens=["Eiffel", "Tower"],
        target_token_position=-1  # Last position
    )
    
    print("CAUSAL TRACING RESULTS:")
    print(f"Baseline score: {trace_result.baseline_score:.4f}")
    print(f"Corrupted score: {trace_result.corrupted_score:.4f}")
    print(f"Subject positions: {trace_result.subject_token_positions}")
    
    # Show top restoration effects
    sorted_effects = sorted(
        trace_result.trace_results.items(),
        key=lambda x: x[1]["restoration_effect"],
        reverse=True
    )
    
    print("\nTop restoration effects:")
    for component, metrics in sorted_effects[:10]:
        effect = metrics["restoration_effect"]
        print(f"  {component}: {effect:.4f}")

## 3. Attention Analysis

Let's visualize attention patterns to understand how the model processes information.

In [None]:
# Get attention weights for visualization
inputs = tokenizer(text, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = model(**inputs, output_attentions=True)

attentions = outputs.attentions  # List of attention weights per layer
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

print(f"Got attention weights for {len(attentions)} layers")
print(f"Attention shape per layer: {attentions[0].shape}")  # [batch, heads, seq, seq]
print(f"Tokens: {tokens}")

In [None]:
# Visualize attention patterns
heatmap_viz = HeatmapVisualizer()

# Show attention heatmap for last layer, first head
last_layer_attention = attentions[-1][0]  # Remove batch dimension

fig = heatmap_viz.attention_heatmap(
    last_layer_attention[0],  # First head
    tokens,
    title="Attention Pattern - Last Layer, Head 0",
    layer_idx=len(attentions)-1,
    head_idx=0,
    show_values=False
)
plt.show()

# Show multi-head attention grid for layer 8
layer_8_attention = attentions[8][0]  # Remove batch dimension

fig = heatmap_viz.multi_head_attention_grid(
    layer_8_attention,
    tokens,
    title="Multi-Head Attention",
    layer_idx=8,
    max_heads=8
)
plt.show()

## 4. Mechanistic Analysis

Now let's dive deeper into the model's internal mechanisms.

### Logit Lens Analysis

In [None]:
# Initialize logit lens
logit_lens = LogitLens(model, tokenizer, device)

# Analyze layer-by-layer predictions
lens_result = logit_lens.analyze(
    text,
    layers=[0, 2, 4, 6, 8, 10, 11],  # Sample of layers
    top_k=5
)

print("LOGIT LENS ANALYSIS:")
print(f"Analyzed {len(lens_result.layer_predictions)} layers")
print(f"Input tokens: {lens_result.tokens}")

# Show predictions for the last token position across layers
last_position = len(lens_result.tokens) - 1
print(f"\nPredictions for position {last_position} ('{lens_result.tokens[last_position]}'):")

for layer_idx in sorted(lens_result.layer_predictions.keys()):
    top_tokens = lens_result.top_tokens[layer_idx][last_position]
    print(f"  Layer {layer_idx}: {top_tokens[:3]}")

In [None]:
# Analyze prediction convergence
convergence = logit_lens.analyze_convergence(
    lens_result,
    position=last_position
)

print("CONVERGENCE ANALYSIS:")
print(f"Final prediction: '{convergence.get('final_prediction')}'")
print(f"Convergence layer: {convergence.get('convergence_layer')}")

# Plot convergence probabilities
if 'convergence_probabilities' in convergence:
    probs = convergence['convergence_probabilities']
    layers = convergence['layers_analyzed']
    
    plt.figure(figsize=(10, 6))
    plt.plot(layers, probs, 'o-', linewidth=2, markersize=8)
    plt.xlabel('Layer')
    plt.ylabel('Probability of Final Prediction')
    plt.title('Prediction Convergence Across Layers')
    plt.grid(True, alpha=0.3)
    plt.show()

### Attention Head Ablation

In [None]:
# Initialize head ablator
with AttentionHeadAblator(model, tokenizer, device) as ablator:
    
    # Systematically ablate attention heads in later layers
    head_results = ablator.systematic_head_ablation(
        text,
        layers=[8, 9, 10, 11],  # Later layers
        heads=[0, 1, 2, 3, 4, 5, 6, 7],  # First 8 heads
        ablation_type="zero"
    )
    
    print("HEAD ABLATION RESULTS:")
    print(f"Baseline score: {head_results.baseline_score:.4f}")
    print(f"Tested {len(head_results.head_impacts)} heads")
    
    # Find critical heads
    critical_heads = ablator.find_critical_heads(
        head_results,
        threshold=0.05,
        top_k=10
    )
    
    print("\nMost critical heads:")
    for head_id, impact in critical_heads:
        print(f"  {head_id}: {impact:.4f}")

In [None]:
# Create heatmap of head impacts
heatmap_matrix = ablator.layer_head_heatmap(head_results)

plt.figure(figsize=(12, 8))
plt.imshow(heatmap_matrix, cmap='Reds', aspect='auto')
plt.colorbar(label='Ablation Impact')
plt.xlabel('Head Index')
plt.ylabel('Layer Index')
plt.title('Attention Head Ablation Impact Heatmap')

# Add text annotations for high-impact heads
for i in range(heatmap_matrix.shape[0]):
    for j in range(heatmap_matrix.shape[1]):
        if heatmap_matrix[i, j] > 0.1:  # Only annotate significant impacts
            plt.text(j, i, f'{heatmap_matrix[i, j]:.2f}', 
                    ha='center', va='center', color='white', fontweight='bold')

plt.tight_layout()
plt.show()

### Neuron Analysis

In [None]:
# Initialize neuron analyzer
neuron_analyzer = NeuronAnalyzer(model, tokenizer, device)

# Extract neuron activations for MLP layers
neuron_results = neuron_analyzer.extract_neuron_activations(
    text,
    components=["mlp"],
    layers=[8, 9, 10]  # Focus on later layers
)

print("NEURON ANALYSIS:")
print(f"Components analyzed: {list(neuron_results.neuron_activations.keys())}")

# Show statistics for some neurons
print("\nSample neuron statistics:")
for neuron_id, stats in list(neuron_results.activation_statistics.items())[:5]:
    print(f"  {neuron_id}:")
    print(f"    Mean: {stats['mean']:.4f}, Std: {stats['std']:.4f}")
    print(f"    Max: {stats['max']:.4f}, Sparsity: {stats['sparsity']:.4f}")
    
    # Show max activating tokens
    if neuron_id in neuron_results.max_activating_tokens:
        max_tokens = neuron_results.max_activating_tokens[neuron_id][:3]
        print(f"    Top tokens: {max_tokens}")
    print()

## 5. Comprehensive Analysis Dashboard

Let's create a comprehensive analysis dashboard that combines multiple interpretability methods.

In [None]:
def create_interpretability_dashboard(text, model, tokenizer, device):
    """Create a comprehensive interpretability dashboard."""
    
    print(f"üîç INTERPRETABILITY DASHBOARD")
    print(f"Text: '{text}'")
    print(f"Tokens: {tokenizer.tokenize(text)}")
    print("=" * 60)
    
    # 1. Attribution Analysis
    print("\nüìä ATTRIBUTION ANALYSIS")
    grad_attributor = GradientAttributor(model, tokenizer, device)
    
    # Quick attribution
    ig_result = grad_attributor.integrated_gradients(text)
    top_idx = torch.argmax(ig_result.attributions)
    top_token = ig_result.tokens[top_idx]
    top_score = ig_result.attributions[top_idx].item()
    
    print(f"Most important token: '{top_token}' (score: {top_score:.4f})")
    
    # 2. Causal Analysis
    print("\nüîó CAUSAL ANALYSIS")
    with ActivationPatcher(model, tokenizer, device) as patcher:
        critical_components = patcher.find_critical_components(
            text, threshold=0.05
        )
        
        if critical_components:
            top_component = critical_components[0]
            location, impact = top_component
            comp_name = f"Layer {location.layer_idx}, {location.component}"
            if location.head_idx is not None:
                comp_name += f", Head {location.head_idx}"
            print(f"Most critical component: {comp_name} (impact: {impact:.4f})")
        else:
            print("No critical components found above threshold")
    
    # 3. Attention Analysis
    print("\nüëÅÔ∏è ATTENTION ANALYSIS")
    attn_attributor = AttentionAttributor(model, tokenizer, device)
    rollout_result = attn_attributor.attention_rollout(text)
    
    top_attn_idx = torch.argmax(rollout_result.attributions)
    top_attn_token = rollout_result.tokens[top_attn_idx]
    top_attn_score = rollout_result.attributions[top_attn_idx].item()
    
    print(f"Most attended token: '{top_attn_token}' (score: {top_attn_score:.4f})")
    
    # 4. Mechanistic Analysis
    print("\nüî¨ MECHANISTIC ANALYSIS")
    logit_lens = LogitLens(model, tokenizer, device)
    lens_result = logit_lens.analyze(text, layers=[0, 6, 11])
    
    last_pos = len(lens_result.tokens) - 1
    final_prediction = lens_result.top_tokens[11][last_pos][0] if 11 in lens_result.top_tokens else "N/A"
    early_prediction = lens_result.top_tokens[0][last_pos][0] if 0 in lens_result.top_tokens else "N/A"
    
    print(f"Early prediction (Layer 0): '{early_prediction}'")
    print(f"Final prediction (Layer 11): '{final_prediction}'")
    
    print("\n" + "=" * 60)
    print("‚úÖ Dashboard complete!")

# Run the dashboard
create_interpretability_dashboard(text, model, tokenizer, device)

## 6. Comparison Across Different Inputs

Let's compare how the model processes different types of factual statements.

In [None]:
# Different factual statements
test_texts = [
    "The Eiffel Tower is located in Paris, the capital of France.",
    "Albert Einstein developed the theory of relativity.", 
    "The Amazon River flows through Brazil.",
    "William Shakespeare wrote Romeo and Juliet."
]

print("üîÑ COMPARATIVE ANALYSIS")
print("Analyzing different factual statements...\n")

# Quick analysis for each text
for i, test_text in enumerate(test_texts, 1):
    print(f"{i}. {test_text}")
    
    # Attribution analysis
    grad_attributor = GradientAttributor(model, tokenizer, device)
    ig_result = grad_attributor.integrated_gradients(test_text)
    
    # Find most important token
    top_idx = torch.argmax(ig_result.attributions)
    top_token = ig_result.tokens[top_idx]
    top_score = ig_result.attributions[top_idx].item()
    
    print(f"   Most important token: '{top_token}' (score: {top_score:.4f})")
    print()

## Summary and Next Steps

This notebook demonstrated the core capabilities of the LLM Interpretability Dashboard:

### ‚úÖ What we covered:
1. **Attribution Methods**: Gradient-based and attention-based attribution
2. **Activation Patching**: Systematic ablation and causal tracing
3. **Attention Analysis**: Visualizing attention patterns
4. **Mechanistic Analysis**: Logit lens, neuron analysis, and head ablation
5. **Visualization**: Rich plots and interactive visualizations

### üöÄ Next steps:
1. Try with larger models (GPT-2 medium/large, GPT-Neo, LLaMA)
2. Analyze more complex tasks (reasoning, factual recall, sentiment)
3. Explore neuron clustering and circuit discovery
4. Build custom attribution methods
5. Create publication-ready visualizations

### üìö Resources:
- [Full documentation](../README.md)
- [API reference](../docs/api.md)
- [Additional examples](../examples/)

Happy interpreting! üß†‚ú®