In [1]:
from transformer_lens import HookedTransformer
import torch
import circuitsvis as cv
import einops
from IPython.display import display
import numpy as np
from pprint import pprint
from datasets import load_dataset
import random
from tqdm import tqdm
import json
import re
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score

In [33]:
device = torch.device(
    # "mps" if torch.backends.mps.is_available() else 
    # "cuda" if torch.cuda.is_available() else 
    "cpu" # NOTE: Using CPU as GPU is occupied by main.ipynb
)

In [34]:
model = HookedTransformer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
model = model.to(device)



Loaded pretrained model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B into HookedTransformer
Moving model to device:  cpu


In [69]:
backtracking_phrases = [
    "i made a mistake", 
    "let me recalculate",
    "that's not right",
    "i need to correct",
    "let's try again",
    "i think i went wrong",
    "let's try another approach",
    "actually, i should",
    "wait, that's incorrect",
    "let me rethink",
    "on second thought",
    "i need to backtrack",
    "let me restart",
    "i made an error",
    
    # Correction indicators
    "hmm, that's not", "hmm, that doesn't", "hmm, this doesn't", 
    "wait, that's not", "wait, that doesn't", "wait, this doesn't",
    "actually, that's not", "actually, that doesn't", "actually, this doesn't",
    "oh, that's not", "oh, that doesn't", "oh, this doesn't",
    
    # Reconsideration indicators
    # "let me reconsider", "let me think again", "on second thought",
    # "let's reconsider", "let's think again", "thinking again",
    # "reconsidering", "rethinking", "let me double-check",
    
    # Mistake acknowledgment
    "i made a calculation error", "calculation error", "computational error",
    "arithmetic error", "i miscalculated",
    "i miscounted", "i misunderstood", "i misinterpreted",
    
    # Approach change indicators
    # "different approach", "alternative approach", "another method",
    # "different method", "alternative method", "different strategy",
    # "alternative strategy", "different way", "alternative way",
    
    # Doubt indicators
    "i'm not sure if", "i'm not convinced", "i'm skeptical",
    "i'm doubtful", "i'm uncertain", "i'm not confident",
    "i'm hesitant", "i'm not sure about", "i'm not certain",
    
    # Specific math correction phrases
    "let me redo this calculation", "let me recalculate",
    "i need to redo", "i should redo", "i'll redo",
    "let me solve this again", "let me solve this differently",
    "let me approach this differently", "let me try a different approach"
]

def identify_backtracking(cot_text):
    """
    Identify potential backtracking phrases in a CoT solution.
    
    Args:
        cot_text: The generated CoT solution text
    
    Returns:
        List of identified backtracking phrases
    """    
    found_phrases = []
    for phrase in backtracking_phrases:
        if phrase.lower() in cot_text.lower():
            found_phrases.append(phrase)
    
    return found_phrases

