# Factual Recall Circuit Discovery - Interactive Tutorial

This notebook walks through the process of discovering factual recall circuits in Gemma 2B using attribution graphs and sparse autoencoders.

## Setup

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, HTML

from circuit_discovery import CircuitDiscovery, Circuit
from testing_pipeline import CircuitTester, FeatureHypothesis
from utils import (
    visualize_circuit, 
    compare_circuits,
    plot_testing_results
)

# Check GPU availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

if device == 'cpu':
    print("WARNING: Running on CPU will be very slow!")

## Part 1: Initialize Discovery System

Load Gemma 2B and set up the circuit discovery pipeline.

In [None]:
# Initialize circuit discovery
print("Loading Gemma 2B...")
discovery = CircuitDiscovery(device=device)
print("✓ Model loaded successfully!")

# Check model details
num_layers = len(discovery.model.model.layers)
print(f"\nModel has {num_layers} layers")

## Part 2: Create Factual Dataset

Define prompts for different types of factual knowledge.

In [None]:
# Create dataset of factual prompts
dataset = {
    'entity_location': [
        {'clean': 'The Eiffel Tower is located in Paris',
         'corrupted': 'The Eiffel Tower is located in London'},
        {'clean': 'Mount Everest is in the Himalayas',
         'corrupted': 'Mount Everest is in the Alps'},
        {'clean': 'The Colosseum is in Rome',
         'corrupted': 'The Colosseum is in Athens'},
    ],
    
    'capital_country': [
        {'clean': 'The capital of France is Paris',
         'corrupted': 'The capital of France is Berlin'},
        {'clean': 'Tokyo is the capital of Japan',
         'corrupted': 'Tokyo is the capital of China'},
    ],
}

# Display dataset statistics
print("Dataset Overview:")
for fact_type, prompts in dataset.items():
    print(f"  {fact_type}: {len(prompts)} prompt pairs")
print(f"\nTotal: {sum(len(p) for p in dataset.values())} prompt pairs")

## Part 3: Test Attribution Methods

Let's first test the attribution methods on a single example.

In [None]:
# Test activation patching
test_clean = "The Eiffel Tower is in Paris"
test_corrupted = "The Eiffel Tower is in London"

print("Testing activation patching across layers...\n")

layer_effects = {}
for layer_idx in range(0, num_layers, 3):  # Test every 3rd layer
    effect = discovery.attribution_graph.activation_patching(
        test_clean, test_corrupted, layer_idx, -1
    )
    layer_effects[layer_idx] = effect
    print(f"Layer {layer_idx:2d}: effect = {effect:.4f}")

# Plot results
plt.figure(figsize=(10, 4))
plt.bar(layer_effects.keys(), layer_effects.values(), color='steelblue', alpha=0.7)
plt.xlabel('Layer', fontsize=12)
plt.ylabel('Patching Effect', fontsize=12)
plt.title('Activation Patching Effect Across Layers', fontsize=14, fontweight='bold')
plt.grid(axis='y', alpha=0.3)
plt.show()

print(f"\nMost important layer: {max(layer_effects, key=layer_effects.get)}")

## Part 4: Train Sparse Autoencoders

Train SAEs on important layers to extract interpretable features.

In [None]:
# Collect training prompts
training_prompts = []
for prompts in dataset.values():
    training_prompts.extend([p['clean'] for p in prompts])

print(f"Training SAEs on {len(training_prompts)} prompts...\n")

# Train SAEs for middle layers (most important for factual recall)
important_layers = [8, 12, 16]  # Adjust based on your model

for layer_idx in important_layers:
    if layer_idx < num_layers:
        print(f"\nTraining SAE for layer {layer_idx}...")
        sae = discovery.train_sparse_autoencoder(
            layer_idx=layer_idx,
            training_prompts=training_prompts,
            epochs=5  # Increase for better results
        )
        print(f"✓ SAE trained for layer {layer_idx}")

print("\n✓ All SAEs trained!")

## Part 5: Discover Circuits

Now discover circuits for each fact type.

In [None]:
# Discover circuits (SAEs already trained)
print("Discovering circuits...\n")

circuits = []
for fact_type, prompts in dataset.items():
    print(f"\n{'='*60}")
    print(f"Discovering circuit for: {fact_type}")
    print(f"{'='*60}")
    
    circuit = discovery.discover_circuit(
        fact_prompts=prompts,
        fact_type=fact_type,
        threshold=0.01
    )
    
    circuits.append(circuit)
    
    print(f"\n✓ Found circuit with:")
    print(f"   {len(circuit.nodes)} nodes")
    print(f"   {len(circuit.edges)} edges")
    print(f"   Attribution score: {circuit.attribution_score:.4f}")

print(f"\n{'='*60}")
print(f"Total circuits discovered: {len(circuits)}")
print(f"{'='*60}")

## Part 6: Visualize Circuits

Create visualizations of discovered circuits.

In [None]:
# Visualize each circuit
for circuit in circuits:
    print(f"\nVisualizing: {circuit.name}")
    fig = visualize_circuit(circuit, figsize=(14, 8))
    plt.show()

