# Mechanistic Interpretability Toolkit Examples

This notebook provides examples of how to use the mech_interp_toolkit library for mechanistic interpretability research.

It covers:
1. Loading models and tokenizers
2. Extracting activations from specific model components
3. Performing Direct Logit Attribution (DLA)
4. Running Gradient-based Attribution (Integrated Gradients)
5. Patching activations to test causal interventions
6. Training Linear Probes on extracted activations

## Import Required Libraries

In [None]:
import torch
import numpy as np
from mech_interp_toolkit.utils import load_model_tokenizer_config
from mech_interp_toolkit.activation_dict import ActivationDict
from mech_interp_toolkit.activations import get_activations, patch_activations
from mech_interp_toolkit.direct_logit_attribution import run_componentwise_dla
from mech_interp_toolkit.gradient_based_attribution import simple_integrated_gradients
from mech_interp_toolkit.linear_probes import LinearProbe
from mech_interp_toolkit.tokenizer import ChatTemplateTokenizer

## Load Model and Tokenizer

In [None]:
print("Loading model...")
# Replace with your model name, e.g., "Qwen/Qwen2-0.5B"
# Note: This requires the 'nnsight' library as per the toolkit's implementation.
model_name = "Qwen/Qwen2-0.5B" 

try:
    model, tokenizer, config = load_model_tokenizer_config(model_name)
    print(f"Successfully loaded {model_name}")
except Exception as e:
    print(f"Could not load model {model_name} (this is expected if not running in a configured env): {e}")
    # Mocking config for the rest of the script if model load fails, 
    # so the example code structure is visible.
    from transformers import AutoConfig
    config = AutoConfig.from_pretrained(model_name)
    raise

## 1. Tokenization with Chat Templates

In [None]:
print("\n--- Tokenization ---")
chat_tokenizer = ChatTemplateTokenizer(tokenizer)
prompt = "Explain the theory of relativity in one sentence."
# Returns a dict with 'input_ids' and 'attention_mask'
inputs = chat_tokenizer(prompt)
print(f"Input shape: {inputs['input_ids'].shape}")

## 2. Extracting Activations

In [None]:
print("\n--- Extracting Activations ---")
# Define components to extract: (layer_index, component_name)
# Common names: "attn", "mlp", "z" (attention head output), "layer_in", "layer_out"
components_to_extract = [(0, "mlp"), (0, "attn")]

# get_activations returns (activations, grads, logits)
activations, grads, logits = get_activations(model, inputs, components_to_extract)

for key, val in activations.items():
    print(f"Extracted {key}: {val.shape}")

## 3. Direct Logit Attribution (DLA)

In [None]:
print("\n--- Direct Logit Attribution ---")
# Calculate the contribution of each component to a specific direction in the residual stream.
# Typically, this direction is (embedding[correct_token] - embedding[incorrect_token]).
# Here we use a random direction for demonstration.
logit_diff_direction = torch.randn(config.hidden_size).to(model.device)

dla_results = run_componentwise_dla(model, inputs, logit_diff_direction)

# dla_results is an ActivationDict containing attribution scores
if dla_results:
    print(f"DLA computed for {len(dla_results)} components.")
    print(f"Attribution for Layer 0 MLP: {dla_results.get((0, 'mlp'), 'N/A')}")

## 4. Activation Patching

In [None]:
print("\n--- Activation Patching ---")
# We can modify activations and see the effect on the output.
# Let's zero-ablate the MLP at layer 0.

# Create a new ActivationDict for patching
patch_dict = ActivationDict(config, positions=slice(None))
# Set the value to patch with (zeros in this case)
patch_dict[(0, "mlp")] = torch.zeros_like(activations[(0, "mlp")])

# patch_activations returns (activations, grads, logits)
_, _, patched_logits = patch_activations(model, inputs, patch_dict, position=0)

print(f"Original logits norm: {logits.norm().item():.4f}")
print(f"Patched logits norm: {patched_logits.norm().item():.4f}")

## 5. Gradient-Based Attribution (Integrated Gradients)

In [None]:
print("\n--- Integrated Gradients ---")
# Attribution of input features (or internal activations) to the output.
# We need a baseline activation (usually zeros).

baseline_embeddings = ActivationDict(config, positions=slice(None))
seq_len = inputs['input_ids'].shape[1]
# Assuming we want to attribute to 'layer_in' at layer 0
baseline_embeddings[(0, "layer_in")] = torch.zeros(1, seq_len, config.hidden_size).to(model.device)

ig_attributions = simple_integrated_gradients(
    model, inputs, baseline_embeddings, steps=5
)
print("Integrated Gradients computed.")

## 6. Linear Probes

In [None]:
print("\n--- Linear Probes ---")
# Train a linear classifier on activations.

# Generate dummy data for demonstration
# In reality, you would collect activations over a dataset
X = ActivationDict(config, positions=slice(None))
# LinearProbe expects exactly one component in the ActivationDict
X[(0, "mlp")] = torch.randn(100, 1, config.hidden_size) # 100 samples
y = np.random.randint(0, 2, 100) # Binary classification targets

probe = LinearProbe(target_type="classification")
probe.fit(X, y)
print("Linear probe trained.")

# Predict on new data
X_test = ActivationDict(config, positions=slice(None))
X_test[(0, "mlp")] = torch.randn(10, 1, config.hidden_size)
preds = probe.predict(X_test)
print(f"Predictions shape: {preds.shape}")