In [70]:
def analyze_cot_results(json_file_path):
    """
    Analyze the Chain-of-Thought results from a saved JSON file.
    
    Args:
        json_file_path: Path to the JSON file with CoT results
        
    Returns:
        Dictionary with analysis results
    """
    # Load the JSON data
    with open(json_file_path, 'r') as f:
        results = json.load(f)
    
    print(f"Analyzing {len(results)} CoT solutions...")
    
    # Initialize counters and storage
    analysis = {
        "total_problems": len(results),
        "correct_answers": 0,
        "has_think_close_tag": 0,
        "ran_out_of_tokens": 0,
        "has_backtracking": 0,
        "token_limit_problems": [],
        "backtracking_problems": [],
        "level_distribution": Counter(),
        "type_distribution": Counter(),
        "level_accuracy": {},
        "type_accuracy": {}
    }
    
    # Track level-specific metrics
    level_correct = Counter()
    level_total = Counter()
    type_correct = Counter()
    type_total = Counter()
    
    # Analyze each problem
    for result in results:
        problem_level = result.get("problem_level", "Unknown")
        problem_type = result.get("problem_type", "Unknown")
        problem_id = result.get("problem_id", "Unknown")
        
        # Update distributions
        analysis["level_distribution"][problem_level] += 1
        analysis["type_distribution"][problem_type] += 1
        level_total[problem_level] += 1
        type_total[problem_type] += 1
        
        # 1. Check for correct answers (this is a simplified check - may need refinement)
        # Extract boxed answers from both generated and ground truth
        generated_cot = result.get("generated_cot", "")
        ground_truth = result.get("ground_truth_solution", "")
        
        # Extract boxed answers
        def extract_boxed_answers(text):
            boxed_pattern = r"\\boxed{([^}]*)}"
            matches = re.findall(boxed_pattern, text)
            return [match.strip() for match in matches]
        
        generated_answers = extract_boxed_answers(generated_cot)
        ground_truth_answers = extract_boxed_answers(ground_truth)
        
        # Check if any generated answer matches any ground truth answer
        is_correct = False
        if generated_answers and ground_truth_answers:
            # Normalize answers for comparison (remove spaces, convert to lowercase)
            norm_generated = [re.sub(r'\s+', '', ans.lower()) for ans in generated_answers]
            norm_ground_truth = [re.sub(r'\s+', '', ans.lower()) for ans in ground_truth_answers]
            
            # Check for any match
            for gen_ans in norm_generated:
                if any(gen_ans == gt_ans for gt_ans in norm_ground_truth):
                    is_correct = True
                    break
        
        if is_correct:
            analysis["correct_answers"] += 1
            level_correct[problem_level] += 1
            type_correct[problem_type] += 1
        
        # 2. Check for </think> close tags
        if "</think>" in generated_cot:
            analysis["has_think_close_tag"] += 1
        
        # 3. Check if ran out of tokens (heuristic: no boxed answer)
        # We determine token limit issues by checking if there's no boxed answer
        if len(generated_answers) == 1 and generated_answers[0] == '':
            analysis["ran_out_of_tokens"] += 1
            analysis["token_limit_problems"].append({
                "id": problem_id,
                "level": problem_level,
                "type": problem_type
            })
        
        # 4. Check for backtracking
        backtracking_phrases = identify_backtracking(generated_cot)
        if backtracking_phrases:
            analysis["has_backtracking"] += 1
            analysis["backtracking_problems"].append({
                "id": problem_id,
                "level": problem_level,
                "type": problem_type,
                "phrases": backtracking_phrases,
                "correct_after_backtracking": is_correct
            })
    
    # Calculate accuracy by level and type
    for level, count in level_total.items():
        analysis["level_accuracy"][level] = level_correct[level] / count if count > 0 else 0
    
    for problem_type, count in type_total.items():
        analysis["type_accuracy"][problem_type] = type_correct[problem_type] / count if count > 0 else 0
    
    # Calculate percentages for easier interpretation
    analysis["percent_correct"] = (analysis["correct_answers"] / analysis["total_problems"]) * 100 if analysis["total_problems"] > 0 else 0
    analysis["percent_think_close"] = (analysis["has_think_close_tag"] / analysis["total_problems"]) * 100 if analysis["total_problems"] > 0 else 0
    analysis["percent_token_limit"] = (analysis["ran_out_of_tokens"] / analysis["total_problems"]) * 100 if analysis["total_problems"] > 0 else 0
    analysis["percent_backtracking"] = (analysis["has_backtracking"] / analysis["total_problems"]) * 100 if analysis["total_problems"] > 0 else 0
    
    return analysis

In [74]:
def print_analysis_report(analysis):
    """
    Print a formatted report of the analysis results.
    
    Args:
        analysis: Dictionary with analysis results
    """
    print("\n" + "="*50)
    print("CHAIN-OF-THOUGHT ANALYSIS REPORT")
    print("="*50)
    
    print(f"\nTotal problems analyzed: {analysis['total_problems']}")
    
    print("\n1. CORRECTNESS")
    print(f"Correct answers: {analysis['correct_answers']} ({analysis['percent_correct']:.2f}%)")
    
    print("\n2. THINK TAGS")
    print(f"Solutions with </think> close tags: {analysis['has_think_close_tag']} ({analysis['percent_think_close']:.2f}%)")
    
    print("\n3. TOKEN LIMITS")
    print(f"Problems that ran out of tokens: {analysis['ran_out_of_tokens']} ({analysis['percent_token_limit']:.2f}%)")
    if analysis['token_limit_problems']:
        print("Sample of problems that ran out of tokens:")
        for i, problem in enumerate(analysis['token_limit_problems']):
            print(f"  {i+1}. Level: {problem['level']}, Type: {problem['type']}")
    
    print("\n4. BACKTRACKING")
    print(f"Solutions with backtracking: {analysis['has_backtracking']} ({analysis['percent_backtracking']:.2f}%)")
    if analysis['backtracking_problems']:
        print("Sample of problems with backtracking:")
        for i, problem in enumerate(analysis['backtracking_problems']):  # Show first 5
            print(f"  {i+1}. Level: {problem['level']}, Type: {problem['type']}")
            print(f"     Phrases: {', '.join(problem['phrases'])}")
            print(f"     Correct after backtracking: {problem['correct_after_backtracking']}")
    print("\n5. PERFORMANCE BY LEVEL")
    for level, accuracy in sorted(analysis['level_accuracy'].items()):
        count = analysis['level_distribution'][level]
        print(f"  {level}: {accuracy*100:.2f}% correct ({count} problems)")
    
    print("\n6. PERFORMANCE BY TYPE")
    for problem_type, accuracy in sorted(analysis['type_accuracy'].items()):
        count = analysis['type_distribution'][problem_type]
        print(f"  {problem_type}: {accuracy*100:.2f}% correct ({count} problems)")
    
    print("\n" + "="*50)

