# Causal Tracing of "Answer-First" Bias

This notebook runs the analysis to detect if the answer to a math problem emerges in the earliest layers of the model, before reasoning begins.

In [None]:
import sys
sys.path.append("..")

import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from src.data_loader import load_gsm8k_dataset, filter_single_token_answers
from src.analysis import run_analysis
from transformers import AutoTokenizer

## Load Data

In [None]:
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-1.5B-Instruct")
dataset = load_gsm8k_dataset(split="test")
filtered_data = filter_single_token_answers(dataset, tokenizer)
print(f"Loaded {len(dataset)} examples, filtered to {len(filtered_data)} single-token answer examples.")

## Run Analysis

In [None]:
results = run_analysis(
    model_name="Qwen/Qwen2.5-Math-1.5B-Instruct",
    dataset=filtered_data,
    num_samples=20,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

## Visualization

We plot the probability of the correct answer token across layers (at the last token of the prompt).

In [None]:
layer_probs = np.array(results["layer_probs"])
avg_probs = np.mean(layer_probs, axis=0)

plt.figure(figsize=(10, 6))
plt.plot(avg_probs, marker='o')
plt.xlabel("Layer Index")
plt.ylabel("Probability of Correct Answer")
plt.title("Average Probability of Correct Answer Across Layers (Logit Lens)")
plt.grid(True)
plt.show()

In [None]:
# Heatmap for individual samples
plt.figure(figsize=(12, 8))
sns.heatmap(layer_probs, cmap="viridis", cbar_kws={'label': 'Probability'})
plt.xlabel("Layer Index")
plt.ylabel("Sample Index")
plt.title("Answer Probability per Layer per Sample")
plt.show()