# Compare all circuits
print("\nComparing all circuits...")
fig = compare_circuits(circuits, figsize=(14, 5))
plt.show()

## Part 7: Test Hypotheses

Create and test hypotheses about circuit features.

In [None]:
# Initialize tester
tester = CircuitTester(discovery)

# Create example hypotheses
hypotheses = []

for circuit in circuits:
    # Get first few nodes
    for i, node in enumerate(circuit.nodes[:3]):
        hyp = FeatureHypothesis(
            feature_id=node,
            hypothesis=f"{circuit.fact_type} detector (node {i})",
            test_prompts=[
                "Paris is in France",
                "Tokyo is in Japan",
                "London is in England"
            ],
            control_prompts=[
                "The sky is blue",
                "Two plus two equals four",
                "Water is wet"
            ]
        )
        hypotheses.append(hyp)

print(f"Testing {len(hypotheses)} hypotheses...\n")

# Test hypotheses
results = []
for hyp in hypotheses:
    try:
        result = tester.validator.test_hypothesis(hyp, verbose=True)
        results.append(result)
        print()
    except Exception as e:
        print(f"Skipped: {e}\n")

# Summary
if results:
    passed = sum(r.passed for r in results)
    print(f"\n{'='*60}")
    print(f"Testing Summary: {passed}/{len(results)} hypotheses passed")
    print(f"{'='*60}")

## Part 8: Analyze Circuit Properties

Dive deeper into circuit characteristics.

In [None]:
# Analyze circuit properties
print("Circuit Analysis:\n")

for circuit in circuits:
    print(f"\n{circuit.name}:")
    
    # Node distribution across layers
    layer_counts = {}
    for node in circuit.nodes:
        layer = node[0]
        layer_counts[layer] = layer_counts.get(layer, 0) + 1
    
    print(f"  Layer distribution:")
    for layer in sorted(layer_counts.keys()):
        print(f"    Layer {layer:2d}: {layer_counts[layer]} nodes")
    
    # Average connectivity
    if circuit.nodes:
        connectivity = len(circuit.edges) / len(circuit.nodes)
        print(f"  Average connectivity: {connectivity:.2f} edges/node")
    
    # Circuit depth
    if circuit.nodes:
        min_layer = min(n[0] for n in circuit.nodes)
        max_layer = max(n[0] for n in circuit.nodes)
        depth = max_layer - min_layer + 1
        print(f"  Circuit depth: {depth} layers (L{min_layer} to L{max_layer})")

## Part 9: Export Results

Save circuits and analysis results.

In [None]:
import json
from pathlib import Path

# Create output directory
output_dir = Path('/mnt/user-data/outputs/notebook_results')
output_dir.mkdir(parents=True, exist_ok=True)

# Export circuits as JSON
circuits_data = []
for circuit in circuits:
    circuits_data.append({
        'name': circuit.name,
        'fact_type': circuit.fact_type,
        'num_nodes': len(circuit.nodes),
        'num_edges': len(circuit.edges),
        'attribution_score': circuit.attribution_score,
        'nodes': [{'layer': n[0], 'feature': n[1]} for n in circuit.nodes[:20]]
    })

with open(output_dir / 'circuits.json', 'w') as f:
    json.dump(circuits_data, f, indent=2)

print(f"✓ Results exported to {output_dir}")
print(f"\nFiles created:")
print(f"  - circuits.json")
print(f"\nYou can now view these files in the outputs directory!")

## Part 10: Interactive Exploration

Try your own examples!

In [None]:
# Test your own factual prompt
custom_clean = "The Great Wall of China is in Asia"  # Your fact here
custom_corrupted = "The Great Wall of China is in Europe"  # Corrupted version

print(f"Testing custom prompt:\n")
print(f"Clean: {custom_clean}")
print(f"Corrupted: {custom_corrupted}\n")

# Test across layers
print("Layer effects:")
for layer_idx in range(0, num_layers, 4):
    effect = discovery.attribution_graph.activation_patching(
        custom_clean, custom_corrupted, layer_idx, -1
    )
    print(f"  Layer {layer_idx:2d}: {effect:.4f}")

print("\nTry changing the prompts above and re-running this cell!")

## Summary

In this notebook, we:
1. ✓ Loaded Gemma 2B and initialized circuit discovery
2. ✓ Created a dataset of factual prompts
3. ✓ Tested attribution methods
4. ✓ Trained sparse autoencoders
5. ✓ Discovered factual recall circuits
6. ✓ Visualized circuits
7. ✓ Tested hypotheses about features
8. ✓ Analyzed circuit properties
9. ✓ Exported results

## Next Steps

- Explore more fact types
- Test on larger datasets
- Refine SAE training
- Investigate circuit overlap
- Try different models

## References

- [Mechanistic Interpretability](https://transformer-circuits.pub/)
- [Sparse Autoencoders](https://arxiv.org/abs/2309.08600)
- [Attribution Patching](https://arxiv.org/abs/2310.10348)