In [75]:
def run_analysis(json_file_path):
    """
    Run the analysis on a JSON file and print the report.
    
    Args:
        json_file_path: Path to the JSON file with CoT results
    """
    analysis = analyze_cot_results(json_file_path)
    print_analysis_report(analysis)
    return analysis

In [76]:
analysis = run_analysis("math_cot_results_t=0.4_mnt=1500_tp=0.92.json")


Analyzing 41 CoT solutions...

CHAIN-OF-THOUGHT ANALYSIS REPORT

Total problems analyzed: 41

1. CORRECTNESS
Correct answers: 13 (31.71%)

2. THINK TAGS
Solutions with </think> close tags: 20 (48.78%)

3. TOKEN LIMITS
Problems that ran out of tokens: 22 (53.66%)
Sample of problems that ran out of tokens:
  1. Level: Level 5, Type: Precalculus
  2. Level: Level 5, Type: Counting & Probability
  3. Level: Level 5, Type: Counting & Probability
  4. Level: Level 3, Type: Intermediate Algebra
  5. Level: Level 5, Type: Intermediate Algebra
  6. Level: Level 4, Type: Geometry
  7. Level: Level 2, Type: Intermediate Algebra
  8. Level: Level 5, Type: Counting & Probability
  9. Level: Level 5, Type: Algebra
  10. Level: Level 5, Type: Precalculus
  11. Level: Level 4, Type: Algebra
  12. Level: Level 4, Type: Intermediate Algebra
  13. Level: Level 5, Type: Counting & Probability
  14. Level: Level 5, Type: Counting & Probability
  15. Level: Level 3, Type: Counting & Probability
  16. Level:

In [30]:
# Load the JSON data
with open("math_cot_results_t=0.4_mnt=1500_tp=0.92.json", 'r') as f:
    results = json.load(f)
        
# Get the backtracking problems
backtracking_problems = analysis['backtracking_problems']
print(backtracking_problems)

# Print the answers
print("Backtracking Problems:")
for i, problem in enumerate(backtracking_problems):
    print(f"\nProblem {i+1}:")
    
    # Get the full problem details from results using the problem id
    problem_id = problem['id']
    full_problem = next((p for p in results if p['problem_id'] == problem_id), None)
    
    print(f"  Level: {problem['level']}")
    print(f"  Type: {problem['type']}")
    print(f"  Phrases: {', '.join(problem['phrases'])}")
    print(f"  Correct after backtracking: {problem['correct_after_backtracking']}")
    
    if full_problem:
        print(f"  Problem Text: {full_problem.get('problem_text', 'N/A')}")
        print(f"  Solution:")
        print(f"{full_problem.get('ground_truth_solution', 'N/A')}")
        print(f"  Generated Solution:")
        print(f"{full_problem.get('generated_cot', 'N/A')}")
    else:
        print(f"  Problem details not found for id: {problem_id}")

[{'id': 6, 'level': 'Level 3', 'type': 'Geometry', 'phrases': ['I made a mistake'], 'correct_after_backtracking': True}]
Backtracking Problems:

Problem 1:
  Level: Level 3
  Type: Geometry
  Phrases: I made a mistake
  Correct after backtracking: True
  Problem Text: When plotted in the standard rectangular coordinate system, trapezoid $ABCD$ has vertices $A(1, -2)$, $B(1, 1)$, $C(5, 7)$ and $D(5, 1)$. What is the area of trapezoid $ABCD$?
  Solution:
The two bases of the trapezoids are the segments $AB$ and $CD$, and the height is the perpendicular distance between the bases, which in this case is the difference of the $x$-coordinates: $5 - 1 = 4$. Similarly, the lengths of the bases are the differences of the $y$-coordinates of their two endpoints. Using the formula $A = \frac{1}{2}(b_1+ b_2)h$, the area is $\frac{1}{2}(3+6)(4) = \boxed{18}$ square units.
  Generated Solution:
Solve this math problem step by step. Put your final answer in \boxed{}. Problem: When plotted in the stand

In [79]:
def identify_backtracking_neurons_improved(model, json_file_path, device, top_k=50):
    """
    Identify neurons that activate during backtracking events by processing entire CoT solutions
    and tracking activations at specific backtracking points.
    
    Args:
        model: The HookedTransformer model
        json_file_path: Path to the JSON file with CoT results
        device: The device to run inference on
        top_k: Number of top neurons to identify
        
    Returns:
        Dictionary with neuron analysis results
    """
    # Load the results
    with open(json_file_path, 'r') as f:
        results = json.load(f)
    
    # Initialize storage for activations
    backtracking_activations = []  # Will store (layer, position, activations)
    non_backtracking_activations = []  # Will store (layer, position, activations)
    
    # Process a subset of examples for efficiency
    sample_size = min(100, len(results))
    sampled_results = random.sample(results, sample_size)
    
    print(f"Processing {sample_size} examples to identify backtracking neurons...")
    
    for result in tqdm(sampled_results):
        generated_cot = result.get("generated_cot", "")
        
        # Skip if the generated CoT is empty
        if not generated_cot:
            continue
        
        # Process the entire CoT solution
        tokens = model.to_tokens(generated_cot)
        str_tokens = model.to_str_tokens(generated_cot)
        
        # Run the model with cache to get all activations
        _, cache = model.run_with_cache(tokens)
        
        # Find positions of backtracking phrases in the token sequence
        backtracking_positions = []
        for phrase in backtracking_phrases:
            phrase_tokens = model.to_str_tokens(phrase)
            
            # Look for this phrase in the token sequence
            for i in range(len(str_tokens[0]) - len(phrase_tokens) + 1):
                # Check if this position contains the phrase
                match = True
                for j, token in enumerate(phrase_tokens):
                    if i+j >= len(str_tokens[0]) or str_tokens[0][i+j].lower() != token.lower():
                        match = False
                        break
                
                if match:
                    # Found a match, add the position range
                    backtracking_positions.append((i, i + len(phrase_tokens)))
        
        # If no backtracking phrases found, sample random positions as non-backtracking
        if not backtracking_positions:
            # Sample random positions (avoiding the beginning and end)
            if len(tokens[0]) > 20:
                num_samples = min(5, len(tokens[0]) - 10)
                for _ in range(num_samples):
                    pos = random.randint(5, len(tokens[0]) - 5)
                    # Extract activations for this position from all layers
                    for layer in range(model.cfg.n_layers):
                        layer_activations = cache["post", layer][0, pos].detach().cpu().numpy()
                        non_backtracking_activations.append((layer, layer_activations))
        else:
            # For each backtracking position, extract activations
            for start_pos, end_pos in backtracking_positions:
                # Get the position where backtracking starts
                trigger_pos = start_pos
                
                # Extract activations at the trigger position from all layers
                for layer in range(model.cfg.n_layers):
                    layer_activations = cache["post", layer][0, trigger_pos].detach().cpu().numpy()
                    backtracking_activations.append((layer, layer_activations))
                
                # Also sample non-backtracking positions from the same solution
                # (avoiding positions close to backtracking phrases)
                safe_positions = []
                for pos in range(5, len(tokens[0]) - 5):
                    # Check if this position is far from any backtracking phrase
                    is_safe = True
                    for bt_start, bt_end in backtracking_positions:
                        if pos >= bt_start - 10 and pos <= bt_end + 10:
                            is_safe = False
                            break
                    
                    if is_safe:
                        safe_positions.append(pos)
                
                # Sample from safe positions
                if safe_positions:
                    num_samples = min(len(backtracking_positions), len(safe_positions))
                    for pos in random.sample(safe_positions, num_samples):
                        # Extract activations for this position from all layers
                        for layer in range(model.cfg.n_layers):
                            layer_activations = cache["post", layer][0, pos].detach().cpu().numpy()
                            non_backtracking_activations.append((layer, layer_activations))
    
    # Analyze activations to find neurons that correlate with backtracking
    neuron_scores = {}
    
    # For each layer, analyze neuron activations
    for layer in range(model.cfg.n_layers):
        # Collect activations for this layer
        layer_backtracking = np.vstack([act for l, act in backtracking_activations if l == layer])
        layer_non_backtracking = np.vstack([act for l, act in non_backtracking_activations if l == layer])
        
        if len(layer_backtracking) == 0 or len(layer_non_backtracking) == 0:
            continue
        
        # For each neuron, calculate its activation difference
        neuron_scores[layer] = []
        
        for neuron_idx in range(layer_backtracking.shape[1]):
            # Extract this neuron's activations
            bt_activations = layer_backtracking[:, neuron_idx]
            non_bt_activations = layer_non_backtracking[:, neuron_idx]
            
            # Calculate mean activation for backtracking vs non-backtracking
            mean_backtracking = np.mean(bt_activations)
            mean_non_backtracking = np.mean(non_bt_activations)
            
            # Calculate effect size (Cohen's d)
            pooled_std = np.sqrt((np.var(bt_activations) + np.var(non_bt_activations)) / 2)
            effect_size = (mean_backtracking - mean_non_backtracking) / (pooled_std + 1e-10)
            
            # Create dataset for AUC calculation
            X = np.concatenate([bt_activations, non_bt_activations])
            y = np.concatenate([
                np.ones(len(bt_activations)),
                np.zeros(len(non_bt_activations))
            ])
            
            # Calculate AUC for this neuron
            try:
                auc = roc_auc_score(y, X)
            except:
                auc = 0.5  # Default if calculation fails
            
            neuron_scores[layer].append({
                'neuron': neuron_idx,
                'mean_diff': mean_backtracking - mean_non_backtracking,
                'effect_size': effect_size,
                'auc': auc
            })
        
        # Sort neurons by effect size
        neuron_scores[layer] = sorted(neuron_scores[layer], 
                                     key=lambda x: abs(x['effect_size']), 
                                     reverse=True)
    
    # Identify top neurons across all layers
    all_neurons = []
    for layer, neurons in neuron_scores.items():
        for neuron in neurons[:top_k]:
            all_neurons.append({
                'layer': layer,
                'neuron': neuron['neuron'],
                'effect_size': neuron['effect_size'],
                'auc': neuron['auc']
            })
    
    # Sort by absolute effect size
    all_neurons = sorted(all_neurons, key=lambda x: abs(x['effect_size']), reverse=True)
    
    return {
        'top_neurons': all_neurons[:top_k],
        'layer_scores': neuron_scores
    }
    
def identify_backtracking_neurons(model, json_file_path, device, top_k=50):
    """
    Identify neurons that activate during backtracking events.
    
    Args:
        model: The HookedTransformer model
        json_file_path: Path to the JSON file with CoT results
        device: The device to run inference on
        top_k: Number of top neurons to identify
        
    Returns:
        Dictionary with neuron analysis results
    """
    # Load the results
    with open(json_file_path, 'r') as f:
        results = json.load(f)
    
    # Initialize storage for activations
    backtracking_activations = []
    non_backtracking_activations = []
    
    # Process a subset of examples for efficiency
    sample_size = min(100, len(results))
    sampled_results = random.sample(results, sample_size)
    
    print(f"Processing {sample_size} examples to identify backtracking neurons...")
    
    for result in tqdm(sampled_results):
        problem_text = result.get("problem_text", "")
        generated_cot = result.get("generated_cot", "")
        
        # Skip if the generated CoT is empty
        if not generated_cot:
            continue
        
        # Identify backtracking phrases with context
        backtracking_instances = identify_backtracking(generated_cot)
        
        # If no backtracking, use this as a control example
        if not backtracking_instances or len(backtracking_instances) == 0:
            # Get a random segment from the CoT
            tokens = model.to_tokens(generated_cot)
            if len(tokens[0]) > 20:  # Ensure we have enough tokens
                # Take a random segment with context window similar to backtracking examples
                context_window = 50  # characters before and after, matching the backtracking case
                
                # Convert to text indices for consistency with backtracking case
                text_length = len(generated_cot)
                if text_length > context_window * 2:
                    # Pick a random center point
                    center_idx = random.randint(context_window, text_length - context_window)
                    context_start = center_idx - context_window
                    context_end = center_idx + context_window
                    context = generated_cot[context_start:context_end]
                else:
                    # If text is too short, use the whole text
                    context = generated_cot
                
                # Get activations for this context
                tokens = model.to_tokens(context)
                _, cache = model.run_with_cache(tokens)
                
                # Extract activations from all layers
                for layer in range(model.cfg.n_layers):
                    layer_activations = cache["post", layer].detach().cpu().numpy()
                    # Flatten across sequence positions
                    flat_activations = layer_activations.reshape(-1, layer_activations.shape[-1])
                    non_backtracking_activations.append((layer, flat_activations))
        else:
            # For each backtracking instance, get the surrounding context
            for phrase in backtracking_instances:
                # Locate the phrase in the generated CoT (case insensitive)
                phrase_lower = phrase.lower()
                generated_cot_lower = generated_cot.lower()
                start_idx = generated_cot_lower.find(phrase_lower)
                
                if start_idx != -1:
                    end_idx = start_idx + len(phrase)
                    
                    # Extract context around the backtracking phrase
                    context_window = 50  # characters before and after
                    context_start = max(0, start_idx - context_window)
                    context_end = min(len(generated_cot), end_idx + context_window)
                    context = generated_cot[context_start:context_end]
                    
                    # Convert context to tokens
                    tokens = model.to_tokens(context)
                    
                    # Get activations for this context
                    _, cache = model.run_with_cache(tokens)
                    
                    # Extract activations from all layers
                    for layer in range(model.cfg.n_layers):
                        layer_activations = cache["post", layer].detach().cpu().numpy()
                        # Flatten across sequence positions
                        flat_activations = layer_activations.reshape(-1, layer_activations.shape[-1])
                        backtracking_activations.append((layer, flat_activations))
    
    # Analyze activations to find neurons that correlate with backtracking
    neuron_scores = {}
    
    # For each layer, train a classifier to distinguish backtracking from non-backtracking
    for layer in range(model.cfg.n_layers):
        # Collect activations for this layer
        layer_backtracking = np.vstack([act for l, act in backtracking_activations if l == layer])
        layer_non_backtracking = np.vstack([act for l, act in non_backtracking_activations if l == layer])
        
        if len(layer_backtracking) == 0 or len(layer_non_backtracking) == 0:
            continue
        
        # Create dataset
        X = np.vstack([layer_backtracking, layer_non_backtracking])
        y = np.concatenate([
            np.ones(len(layer_backtracking)),
            np.zeros(len(layer_non_backtracking))
        ])
        
        # For each neuron, calculate its activation difference
        neuron_scores[layer] = []
        
        for neuron_idx in range(X.shape[1]):
            # Extract this neuron's activations
            neuron_activations = X[:, neuron_idx]
            
            # Calculate mean activation for backtracking vs non-backtracking
            mean_backtracking = np.mean(neuron_activations[:len(layer_backtracking)])
            mean_non_backtracking = np.mean(neuron_activations[len(layer_backtracking):])
            
            # Calculate effect size (Cohen's d)
            pooled_std = np.sqrt((np.var(neuron_activations[:len(layer_backtracking)]) + 
                                 np.var(neuron_activations[len(layer_backtracking):])) / 2)
            effect_size = (mean_backtracking - mean_non_backtracking) / (pooled_std + 1e-10)
            
            # Calculate AUC for this neuron
            try:
                auc = roc_auc_score(y, neuron_activations)
            except:
                auc = 0.5  # Default if calculation fails
            
            neuron_scores[layer].append({
                'neuron': neuron_idx,
                'mean_diff': mean_backtracking - mean_non_backtracking,
                'effect_size': effect_size,
                'auc': auc
            })
        
        # Sort neurons by effect size
        neuron_scores[layer] = sorted(neuron_scores[layer], 
                                     key=lambda x: abs(x['effect_size']), 
                                     reverse=True)
    
    # Identify top neurons across all layers
    all_neurons = []
    for layer, neurons in neuron_scores.items():
        for neuron in neurons[:top_k]:
            all_neurons.append({
                'layer': layer,
                'neuron': neuron['neuron'],
                'effect_size': neuron['effect_size'],
                'auc': neuron['auc']
            })
    
    # Sort by absolute effect size
    all_neurons = sorted(all_neurons, key=lambda x: abs(x['effect_size']), reverse=True)
    
    return {
        'top_neurons': all_neurons[:top_k],
        'layer_scores': neuron_scores
    }

def validate_backtracking_neurons(model, top_neurons, device, num_examples=10):
    """
    Validate the identified backtracking neurons by testing them on new examples.
    
    Args:
        model: The HookedTransformer model
        top_neurons: List of top neurons identified
        device: The device to run inference on
        num_examples: Number of examples to validate on
        
    Returns:
        Dictionary with validation results
    """
    # Function to sample problems from the dataset
    def sample_math_problems(dataset, n=5, level=None, problem_type=None):
        """
        Sample n problems from the dataset, optionally filtering by level or type.
        
        Args:
            dataset: The MATH dataset
            n: Number of problems to sample
            level: Optional filter for problem difficulty (e.g., "Level 1")
            problem_type: Optional filter for problem type (e.g., "Algebra")
        
        Returns:
            List of sampled problems
        """
        filtered_dataset = dataset['train']
        
        if level:
            filtered_dataset = [ex for ex in filtered_dataset if ex['level'] == level]
        
        if problem_type:
            filtered_dataset = [ex for ex in filtered_dataset if ex['type'] == problem_type]
        
        filtered_dataset = list(filtered_dataset)  # Convert to list to ensure it's a sequence
        return random.sample(filtered_dataset, min(n, len(filtered_dataset)))

    # Function to generate CoT using the model
    def generate_cot_for_problem(
        model: HookedTransformer, 
        problem: str, 
        temperature: float = 0.4, 
        max_new_tokens: int = 1500, 
        top_p: float = 0.92
    ):
        """
        Generate a chain-of-thought solution for a given math problem.
        
        Args:
            model: The HookedTransformer model
            problem: The math problem text
            temperature: The temperature for the model
            max_new_tokens: The maximum number of tokens to generate
            top_p: The top-p value for the model
        Returns:
            The generated chain-of-thought solution
        """
        prompt = f"""Solve this math problem step by step. Put your final answer in \\boxed{{}}. Problem: {problem} Solution: \n<think>\n"""
        result = model.generate(prompt, 
                                temperature=temperature,
                                max_new_tokens=max_new_tokens,
                                top_p=top_p)
        return result

    # Load the MATH dataset for validation
    math_dataset = load_dataset("fdyrd/math")
    validation_problems = sample_math_problems(math_dataset, n=num_examples)
    
    validation_results = []
    
    for problem in tqdm(validation_problems, desc="Validating neurons"):
        problem_text = problem['problem']
        
        # Generate a solution with backtracking
        solution = generate_cot_for_problem(model, problem_text)
        
        # Identify backtracking instances
        backtracking_instances = identify_backtracking(solution)
        
        if not backtracking_instances or len(backtracking_instances) == 0:
            continue
        
        # For the first backtracking instance, analyze neuron activations
        phrase = backtracking_instances[0]
        phrase_lower = phrase.lower()
        solution_lower = solution.lower()
        start_idx = solution_lower.find(phrase_lower)
                
        if start_idx == -1:
            continue
        
        end_idx = start_idx + len(phrase)
        
        # Extract context around the backtracking phrase
        context_window = 50  # characters before and after
        context_start = max(0, start_idx - context_window)
        context_end = min(len(solution), end_idx + context_window)
        context = solution[context_start:context_end]
        
        # Convert context to tokens
        tokens = model.to_tokens(context)
        
        # Run with cache to get activations
        _, cache = model.run_with_cache(tokens)
        
        # Check activation of top neurons
        neuron_activations = []
        
        for neuron_info in top_neurons[:10]:  # Check top 10 neurons
            layer = neuron_info['layer']
            neuron = neuron_info['neuron']
            
            # Get activations for this layer
            layer_activations = cache["post", layer].detach().cpu().numpy()
            
            # Get mean activation for this neuron across sequence positions
            mean_activation = np.mean(layer_activations[0, :, neuron])
            
            neuron_activations.append({
                'layer': layer,
                'neuron': neuron,
                'activation': float(mean_activation),
                'context': context,
                'backtracking_phrase': phrase
            })
        
        validation_results.append({
            'problem': problem_text,
            'solution': solution,
            'backtracking_phrase': phrase,
            'neuron_activations': neuron_activations
        })
    
    return validation_results

def visualize_neuron_activations(model, neuron_info, examples, device):
    """
    Visualize the activations of a specific neuron across different examples.
    
    Args:
        model: The HookedTransformer model
        neuron_info: Dictionary with neuron information (layer, index)
        examples: List of text examples to analyze
        device: The device to run inference on
        
    Returns:
        Matplotlib figure with visualization
    """
    layer = neuron_info['layer']
    neuron = neuron_info['neuron']
    
    activations_by_example = []
    
    for example in examples:
        # Get tokens
        tokens = model.to_tokens(example)
        str_tokens = model.to_str_tokens(example)
        
        # Run with cache
        _, cache = model.run_with_cache(tokens)
        
        # Get activations for this layer and neuron
        layer_activations = cache["post", layer][0, :, neuron].detach().cpu().numpy()
        
        activations_by_example.append((str_tokens, layer_activations))
    
    # Create visualization
    fig, axes = plt.subplots(len(examples), 1, figsize=(15, 4 * len(examples)))
    if len(examples) == 1:
        axes = [axes]
    
    for i, (tokens, activations) in enumerate(activations_by_example):
        ax = axes[i]
        
        # Plot activations
        ax.bar(range(len(activations)), activations)
        
        # Add token labels
        ax.set_xticks(range(len(tokens)))
        ax.set_xticklabels(tokens, rotation=45, ha='right')
        
        # Highlight tokens with high activation
        threshold = np.mean(activations) + np.std(activations)
        for j, act in enumerate(activations):
            if act > threshold:
                ax.get_xticklabels()[j].set_color('red')
                ax.get_xticklabels()[j].set_weight('bold')
        
        ax.set_title(f"Example {i+1}: Neuron {neuron} in Layer {layer}")
        ax.set_ylabel("Activation")
    
    plt.tight_layout()
    return fig

def ablate_neurons_and_test(model, top_neurons, test_problems, device):
    """
    Ablate (zero out) the identified neurons and test the effect on backtracking.
    
    Args:
        model: The HookedTransformer model
        top_neurons: List of top neurons to ablate
        test_problems: List of test problems
        device: The device to run inference on
        
    Returns:
        Dictionary with ablation results
    """
    # Define a hook function to ablate specific neurons
    def ablation_hook(activations, hook, neurons_to_ablate):
        # neurons_to_ablate is a list of (layer, neuron) tuples
        for layer, neuron in neurons_to_ablate:
            if hook.name == f"blocks.{layer}.hook_post":
                activations[0, :, neuron] = 0.0
        return activations
    
    # Prepare neurons to ablate
    neurons_to_ablate = [(n['layer'], n['neuron']) for n in top_neurons[:20]]  # Ablate top 20
    
    ablation_results = {
        'original': [],
        'ablated': []
    }
    
    for problem in tqdm(test_problems, desc="Testing ablation"):
        problem_text = problem['problem']
        
        # Generate solution without ablation
        original_prompt = f"Solve this math problem step by step. Put your final answer in \\boxed{{}}. Problem: {problem_text} Solution: \n<think>\n"
        original_solution = model.generate(original_prompt, 
                                         temperature=0.4,
                                         max_new_tokens=500,
                                         top_p=0.92)
        
        # Count backtracking instances in original
        original_backtracking = identify_backtracking_enhanced(original_solution)
        
        # Generate solution with ablation
        ablated_solution = ""
        
        # Set up hooks for ablation
        hooks = []
        for layer in set(layer for layer, _ in neurons_to_ablate):
            hook_name = f"blocks.{layer}.hook_post"
            hook_fn = lambda act, hook=None, neurons=neurons_to_ablate: ablation_hook(act, hook, neurons)
            hooks.append((hook_name, hook_fn))
        
        # Generate with hooks
        with model.hooks(hooks):
            ablated_solution = model.generate(original_prompt, 
                                            temperature=0.4,
                                            max_new_tokens=500,
                                            top_p=0.92)
        
        # Count backtracking instances in ablated
        ablated_backtracking = identify_backtracking_enhanced(ablated_solution)
        
        ablation_results['original'].append({
            'problem': problem_text,
            'solution': original_solution,
            'backtracking_count': len(original_backtracking),
            'backtracking_instances': original_backtracking
        })
        
        ablation_results['ablated'].append({
            'problem': problem_text,
            'solution': ablated_solution,
            'backtracking_count': len(ablated_backtracking),
            'backtracking_instances': ablated_backtracking
        })
    
    # Calculate summary statistics
    original_backtracking_count = sum(r['backtracking_count'] for r in ablation_results['original'])
    ablated_backtracking_count = sum(r['backtracking_count'] for r in ablation_results['ablated'])
    
    ablation_results['summary'] = {
        'original_backtracking_total': original_backtracking_count,
        'ablated_backtracking_total': ablated_backtracking_count,
        'percent_change': ((ablated_backtracking_count - original_backtracking_count) / 
                          max(1, original_backtracking_count)) * 100
    }
    
    return ablation_results

In [80]:
# Identify neurons associated with backtracking
neuron_analysis = identify_backtracking_neurons_improved(
    model=model,
    json_file_path="math_cot_results_t=0.4_mnt=1500_tp=0.92.json",
    device=device,
    top_k=50
)

# Print top neurons
print("Top 10 neurons associated with backtracking:")
for i, neuron in enumerate(neuron_analysis['top_neurons'][:10]):
    print(f"{i+1}. Layer {neuron['layer']}, Neuron {neuron['neuron']}: Effect size = {neuron['effect_size']:.4f}, AUC = {neuron['auc']:.4f}")

Processing 41 examples to identify backtracking neurons...


  5%|‚ñç         | 2/41 [00:45<14:00, 21.54s/it]

In [68]:
# Validate the identified neurons on new examples
validation_results = validate_backtracking_neurons(
    model=model,
    top_neurons=neuron_analysis['top_neurons'],
    device=device,
    num_examples=10
)

# Print validation results
print("\nValidation results:")
for result in validation_results:
    print(f"Problem: {result['problem'][:100]}...")
    print(f"Backtracking phrase: {result['backtracking_phrase']}")
    print("Top neuron activations:")
    for act in result['neuron_activations'][:3]:
        print(f"  Layer {act['layer']}, Neuron {act['neuron']}: {act['activation']:.4f}")
    print()

Validating neurons:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Validating neurons:   0%|          | 0/10 [02:03<?, ?it/s]


KeyboardInterrupt: 