# LLMs arithmetic Analysis

This notebook is a playground for the various experiments described in the paper (as well as additional experiments which weren't mentioned in the paper, for reference).

Most of the main experiments were further written to separate files (those named ```script_.*```). 

## Imports and setup

In [1]:
# Imports and setup

%load_ext autoreload
%autoreload 2

from general_utils import set_deterministic, set_cuda_device
set_cuda_device(0)

import re
import random
import pickle
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import transformer_lens as lens
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.io as pio

from transformers import AutoModelForCausalLM
from plotly import graph_objects as go
from plotly.subplots import make_subplots
from tqdm import tqdm
from ipywidgets import widgets
from functools import partial
from itertools import chain

from prompt_generation import generate_prompts, separate_prompts_and_answers, generate_all_prompts_for_operator, POSITIONS, OPERATORS, OPERATOR_NAMES, _is_number, _get_operand_range
from attention_analysis import ov_transition_analysis, two_operands_arithmetic_qk_heatmap
from visualization_utils import imshow, line, scatter, visualize_arithmetic_attention_patterns, multiple_lines, scatter_with_labels
from evaluation_utils import model_accuracy_on_simple_prompts
from circuit_utils import topk_effective_components
from component import Component
from general_utils import generate_activations, set_deterministic, get_neuron_importance_scores, get_model_consts, safe_eval, load_model, reduce_dimensionality
from linear_probing import linear_probe_across_layers
from metrics import indirect_effect
from eap.attr_patching import node_attribution_patching
from eap.eap_wrapper import EAP
from circuit import Circuit
from evaluation_utils import model_accuracy, circuit_faithfulness_with_mean_ablation
from activation_patching import activation_patching_experiment
from model_analysis_consts import LLAMA3_8B_CONSTS
from heuristics_classification import HeuristicAnalysisData, classify_heuristic_neurons, load_heuristic_classes
from heuristics_analysis import get_relevant_prompts, get_neurons_associated_with_prompt, heuristic_class_knockout_experiment, prompt_knockout_experiment, is_associated_heuristic


COLORBLIND_COLORS = ['#0173b2', '#de8f05', '#029e73','#d55e00', '#cc78bc', '#ca9161', '#fbafe4', '#949494', '#ece133', '#56b4e9']
torch.set_grad_enabled(False)
device = 'cuda'
seed = 42

In [None]:
model_name = "llama3-8b"
model_path = "/mnt/nlp/models/models--meta-llama--Meta-Llama-3-8B/snapshots/cd892e8f4da1043d4b01d5ea182a2e8412bf658f"

if model_path is None:
    model = lens.HookedTransformer.from_pretrained("meta-llama/Meta-Llama-3-8B", fold_ln=True, center_unembed=True, center_writing_weights=True, device=device)
else:
    model = lens.HookedTransformer.from_pretrained("meta-llama/Meta-Llama-3-8B", 
                                                hf_model=AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True),
                                                fold_ln=True, center_unembed=True, center_writing_weights=True, device=device)
model.set_use_split_qkv_input(True)
model.set_use_hook_mlp_in(True)
model.eval()

## Data preparation

In [3]:
# Generate correct prompts for the circuit discovery.
# Corrupt prompts are chosen later per experiment, and differ in operator and/or operands.

max_op = 300
op_ranges = {'+': (0, max_op), '-': (0, max_op), '*': (0, max_op), '/': (1, max_op)}
analysis_prompts_file_path = fr'./data/{model_name}/large_prompts_and_answers_max_op={max_op}.pkl'

set_deterministic(42)

if os.path.exists(analysis_prompts_file_path):
    with open(analysis_prompts_file_path, 'rb') as f:
        large_prompts_and_answers = pickle.load(f)
else:
    large_prompts_and_answers = generate_prompts(model, operand_ranges=op_ranges, correct_prompts=True, num_prompts_per_operator=None, single_token_number_range=(0, LLAMA3_8B_CONSTS.max_single_token))
    with open(analysis_prompts_file_path, 'wb') as f:
        pickle.dump(large_prompts_and_answers, f)

for i in range(len(large_prompts_and_answers)):
    random.shuffle(large_prompts_and_answers[i])


correct_prompts_and_answers = [pa[:50] for pa in large_prompts_and_answers]
evaluation_prompts_and_answers = [pa[50:] for pa in large_prompts_and_answers]

try:
    with open(fr'./data/{model_name}/large_incorrect_prompts_and_answers_max_op={max_op}.pkl', 'rb') as f:
        incorrect_prompts_and_answers = pickle.load(f)
except:
    incorrect_prompts_and_answers = generate_prompts(model, op_ranges, validate_numerals=True, correct_prompts=False, num_prompts_per_operator=None, single_token_number_range=(0, LLAMA3_8B_CONSTS.max_single_token))
    with open(fr'./data/{model_name}/large_incorrect_prompts_and_answers_max_op={max_op}.pkl', 'wb') as f:
        pickle.dump(incorrect_prompts_and_answers, f)

## Model performance

In [None]:
for operator in OPERATORS:
    min_op = 1 if operator == '/' else 0
    acc = model_accuracy_on_simple_prompts(model, min_op, max_op, (0, get_model_consts(model_name).max_single_token), [operator])
    print(f"The model accuracy on simple prompts with operator {operator} is: {acc :.3f}")

## Circuit Discovery

### Manual discovery code

In [None]:
# Initial, manual activation patching for separate attention heads + MLPs
# Fuller analysis is done in script_circuit_localization.py

operator_idx = 0
correct_pa = correct_prompts_and_answers[operator_idx]
corrupt_pa = random.sample(sum(correct_prompts_and_answers, []), len(correct_pa))
for token_pos in [4, 3, 2, 1]:
    seed = 42
    ie_over_layers_and_heads_and_mlp = torch.zeros((model.cfg.n_layers, model.cfg.n_heads + 1), dtype=torch.float32)

    def head_hooking_func(value, hook, head_index, token_pos, cache):
        if token_pos is None:
            value[:, :, head_index, :] = cache[hook.name][:, :, head_index, :] # For z hooking
        else:
            value[:, token_pos, head_index, :] = cache[hook.name][:, token_pos, head_index, :] # For z hooking
        return value

    # MLP
    ie_over_layers_and_heads_and_mlp[:, -1] = activation_patching_experiment(model, correct_pa, corrupt_prompts_and_answers=corrupt_pa, hookpoint_name='mlp_post',
                                                                            metric='IE-Logits',
                                                                            token_pos=token_pos,
                                                                            random_seed=seed).mean(dim=0)

    # Attention heads   
    for head_idx in range(model.cfg.n_heads):
        head_hook_fn = partial(head_hooking_func, head_index=head_idx, token_pos=token_pos)
        ie_over_layers_and_heads_and_mlp[:, head_idx] = activation_patching_experiment(model, correct_pa, corrupt_prompts_and_answers=corrupt_pa, hookpoint_name='z',
                                                                                        metric='IE-Logits',
                                                                                        token_pos=token_pos,
                                                                                    hook_func_overload=head_hook_fn, random_seed=seed).mean(dim=0)

    imshow(ie_over_layers_and_heads_and_mlp[:, :], x=[str(i) for i in range(model.cfg.n_heads)] +['mlp'], width=400, labels={'x':'Attn Head Idx / MLP', 'y': 'Layer'}, title='IE of different attention heads and MLP<br>')

### Linear probing

In [None]:
# Probe for answer in last token position

# Initial linear probing code. Meant to probe for answer token in the residual stream across layers and positions.
# Fuller code can be found in script_linear_probe.py

max_op = 300
max_answer_value = 1000

probe_accs = {}
for operator_idx in range(len(OPERATORS)):
    activations = None
    for pos_to_probe in POSITIONS:
        print(f"{pos_to_probe=}, {operator_idx=}")

        # Get training and testing data for the linear probe
        correct_prompts = separate_prompts_and_answers(large_prompts_and_answers[operator_idx])[0]
        random.shuffle(correct_prompts)
        answers = torch.tensor([safe_eval(prompt[:-1]) for prompt in correct_prompts])

        # Generate the activations once for all positions and cache it
        if activations is None:
            components = [Component('resid_post', layer=i) for i in range(model.cfg.n_layers)]
            activations = generate_activations(model, correct_prompts, components, pos=None)

        pos_activations = {i: activations[i][:, pos_to_probe] for i in range(model.cfg.n_layers)}
        probe_accs[(operator_idx, pos_to_probe)] = linear_probe_across_layers(model, pos_activations, answers, max_answer_value)[1]
        line(probe_accs[(operator_idx, pos_to_probe)], range_y=(0.0, 1.0), labels={'x': 'Layer', 'y': 'Test Accuracy'}, title=f'{OPERATOR_NAMES[operator_idx]} probing Accuracy Per Layer (At Position {pos_to_probe})')    

## Circuit Evaluations

In [None]:
# Generate mean activation cache for all relevant components

max_op = 300
eval_mean_cache_path = f'./data/{model_name}/mean_cache_for_evaluation_all_arithmetic_prompts_max_op={max_op}.pt'
if os.path.exists(eval_mean_cache_path):
    cached_activations = torch.load(eval_mean_cache_path)
    print('Loaded cached activations from file')
else:
    all_heads = [(l, h) for h in range(model.cfg.n_heads) for l in range(model.cfg.n_layers)]
    all_mlps = list(range(model.cfg.n_layers))
    model.set_use_attn_result(True)
    all_components = [Component('z', layer=l, head=h) for (l, h) in all_heads] + \
                     [Component('result', layer=l, head=h) for (l, h) in all_heads] + \
                     [Component('mlp_post', layer=l) for l in all_mlps] + \
                     [Component('mlp_in', layer=l) for l in all_mlps]
    all_prompts = [f"{x}{operator}{y}=" for operator in OPERATORS for x in range(0, max_op) for y in range(0, max_op)]
    cached_activations = generate_activations(model, all_prompts, all_components, pos=None, reduce_mean=True)
    cached_activations = {c: a[None, ...].to(device='cpu').repeat(50, 1, 1) for c, a in zip(all_components, cached_activations)}
    torch.save(cached_activations, eval_mean_cache_path)

In [None]:
# Initial code for manual evaluation of the discovered arithmetic circuit.
# Each circuit includes all MLPs and a subset of attention heads.


def build_circuit(operator_idx):      
    if operator_idx == 0:
        # Addition
        heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 3), (5, 31), (14, 12), (15, 13), (16, 21)]]
    elif operator_idx == 1:
        # Subtraction
        heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (13, 21), (13, 22), (14, 12), (15, 13), (16, 21)]]
    elif operator_idx == 2:
        # Multiplication
        heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 30), (8, 15), (9, 26), (13, 18), (13, 21), (13, 22), 
                                                                (14, 12), (14, 13), (15, 8), (15, 13), (15, 14), (15, 15), (16, 3), 
                                                                (16, 21), (17, 24), (17, 26), (18, 16), (20, 2), (22, 1)]]
    elif operator_idx == 3:
        # Division
        heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 31), (15, 13), (15, 14), (16, 21), (18, 16)]]

    full_mlps = [Component('mlp_post', layer=l) for l in range(model.cfg.n_layers)]
    full_circuit = Circuit(model.cfg)
    for c in list(set(heads + full_mlps)):
        full_circuit.add_component(c)
    return full_circuit


operator_idx = 0
max_op = 300
mean_cache = torch.load(f'./data/{model_name}/mean_cache_for_evaluation_all_arithmetic_prompts_max_op={max_op}.pt')

avg_nl_acc = 0
avg_acc = 0
seeds = [42, 412, 32879, 123, 436]
for seed in seeds:
    set_deterministic(seed)
    prompts_and_answers = random.sample(evaluation_prompts_and_answers[operator_idx], k=50)
    print(prompts_and_answers)

    full_circuit = build_circuit(operator_idx)
    nl_acc = circuit_faithfulness_with_mean_ablation(model, full_circuit, prompts_and_answers, mean_cache, metric='nl')
    avg_nl_acc += nl_acc
    print(f"Normalized Logit Acc (Seed {seed}): {nl_acc:.3f}")

avg_nl_acc = avg_nl_acc / len(seeds)
print(f"Avg Normalized Logit Acc: {avg_nl_acc:.3f}")

In [7]:
# Get the (pre-calculated) indirect effects of each component to 
# evaluate the circuit based on the amount of attention heads (by including only the highest effect attention heads)

def process_ie_maps(ie_maps):
    # Average across seeds
    summed_seed_ie_maps = {}
    seeds = set()
    for op_idx, pos, seed in ie_maps.keys():
        seeds.add(seed)
        if (op_idx, pos) not in summed_seed_ie_maps:
            summed_seed_ie_maps[(op_idx, pos)] = ie_maps[(op_idx, pos, seed)]
        else:
            summed_seed_ie_maps[(op_idx, pos)] += ie_maps[(op_idx, pos, seed)]
    ie_maps = {k: v / len(seeds) for (k, v) in summed_seed_ie_maps.items()}

    # Mean across positions
    ie_maps = {op_idx: torch.stack([ie_maps[(op_idx, pos)] for pos in POSITIONS]).mean(dim=0) for op_idx in range(4)}

    # Tensorify and log scale
    ie_maps = torch.stack([ie_maps[op_idx] for op_idx in range(len(OPERATORS))]) # ops, Layers, heads+mlp
    ie_maps = np.log1p(ie_maps)
    return ie_maps

ie_maps = process_ie_maps(torch.load(f'./data/{model_name}/ie_maps_activation_patching.pt'))
max_n_heads = 100

In [None]:
# Evaluate the faithfulness of the circuit with only some of the heads to understand if indeed a small number of sparse heads
# is enough to achieve high faithfulness to the model

mean_cache = torch.load(f'./data/{model_name}/mean_cache_for_evaluation_all_arithmetic_prompts_max_op={max_op}.pt')

def build_circuit(operator_idx, n_heads):
    heads = list(topk_effective_components(model, ie_maps[operator_idx], k=100, heads_only=True).keys())[:n_heads]
    full_mlps = [Component('mlp_post', layer=l) for l in range(model.cfg.n_layers)]
    full_circuit = Circuit(model.cfg)
    for c in list(set(heads + full_mlps)):
        full_circuit.add_component(c)
    return full_circuit

max_n_heads = 100
seeds = [42, 412, 32879, 123, 436]
faithfulness_results = torch.zeros((len(seeds), len(OPERATORS), max_n_heads))

for operator_idx in range(len(OPERATORS)):
    for seed_idx, seed in enumerate(seeds):
        set_deterministic(seed)
        prompts_and_answers = random.sample(evaluation_prompts_and_answers[operator_idx], k=50)
        print(operator_idx, seed)

        for n_heads in range(0, max_n_heads):
            full_circuit = build_circuit(operator_idx, n_heads)
            nl_acc = circuit_faithfulness_with_mean_ablation(model, full_circuit, prompts_and_answers, mean_cache, metric='nl')
            faithfulness_results[seed_idx, operator_idx, n_heads] = nl_acc
            print(f"{seed=}, operator={OPERATORS[operator_idx]}, {n_heads=}: Faithfulness: {nl_acc}")

## Deep Component Analysis

### Heads

In [None]:
# Show the average attention patterns of the heads with the highest indirect effect in the circuit

# Change to the visualized operator index
operator_idx = 0

# Get the indirect effect of each component
ie_maps = process_ie_maps(torch.load(f'./data/{model_name}/ie_maps_activation_patching.pt'))
most_effective_heads = topk_effective_components(model, ie_maps[operator_idx], k=20, heads_only=True) 

# Calculate the activation of the important heads for all valid arithmetic prompts
min_op = 1 if operator_idx == 3 else 0
prompts = generate_all_prompts_for_operator(OPERATORS[operator_idx], min_op, max_op, single_token_number_range=(0, LLAMA3_8B_CONSTS.max_single_token))

# Visualize the attention patterns of the most effective heads
head_html, head_patterns = visualize_arithmetic_attention_patterns(model, most_effective_heads, prompts, use_bos_token=True, return_raw_patterns=True)
torch.save((head_html, most_effective_heads, head_patterns), f'./data/{model_name}/mean_attn_head_patterns_{OPERATOR_NAMES[operator_idx]}.pt')
display(head_html)

In [None]:
# Present the QK circuit of attention heads.

max_op = 150

# Run once to calculate the QK heatmaps
for operator_idx in range(len(OPERATORS)):
    attention_values_file_path = f'./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_lastpos_attn_grid_max_operand={max_op}.pt'
    if os.path.exists(attention_values_file_path):
        continue
    else:
        attention_pattern_values = two_operands_arithmetic_qk_heatmap(model, OPERATORS[operator_idx], maximal_operand_value=max_op, dst_token_position=-1, show_progress=True)
        torch.save(attention_pattern_values, attention_values_file_path)


operator_idx = 0
all_attn_values = [torch.load(f'./data/{model_name}/{OPERATOR_NAMES[i]}_lastpos_attn_grid_max_operand={max_op}.pt') for i in range(len(OPERATORS))]
src_tokens = ['BOS', 'Operand1', 'Operator', 'Operand2', '=']

# Define the slider widgets
operator_slider = widgets.IntSlider(min=0, max=len(OPERATORS) - 1, value=0, description='Operator:')
layer_slider = widgets.IntSlider(min=0, max=model.cfg.n_layers - 1, value=0, description='Layer:')
head_slider = widgets.IntSlider(min=0, max=model.cfg.n_heads - 1, value=0, description='Head:')
src_pos_slider = widgets.IntSlider(min=0, max=len(src_tokens) - 1, value=0, description='Source Position:')

# Define the update function
def show_attention_pattern(op, layer, head, src_pos):
    attention_pattern_values = all_attn_values[op]
    attn_visualization = attention_pattern_values[layer, head, :, :, src_pos]
    plt.imshow(attn_visualization, cmap='hot', vmin=0.0, vmax=1.0, interpolation='nearest')
    plt.colorbar()
    plt.ylabel('Operand1')
    plt.xlabel('Operand2')  
    plt.title(f'Attention Pattern (Layer: {layer}, Head: {head}, Source Token: {src_tokens[src_pos]})')
    plt.show()

# Create and display the intereveractive widget
interactive_plot = widgets.interactive(show_attention_pattern, op=operator_slider, layer=layer_slider, head=head_slider, src_pos=src_pos_slider)
interactive_plot

In [None]:
# Present the OV circuit of attention heads. This is useful for linear projection heads.

arithmetic_words = [str(i) for i in range(0, 1000)]

layer_slider = widgets.IntSlider(min=0, max=model.cfg.n_layers-1, value=16, description='Layer:')
head_slider = widgets.IntSlider(min=0, max=model.cfg.n_heads-1, value=21, description='Head:')

def show_ov_circuit(layer, head):
    ov_circuit_heatmap = ov_transition_analysis(model, layer, head, arithmetic_words)
    ov_circuit_heatmap = ov_circuit_heatmap.cpu().numpy()
    plt.imshow(ov_circuit_heatmap, cmap='hot', interpolation='nearest')
    plt.colorbar()
    plt.ylabel('y')
    plt.xlabel('x')
    plt.title(f'OV Visualization (Layer: {layer}, Head: {head})')
    plt.show()
interactive_plot = widgets.interactive(show_ov_circuit, layer=layer_slider, head=head_slider)
interactive_plot

### MLPS

#### Per-neuron attribution

In [None]:
# Run node attribution patching (https://www.neelnanda.io/mechanistic-interpretability/attribution-patching) 
# to get a per-neuron effect approximation. 
# This is a faster method than running the activation patching described in the paper, but gives slightly less accurate results.

set_deterministic(42)

# Run node attribution patching for all operators
for operator_idx in range(len(OPERATORS)):
    model_cpu = load_model(model_name, model_path, device='cpu')
    model_cpu.set_use_attn_result(True)

    prompts_and_answers = correct_prompts_and_answers[operator_idx]
    corrupt_prompts_and_answers = random.sample(sum(correct_prompts_and_answers, []), k=len(prompts_and_answers)) # Sample randomly from all prompts

    attribution_scores = node_attribution_patching(model_cpu, prompts_and_answers, corrupt_prompts_and_answers,
                                                attributed_hook_names=['mlp.hook_post'],
                                                metric='IE', batch_size=1)
    torch.save(attribution_scores, f"./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_node_attribution_scores.pt")

#### Early MLP(s) Analysis

In [None]:
# Look at the activation of the highest-effect neurons in early MLPs (this is mainly (only?) valid for MLP0 as we observe
# the effect of the MLP on a single token, which ignores the role of pre-MLP attention heads.

operator_idx = 0
OP1_POS, OP2_POS = 1, 3
k_neurons = 5
layer = 0

# Find the highest-effect neurons
mlppost_neuron_scores = get_neuron_importance_scores(model, model_name, operator_idx=operator_idx, pos=OP1_POS)
top_mlp_neurons = mlppost_neuron_scores[layer].topk(k_neurons).indices.tolist()

# Calculate the MLP activations on a range of nuemrical tokens
operand_values = [str(i) for i in range(500)] + ['+', '-', '*', '/', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'zero', 'ten'] + ['february', 'march', 'april']
activations = generate_activations(model, operand_values, [Component('mlp_post', layer=layer)], pos=-1)[0]

# Calculate the logit lens of the relevant V vectors for the top neurons
vector_inputs = model.blocks[layer].mlp.W_out[top_mlp_neurons, :]
vector_logits = vector_inputs @ model.W_U
arithmetic_tokens = model.to_tokens([str(i) for i in range(500)], prepend_bos=False)
v_tokens = 10

# For each important neuron, show its activation as a function of the input token; And a list of the top boosted (and inhibited) 
# tokens for the corresponding V vector of that neuron
for i, neuron_idx in enumerate(top_mlp_neurons):
    neuron_activations = activations[:, neuron_idx]
    line(neuron_activations, title=f'Neuron {neuron_idx} activations', x=operand_values, labels={'y': 'Activation'})
    
    topk_tokens = model.to_str_tokens(vector_logits[i].topk(v_tokens).indices, prepend_bos=False)
    topk_arithmetic_tokens = model.to_str_tokens(arithmetic_tokens[vector_logits[i, arithmetic_tokens].view(-1).topk(v_tokens).indices], prepend_bos=False)
    bottomk_arithmetic_tokens = model.to_str_tokens(arithmetic_tokens[vector_logits[i, arithmetic_tokens].view(-1).topk(v_tokens, largest=False).indices], prepend_bos=False)
    print(f'Neuron {neuron_idx} logit lens:')
    print(f'Top overall {v_tokens} tokens: {topk_tokens}')
    print(f'Top arithmetic {v_tokens} tokens: {topk_arithmetic_tokens}')
    print(f'Bottom arithmetic {v_tokens} tokens: {bottomk_arithmetic_tokens}')

In [None]:
# Try different dimensionallity reduction techniques on the output of early MLPs
# to see if we can find any interesting patterns.

max_op = 300
visualize_mlp = 17
mlps_to_test = [Component('mlp_out', layer=i) for i in list(range(0, 17))]
dim_reduce_type = 'umap'

# Generate activations (this can be cached after first run for faster observations)
prompts = [f'{x}+{y}=' for x in range(max_op) for y in range(max_op)]
activations = generate_activations(model, prompts, mlps_to_test, pos=-1)

# Reduce dimensionality and visualize
component = mlps_to_test[visualize_mlp]
x, y = reduce_dimensionality(activations[visualize_mlp].cpu(), type=dim_reduce_type)
fig = go.Figure(data=go.Scatter(x=x, y=y, mode='markers',
                                marker={'color': [eval(p[:-1]) % 2 for p in prompts]}, 
                                hovertext=[f'{p}{eval(p[:-1])}' for p in prompts])) 
fig.update_layout(title=f'{dim_reduce_type} of arithmetic token embeddings post MLP{component.layer}')
fig.show()

#### MLP KV Analysis

In [None]:
# Get MLPPost neuron scores in the last token position
# and analyze K and V vectors (https://arxiv.org/abs/2012.14913) for top neurons

operator_idx = 0
pos = -1
k_neurons = 10
layer = 26
max_op_visualization = 50 # Can increase this value for higher resolutions activation patterns, at the cost of runtime

# Visualize highest-effect neurons in the MLP
mlppost_neuron_scores = get_neuron_importance_scores(model, model_name, operator_idx=operator_idx, pos=pos)
line(mlppost_neuron_scores[layer])
important_neurons = mlppost_neuron_scores[layer].topk(k_neurons, largest=True).indices.tolist()

# Calculate the logit lens of the relevant V vectors for the top neurons
important_vectors = model.blocks[layer].mlp.W_out[important_neurons, :]
vector_logits = important_vectors @ model.W_U

# Observe the logits of the token prior to the interesting MLP
w_in_logits = (model.blocks[layer].mlp.W_in[:, important_neurons].T @ model.W_U) #+ model.b_U

# We only care about the logits of the arithmetic tokens
arithmetic_labels = [label for label in model.tokenizer.vocab if _is_number(label.strip(" "))]
arithmetic_tokens = model.to_tokens(arithmetic_labels, prepend_bos=False).view(-1)

# Generate the key activation pattern for the neurons in the MLP
prompts = [f'{x}{OPERATORS[operator_idx]}{y}=' for x in range(max_op_visualization) for y in range(max_op_visualization)]
all_prompts_activations = generate_activations(model, prompts, [Component('mlp_post', layer=layer)], pos=pos)[0]

# Present all of the information for the top neurons ()
v_tokens = 10
for i, neuron in enumerate(important_neurons):
    topk_tokens = model.to_str_tokens(vector_logits[i].topk(v_tokens).indices, prepend_bos=False)
    topk_arithmetic_tokens = model.to_str_tokens(arithmetic_tokens[vector_logits[i, arithmetic_tokens].view(-1).topk(v_tokens).indices], prepend_bos=False)
    bottomk_arithmetic_tokens = model.to_str_tokens(arithmetic_tokens[vector_logits[i, arithmetic_tokens].view(-1).topk(v_tokens, largest=False).indices], prepend_bos=False)

    topk_w_in_tokens = model.to_str_tokens(w_in_logits[i].topk(v_tokens).indices, prepend_bos=False)
    topk_w_in_arithmetic_tokens = model.to_str_tokens(arithmetic_tokens[w_in_logits[i, arithmetic_tokens].view(-1).topk(v_tokens).indices], prepend_bos=False)
    bottomk_w_in_arithmetic_tokens = model.to_str_tokens(arithmetic_tokens[w_in_logits[i, arithmetic_tokens].view(-1).topk(v_tokens, largest=False).indices], prepend_bos=False)

    print(f'Neuron {neuron} logit lens:')
    print(f'IE effect: {mlppost_neuron_scores[layer][neuron]}')
    
    activation_img = all_prompts_activations[:, neuron].reshape((max_op_visualization, max_op_visualization))
    imshow(activation_img, x=list(range(max_op_visualization)), y=list(range(max_op_visualization)), labels={'x': 'Operand2', 'y': 'Operand1'}, width=600,
        title=f'Neuron {neuron} activations in MLP {layer} (pos {pos}) as function of operands')

    print(f'Top overall {v_tokens} tokens: {topk_tokens}')
    print(f'Top overall W_in {v_tokens} tokens: {topk_w_in_tokens}')

    print(f'Top arithmetic {v_tokens} tokens: {topk_arithmetic_tokens}')
    print(list(zip(topk_arithmetic_tokens, vector_logits[i, arithmetic_tokens].view(-1).topk(v_tokens, largest=True).values.tolist())))
    print(f'Top arithmetic W_in{v_tokens} tokens: {topk_w_in_arithmetic_tokens}')


    print(f'Bottom arithmetic {v_tokens} tokens: {bottomk_arithmetic_tokens}')
    print(list(zip(bottomk_arithmetic_tokens, vector_logits[i, arithmetic_tokens].view(-1).topk(v_tokens, largest=False).values.tolist())))
    print(f'Bottom arithmetic W_in {v_tokens} tokens: {bottomk_w_in_arithmetic_tokens}')

#### Per-Neuron evaluation

In [None]:
# Evaluate the circuit with a sparse subset of neurons in each middle- and late-layer MLP.
# This is an initial investigation, a full analysis is done in script_topk_neuron_eval.py.

set_deterministic(42)
mean_cache = torch.load(f'./data/{model_name}/mean_cache_for_evaluation_all_arithmetic_prompts_max_op={max_op}.pt')
operator_idx = 0
prompts_and_answers = random.sample(evaluation_prompts_and_answers[operator_idx], k=50)

# Ranking neurons according to Attribution patching
mlppost_neuron_scores = get_neuron_importance_scores(model, model_name, operator_idx=operator_idx, pos=-1)

def build_circuit(operator_idx, mlp_top_neurons):      
    if operator_idx == 0:
        # Addition
        heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 3), (5, 31), (14, 12), (15, 13), (16, 21)]]
    elif operator_idx == 1:
        # Subtraction
        heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (13, 21), (13, 22), (14, 12), (15, 13), (16, 21)]]
    elif operator_idx == 2:
        # Multiplication
        heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 30), (8, 15), (9, 26), (13, 18), (13, 21), (13, 22), 
                                                                (14, 12), (14, 13), (15, 8), (15, 13), (15, 14), (15, 15), (16, 3), 
                                                                (16, 21), (17, 24), (17, 26), (18, 16), (20, 2), (22, 1)]]
    elif operator_idx == 3:
        # Division
        heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 31), (15, 13), (15, 14), (16, 21), (18, 16)]]

    partial_mlp_layers = list(range(get_model_consts(model_name).first_heuristics_layer, model.cfg.n_layers))
    full_mlps = [Component('mlp_post', layer=l) for l in range(model.cfg.n_layers) if l not in partial_mlp_layers]
    partial_mlps = [Component('mlp_post', layer=l, neurons=mlp_top_neurons[l]) for l in partial_mlp_layers]    
    full_circuit = Circuit(model.cfg)
    for c in list(set(heads + full_mlps + partial_mlps)):
        full_circuit.add_component(c)
    return full_circuit


k_values = sorted([0, 10, 25, 50, 75, 100, 150] + list(range(200, 1000, 100)) + list(range(1000, 14000, 200)) + list(range(14000, model.cfg.d_mlp, 20)) + [model.cfg.d_mlp])
faithfulness_per_k = torch.zeros((len(k_values),))
for i, k in enumerate(k_values):
    mlp_top_neurons = {}
    attn_top_neurons = {}
    for layer in range(1, model.cfg.n_layers):
        mlp_top_neurons[layer] = mlppost_neuron_scores[layer].topk(k).indices.tolist()
    full_circuit = build_circuit(operator_idx, mlp_top_neurons)
    faithfulness_per_k[i] = circuit_faithfulness_with_mean_ablation(model, full_circuit, prompts_and_answers, mean_cache, metric='nl')
    print(f"Top-k neurons per layer: {k};\tFaithfulness: {faithfulness_per_k[i].item()}")

line(faithfulness_per_k, x=k_values, range_y=(0, 1.0), title=f'(Normalized Logit) faithfulness as a function of top k Neurons in each middle and late MLP', labels={'x':'k', 'y':'Faithfulness'})

## Heuristic analysis

In [None]:
# Some utility functions to be used later.
# Some global values used in these functions are defined in later cells and need to be run before CALLING these functions. Ugly, but that's what you get at a deadline.

def present_neuron(layer, neuron, use_kv_maps=True):
    """
    Visualizes and analyzes the activations and numerical token logits of a specific neuron.

    Args:
        layer (int): The layer number in the model where the neuron is located.
        neuron (int): The neuron index within the specified layer.
        use_kv_maps (bool, optional): Flag to determine whether to use key-value maps for prompts activations. 
                                      Defaults to True.
    
    Returns:
        None
    """
    # Choose the relevant activation pattern (either multiplied by V vector logits (for direct heuristics) or not multiplied (for indirect heuristics)
    if use_kv_maps:
        prompts_activations = kv_prompts_activations
    else:
        prompts_activations = k_prompts_activations
    v_tokens = 10

    # Get the logits of numerical tokens in the V vector
    vector_logits = model.blocks[layer].mlp.W_out[neuron] @ model.W_U
    topk_arithmetic_tokens = model.to_str_tokens(arithmetic_tokens[vector_logits[arithmetic_tokens].view(-1).topk(v_tokens).indices], prepend_bos=False)
    bottomk_arithmetic_tokens = model.to_str_tokens(arithmetic_tokens[vector_logits[arithmetic_tokens].view(-1).topk(v_tokens, largest=False).indices], prepend_bos=False)
    
    # Create the 2-D activation pattern of the neuron. Each cell in index (op1, op2) is the activation of the neuron for the prompt containing op1 and op2 (for the operator used in later cells)
    activation_img = torch.zeros((max_op - min_op, max_op - min_op))
    for i, (op1, op2) in enumerate(op1_op2_pairs):
        activation_img[op1 - min_op, op2 - min_op] = prompts_activations[(layer, neuron)][i]

    # Present all information
    print(f'Neuron {neuron}:')
    imshow(activation_img, x=list(range(min_op, max_op)), y=list(range(min_op, max_op)), labels={'x': 'Operand2', 'y': 'Operand1'}, width=600,
        title=f'Neuron {neuron} activations in MLP {layer} as function of operands')
    topk_op1_op2_pairs, bottomk_op1_op2_pairs =  _get_top_op1_op2_indices(layer, neuron, prompts_activations, top_k=50), \
                                                 _get_top_op1_op2_indices(layer, neuron, prompts_activations, top_k=50, is_top=False)
    print("Top 50 op1,op2 values in activation map: ", topk_op1_op2_pairs)
    print("Top results: ", [safe_eval(f"{op1}{OPERATORS[operator_idx]}{op2}") for (op1, op2) in topk_op1_op2_pairs])
    if not use_kv_maps:
        print("Bottom 50 op1,op2 values in activation map: ", bottomk_op1_op2_pairs)
        print("Bottom results: ", [safe_eval(f"{op1}{OPERATORS[operator_idx]}{op2}") for (op1, op2) in bottomk_op1_op2_pairs])
    if not use_kv_maps:
        print(f'Top arithmetic {v_tokens} tokens: {topk_arithmetic_tokens}')
        print(f'Bottom arithmetic {v_tokens} tokens: {bottomk_arithmetic_tokens}')
    print(sorted(rev_heuristic_classes[(layer, neuron)], key=lambda x:x[1], reverse=True)) # Show the heuristic matching scores of this neuron


def reverse_heuristic_dictionary(d):
    """
    Turn the heuristic classes list to a reversed dictionary of neuron->(heuristic_name, score)
    """
    result = {}
    for heuristic, layer_neuron_scores in d.items():
        for lns in layer_neuron_scores:
            result.setdefault(lns[:2], []).append((heuristic, lns[2]))
    return result


def _get_top_op1_op2_indices(layer, neuron_idx, prompts_activations, top_k=None, is_top=True):
    """
    Get, from a specific neurons activations, the highest (or lowest) activating pairs of (op1, op2) values.
    """
    activation_map = prompts_activations[(layer, neuron_idx)]
    if top_k is None:
        top_k = len(activation_map)
    top_op1_op2_pairs = op1_op2_pairs[activation_map.topk(top_k, largest=is_top).indices.cpu().numpy()].tolist()
    return top_op1_op2_pairs

In [None]:
# Calculate the activations of top neurons in each layer for a specific operator.
# This cell must be run for the other heuristic analysis cells to work.

# Settings
operator_idx = 0
topk_neurons_per_layer = 200
min_op = 1 if operator_idx == 3 else 0
max_op = 50

# Create an ORDERED list of all valid (op1, op2) pairs for the operator
op1_op2_pairs = torch.tensor(sorted([(op1, op2) for op1 in range(min_op, max_op) for op2 in _get_operand_range(OPERATORS[operator_idx], op1, min_op, max_op, get_model_consts(model_name).max_single_token)]))
prompts = [f'{op1}{OPERATORS[operator_idx]}{op2}=' for (op1, op2) in op1_op2_pairs]

# Get the numerical tokens in the model's vocabulary
arithmetic_labels = [label for label in model.tokenizer.vocab if _is_number(label.strip(" "), is_int=True)]
arithmetic_tokens = model.to_tokens(arithmetic_labels, prepend_bos=False).view(-1)

# Create a list of heuristical neurons in the relevant layers (16 - 31 for Llama3-8B)
neuron_importance_scores = get_neuron_importance_scores(model, model_name, operator_idx=operator_idx, pos=-1)
heuristic_neurons = []
for layer in range(get_model_consts(model_name).first_heuristics_layer, model.cfg.n_layers):
    heuristic_neurons += [(layer, neuron) for neuron in neuron_importance_scores[layer].topk(topk_neurons_per_layer).indices.tolist()]

# Calculate k (key) prompt activations
k_prompts_activations = generate_activations(model, prompts, [Component('mlp_post', layer=l) for l in range(model.cfg.n_layers)], pos=-1)
k_prompts_activations = {(layer, neuron): k_prompts_activations[layer][:, neuron] for (layer, neuron) in heuristic_neurons}

# Calculate kv (key-value) prompt activations by multiplying the key activations with the value vector logits
kv_prompts_activations = {}
results_for_all_pairs = [str(safe_eval(f'{op1}{OPERATORS[operator_idx]}{op2}')) for (op1, op2) in op1_op2_pairs]
results_for_all_pairs_labels = model.to_tokens(results_for_all_pairs, prepend_bos=False).view(-1)
for (layer, neuron) in tqdm(heuristic_neurons):
    v_vector_logits = (model.blocks[layer].mlp.W_out[neuron] @ model.W_U)
    logits_for_all_pairs = v_vector_logits[results_for_all_pairs_labels].cpu()
    kv_prompts_activations[(layer, neuron)] = k_prompts_activations[(layer, neuron)] * logits_for_all_pairs

#### Sepearating heuristics to classes

In [None]:
# Prep to make next cells run faster (top and bottom results are cached)

heuristic_datas = []
for use_kv_maps in [True, False]:
    prompts_activations = kv_prompts_activations if use_kv_maps else k_prompts_activations
    top_op1_op2_indices = {(layer, neuron): _get_top_op1_op2_indices(layer, neuron, prompts_activations, is_top=True) for (layer, neuron) in heuristic_neurons}
    top_results = {}
    for (layer, neuron) in tqdm(top_op1_op2_indices.keys()):
        top_results[(layer, neuron)] = [safe_eval(f"{op1}{OPERATORS[operator_idx]}{op2}") for (op1, op2) in top_op1_op2_indices[(layer, neuron)]]
        assert all([0 <= result <= get_model_consts(model_name).max_single_token for result in top_results[(layer, neuron)]])

    # Create the object containing all relevant data for heuristic analysis. This object is passed to the heuristic analysis functions.
    heuristic_data = HeuristicAnalysisData()
    heuristic_data.also_check_bottom_results = not use_kv_maps
    heuristic_data.op1_op2_pairs = op1_op2_pairs
    heuristic_data.top_op1_op2_indices = top_op1_op2_indices
    heuristic_data.top_results = top_results
    heuristic_data.max_op = max_op
    heuristic_data.max_single_token = get_model_consts(model_name).max_single_token
    heuristic_data.operator_idx = operator_idx
    heuristic_data.k_per_heuristic_cache = {}
    heuristic_datas.append(heuristic_data)

kv_heuristic_data, k_heuristic_data = heuristic_datas

In [None]:
# Run the neuron to heuristic classification process.

use_kv_maps = False
if use_kv_maps:
    heuristic_data = kv_heuristic_data
else:
    heuristic_data = k_heuristic_data
heuristic_classes = classify_heuristic_neurons(heuristic_neurons, heuristic_data)
torch.save(heuristic_classes, f'./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_heuristic_matches_dict_{"KV" if use_kv_maps else "K"}_maps.pt')

In [None]:
# Load the pre-calculated heuristic classes to neurons dictionary for a specific operator.

operator_idx = 0

MATCH_THRESHOLD = 0.6

heuristic_classes = load_heuristic_classes(f'./data/{model_name}', operator_idx, neuron_activations_type="HYBRID") # HYBRID means that direct heuristics are classified using the key-value maps, and indirect heuristics are classified using the key maps
heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= MATCH_THRESHOLD] for name, layer_neuron_scores in heuristic_classes.items()} # Filter by threshold

# Some extra prep for later presentation
classified_neurons = [(l, n) for (l, n, s) in sum(heuristic_classes.values(), [])]
print(f"Classified neurons: {len(classified_neurons)} (Unique: {len(set(classified_neurons))})")
unclassified_neurons = [(l, n) for (l, n) in heuristic_neurons if (l, n) not in classified_neurons]
rev_heuristic_classes = reverse_heuristic_dictionary(heuristic_classes)
rev_heuristic_classes.update({(l, n): [] for (l, n) in unclassified_neurons})

#### Analysing statistics regarding heuristic classes

In [None]:
# Draw a plot bar showing the amount of heuristics in each heuristics class, grouped by heuristic type, separated by layer

import pandas as pd
from collections import Counter

def union_dict_values_by_regex(dictionary, patterns):
    result_dict = {}
    for pattern in patterns:
        pattern_keys = [key for key in dictionary.keys() if re.search(pattern, key)]
        result_dict[pattern] = sum([dictionary[key] for key in pattern_keys], [])
    return result_dict

operator_idx = 0
MATCH_THRESHOLD = 0.6
heuristic_classes = load_heuristic_classes(f'./data/{model_name}', operator_idx, "HYBRID")
heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= MATCH_THRESHOLD] for name, layer_neuron_scores in heuristic_classes.items()} # Filter by threshold

heuristic_name_patterns = [r"op\d_\d+mod\d+", r"both_operands_\d+mod\d+", r"result_\d+mod\d+", 
                           r"op\d_region_\d+_\d+", r"both_operands_region_\d+_\d+",	r"result_region_\d+_\d+", 
                           r"same_operand", 
                           r"op\d_value_\d+", r"result_value_\d+", 
                           r"op\d_pattern_.*", r"result_pattern_.*", 
                           r"result_multi_value_.*"]
unified_heuristic_classes = union_dict_values_by_regex(heuristic_classes, heuristic_name_patterns)
df = pd.DataFrame([(key, layer, count) for key, value in unified_heuristic_classes.items() for layer, count in Counter(layer for layer, _, _ in value).items()],
                  columns=['group', 'layer', 'count'])
fig = go.Figure()
for layer in sorted(df['layer'].unique()):
    layer_data = df[df['layer'] == layer]
    fig.add_trace(go.Bar(
        x=layer_data['group'],
        y=layer_data['count'],
        name=str(layer)
    ))
fig.update_layout(
    barmode='group',
    xaxis={'title': 'Group'},
    yaxis={'title': 'Count'},
    title='Count by Heuristic Group and Layer'
)
fig.show()

#### Entire heuristic knockout experiment

Code for heuristic knockout experiment in the paper (section 4.2, first experiment).

In [None]:
set_deterministic(42)
operator_idx = 0
MATCH_THRESHOLD = 0.6
ACTIVATION_MAP_TYPE = ["KV", "K", "HYBRID"][2]

# Load heuristic classes
heuristic_classes_unfiltered = load_heuristic_classes(f"./data/{model_name}", operator_idx, ACTIVATION_MAP_TYPE)
min_op = 1 if operator_idx == 3 else 0

# Run the heuristic ablation experiment
heuristics_knockout_results = heuristic_class_knockout_experiment(heuristic_classes, 
                                                                  operator_idx, 
                                                                  large_prompts_and_answers, 
                                                                  model, 
                                                                  min_op, max_op,
                                                                  get_model_consts(model_name).max_single_token,
                                                                  heuristic_neuron_match_threshold=MATCH_THRESHOLD,
                                                                  seed=42, 
                                                                  verbose=True)
torch.save(heuristics_knockout_results, f'./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_heuristic_ablation_results_thres={MATCH_THRESHOLD}_{ACTIVATION_MAP_TYPE}_maps.pt')

In [None]:
# Expected hypothesis - the accuracy of prompts associated with an ablated heuristic (ablated related) should drop more than the accuracy of prompts not 
# associated with that heuristic (ablated unrelated).

# This is an initial version of the figure presented in the paper, used for investigations.

operator_idx = 0
MIN_NEURONS_PER_HEURISTIC = 10
MIN_SCORE_SUM_PER_HEURISTIC = 10
MATCH_THRESHOLD = 0.55
ACTIVATION_MAP_TYPE = ["KV", "K", "HYBRID"][2]

heuristics_knockout_results = torch.load(f'./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_heuristic_ablation_results_thres={MATCH_THRESHOLD}_{ACTIVATION_MAP_TYPE}_maps.pt')
heuristics_knockout_results = sorted(heuristics_knockout_results, key=lambda h: h.baseline_related - h.ablated_related, reverse=True)
heuristics_knockout_results = [h for h in heuristics_knockout_results if \
                                len(h.ablated_neurons)>= MIN_NEURONS_PER_HEURISTIC and 
                                h.ablated_neuron_matching_score >= MIN_SCORE_SUM_PER_HEURISTIC]

if ACTIVATION_MAP_TYPE == "KV":
    # No point in looking at operand heuristics
    heuristics_knockout_results = [h for h in heuristics_knockout_results if h.heuristic_name.startswith('result')]

heuristics_knockout_results = [h for h in heuristics_knockout_results if not h.heuristic_name.startswith('result')]


heuristic_names_to_test = [h.heuristic_name for h in heuristics_knockout_results]
baseline_related = torch.tensor([h.baseline_related for h in heuristics_knockout_results])
baseline_unrelated = torch.tensor([h.baseline_unrelated for h in heuristics_knockout_results])
ablated_related = torch.tensor([h.ablated_related for h in heuristics_knockout_results])
ablated_unrelated = torch.tensor([h.ablated_unrelated for h in heuristics_knockout_results])
ablated_neurons_counts = torch.tensor([len(h.ablated_neurons) for h in heuristics_knockout_results])
ablated_neuron_matching_scores = torch.tensor([h.ablated_neuron_matching_score for h in heuristics_knockout_results])

show_all_lines = True
        all_lines:
    lines = [baseline_related, baseline_unrelated, ablated_related, ablated_unrelated]
    line_titles = ["Baseline related", "Baseline unrelated", "Ablated related", "Ablated unrelated"]
    line_colors = ["blue", "red", "green", "purple"]
else:
    lines = [baseline_related, ablated_related]
    line_titles = [f"baseline", f"ablated"]
    line_colors = ["blue", "red"]

fig = multiple_lines(list(range(len(baseline_related))), lines, line_titles,
               title=rf"{OPERATOR_NAMES[operator_idx]} accuracy on related{' and unrelated prompts' if show_all_lines else ''}<br> before and after knockouts<br>" + 
               f"Sorted by knockout diff (On {ACTIVATION_MAP_TYPE} activation maps)",
               xaxis_title="Heuristic index",
               yaxis_title="Accuracy",
               hovertext=list(zip(heuristic_names_to_test, ablated_neurons_counts, ablated_neuron_matching_scores)),
               show_fig=False,
               width=500)

for i, l in enumerate(lines):
    fig.add_hline(y=l.mean(), line_dash="dash", line_color=line_colors[i])


# Add hovertext for the figure
fig.show()

#### Knockout prompt-related heuristics

Code for prompt-guided heuristic neurons knockout experiment in the paper (section 4.2, second experiment).

In [None]:
# The goal of this cell is to ablate neurons which are associated with a prompt, and see how it affects the model's accuracy on the prompt.
# Our hypothesis is that ablating neurons that belong to heuristics associated with a prompt will lead to a higher drop in accuracy compared to ablating random neurons.
# Failures can be explained by imperfect heuristics / misseing heuristic classes definitions.

# Run across several seeds to get a better feel for the best amount of neurons to knock out per operator

MATCH_THRESHOLD = 0.6
neuron_hard_limits = range(0, 201, 5)
seeds = [42]

baseline_results = torch.zeros((len(OPERATORS), len(neuron_hard_limits), len(seeds)))
ablated_results = torch.zeros((len(OPERATORS), len(neuron_hard_limits), len(seeds)))
ablated_neuron_counts = torch.zeros((len(OPERATORS), len(neuron_hard_limits), len(seeds)))
control_results = torch.zeros((len(OPERATORS), len(neuron_hard_limits), len(seeds)))

for neuron_hard_limit_idx, neuron_hard_limit in enumerate(tqdm(neuron_hard_limits)):
    for operator_idx in range(len(OPERATORS)):
        neuron_importance_scores = get_neuron_importance_scores(operator_idx=operator_idx, pos=-1)
        all_top_neurons = []
        for layer in range(16, model.cfg.n_layers):
            all_top_neurons += [(layer, neuron) for neuron in neuron_importance_scores[layer].topk(200).indices.tolist()]

        for seed_idx, seed in enumerate(seeds):
            set_deterministic(seed)
            heuristic_classes = load_heuristic_classes(f'./data/{model_name}', operator_idx, "HYBRID")
            heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= MATCH_THRESHOLD] for name, layer_neuron_scores in heuristic_classes.items()}
            prompts_and_answers = random.sample(evaluation_prompts_and_answers[operator_idx], k=50)
            baseline, ablated, ablated_neuron_avg_count, control_ablated = prompt_knockout_experiment(heuristic_classes, 
                                                                                                      model, prompts_and_answers, 
                                                                                                      neuron_count_hard_limit_per_layer=neuron_hard_limit,
                                                                                                      all_top_neurons=all_top_neurons,
                                                                                                      metric_fn=model_accuracy)
            baseline_results[operator_idx, neuron_hard_limit_idx, seed_idx] = baseline
            ablated_results[operator_idx, neuron_hard_limit_idx, seed_idx] = ablated
            ablated_neuron_counts[operator_idx, neuron_hard_limit_idx, seed_idx] = ablated_neuron_avg_count
            control_results[operator_idx, neuron_hard_limit_idx, seed_idx] = control_ablated
            print(f"{neuron_hard_limit=}, {operator_idx=}, {seed=}, {baseline=}, {ablated=}, {ablated_neuron_avg_count=}, {control_ablated=}")

torch.save((neuron_hard_limits, baseline_results, ablated_results, ablated_neuron_counts, control_results), f'./data/{model_name}/prompt_knockout_results_with_neuron_limits_per_layer.pt')

#### Debug failed heuristic knockouts

In [None]:
operator_idx = 0
MATCH_THRESHOLD = 0.55
heuristic_classes = torch.load(f'./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_heuristic_matches_dict_K_maps.pt')
heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= MATCH_THRESHOLD] for name, layer_neuron_scores in heuristic_classes.items()} # Filter by threshold
rev_heuristic_classes = reverse_heuristic_dictionary(heuristic_classes)

print(get_relevant_prompts("op2_pattern_.8.", operator_idx, min_op, max_op))

neurons_to_analyze = heuristic_classes[f"op2_pattern_.8."][:10]
for n in neurons_to_analyze:
    layer, neuron = n[0], n[1]
    present_neuron(layer, neuron, use_kv_maps=False)

In [None]:
# Attempt better method for pattern identification in the maps 

layer, neuron = 16, 2512
# present_neuron(layer, neuron, use_kv_maps=False)

def is_pattern_neuron_new(layer, neuron, pattern, op_index=1, use_kv_maps=False):
    # Get op1, op2 pairs relevant to the pattern
    relevant_pair_indices = [i for (i, (op1, op2)) in enumerate(op1_op2_pairs) if re.match(f"^{pattern}$", str((op1 if op_index == 1 else op2).item()).zfill(3))]
    
    # Get the activations of the neuron for the relevant pairs
    prompt_activations = kv_prompts_activations[(layer, neuron)] if use_kv_maps else k_prompts_activations[(layer, neuron)]
    relevant_activations = prompt_activations[relevant_pair_indices]
    

is_pattern_neuron_new(layer, neuron, '.8.'), is_pattern_neuron_new(layer, neuron, '.3.')#, get_periodic_patterns(layer, neuron, heuristic_analysis_data, op_index)

In [None]:
wave_heuristics = [(18, 391), (18, 13662), (19, 1064)] # WAVE HEURISTICS
non_wave_heuristics = [(19, 477), (18, 4040)]

from pywt import wavedec, waverec, threshold
from scipy import signal

def get_fft_frequencies(tensor, sample_rate=1.0):
        # Ensure the tensor is on CPU and convert to numpy array
        if tensor.is_cuda:
            tensor = tensor.cpu()
        signal = tensor.numpy()

        # Compute the FFT
        fft_result = torch.fft.fft(tensor)
        
        # Get the magnitude spectrum
        magnitude_spectrum = torch.abs(fft_result)
        
        # Compute the frequencies
        n = len(signal)
        freq = np.fft.fftfreq(n, d=1/sample_rate)
        
        return freq, magnitude_spectrum.numpy()

def detect_variable_oscillation(y, noise_threshold=12.0, min_peaks=3):
    # Denoise the signal using wavelet transform
    coeffs = wavedec(y, 'db4', level=2)
    coeffs[1:] = [threshold(i, value=noise_threshold*max(i), mode='soft') for i in coeffs[1:]]
    y_denoised = waverec(coeffs, 'db4')
    line(y_denoised)
    
    # Find peaks and troughs
    peaks, _ = signal.find_peaks(y_denoised)
    troughs, _ = signal.find_peaks(-y_denoised)

    # peak_distances = np.diff(peaks)
    peak_distance_relations = peaks[1:] / peaks[:-1]
    print(peak_distance_relations)
        
# op2_value = 13 # random
# values = prompts_activations[(layer, neuron)][(op1_op2_pairs[:, 1] == op2_value).nonzero()].view(-1)
# line(values)
# frequencies, magnitudes = get_fft_frequencies(values)
# frequencies = frequencies[:len(frequencies) // 2]
# magnitudes = magnitudes[:len(magnitudes) // 2]
# line(magnitudes, x=frequencies)
# threshold = 0.5 * magnitudes.max()
# dominant_freq_indices = np.where(magnitudes > threshold)[0]
# dominant_frequencies = frequencies[dominant_freq_indices]
# dominant_frequencies
# print(frequencies[magnitudes.argmax()].item() != 0)
# if frequencies[magnitudes.argmax()].item() != 0:
#     print(layer, neuron)

# peaks = signal.find_peaks(values, height=values.max().item() / 2, distance=20)
# peaks = signal.find_peaks_cwt(values, widths=5)
# print(len(peaks), peaks)
from visualization_utils import line
for l, n in wave_heuristics + non_wave_heuristics:
    op2_value = 13
    values = k_prompts_activations[(l, n)][(op1_op2_pairs[:, 1] == op2_value).nonzero()].view(-1)
    line(values)
    detect_variable_oscillation(values)

In [None]:
# Look for indirect heuristics in layer < 16 as well

operator_idx = 0
topk_neurons_per_layer = 200
min_op = 1 if operator_idx == 3 else 0
op1_op2_pairs = torch.tensor(sorted([(op1, op2) for op1 in range(min_op, max_op) for op2 in _get_operand_range(OPERATORS[operator_idx], op1, min_op, max_op, LLAMA3_8B_CONSTS.max_single_token)]))
prompts = [f'{op1}{OPERATORS[operator_idx]}{op2}=' for (op1, op2) in op1_op2_pairs]
arithmetic_labels = [label for label in model.tokenizer.vocab if _is_number(label.strip(" "), is_int=True)]
arithmetic_tokens = model.to_tokens(arithmetic_labels, prepend_bos=False).view(-1)

neuron_importance_scores = get_neuron_importance_scores(operator_idx=operator_idx, pos=-1)
heuristic_neurons = []
for layer in range(0, model.cfg.n_layers):
    heuristic_neurons += [(layer, neuron) for neuron in neuron_importance_scores[layer].topk(topk_neurons_per_layer).indices.tolist()]

# Calculate k (key) prompt activations
k_prompts_activations = generate_activations(model, prompts, [Component('mlp_post', layer=l) for l in range(model.cfg.n_layers)], pos=-1)
k_prompts_activations = {(layer, neuron): k_prompts_activations[layer][:, neuron] for (layer, neuron) in heuristic_neurons}

# Calculate kv (key-value) prompt activations by multiplying the key activations with the value vector logits
kv_prompts_activations = {}
results_for_all_pairs = [str(safe_eval(f'{op1}{OPERATORS[operator_idx]}{op2}')) for (op1, op2) in op1_op2_pairs]
results_for_all_pairs_labels = model.to_tokens(results_for_all_pairs, prepend_bos=False).view(-1)
for (layer, neuron) in tqdm(heuristic_neurons):
    v_vector_logits = (model.blocks[layer].mlp.W_out[neuron] @ model.W_U)
    logits_for_all_pairs = v_vector_logits[results_for_all_pairs_labels].cpu()
    kv_prompts_activations[(layer, neuron)] = k_prompts_activations[(layer, neuron)] * logits_for_all_pairs


#### Investigating failure modes

In [None]:
# Experimentation with the effect of heuristics on correctly-completed or incorrectly-completed prompts.
# This is the code behind section 4.3.

def get_key_activation(prompt, layer, neuron):
    """
    Find the activation of the key neuron (without considering any value vector logits) for a specific prompt.
    """
    op1, op2 = map(int, re.findall(r"\d+", prompt))
    prompt_index = ((op1_op2_pairs[:, 0] == op1) & (op1_op2_pairs[:, 1] == op2)).nonzero().item()
    return k_prompts_activations[(layer, neuron)][prompt_index]


def get_logit_of_correct_answer_from_neuron(prompt, layer, neuron):
    """
    Get the logit contribution of a specific neuron for a prompt.
    The logit contribution is based both on the activation of the neuron for the prompt, as well as the logit of the correct answer in the neuron's V vector.
    """
    # Find the activation of the neuron for the given prompt
    op1, op2 = map(int, re.findall(r"\d+", prompt))
    prompt_index = ((op1_op2_pairs[:, 0] == op1) & (op1_op2_pairs[:, 1] == op2)).nonzero().item()
    return kv_prompts_activations[(layer, neuron)][prompt_index]


MATCH_THRESHOLD = 0.6
heuristic_classes = load_heuristic_classes(f'./data/{model_name}', operator_idx, "HYBRID")
heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= MATCH_THRESHOLD] for name, layer_neuron_scores in heuristic_classes.items()}


def analyze_prompt_heuristic_properties(prompts):
    """
    Function to check various properties of the model's behavior on a set of prompts.
    Some properties include - the number of neurons associated with each prompt (as described in the paper),
        the total logit contribution of the associated neuron to the correct answer token (as described in the paper),
        
    """
    # key_activation_threshold = 0.05 # Magin number #1
    logit_contrib_threshold = 0.05 # Magic number #2

    avg_associated_neuron_count = 0
    avg_key_activation = 0
    avg_high_contrib_neuron_count = 0
    avg_matching_score, avg_filtered_matching_score = 0, 0
    logit_contribs = {}
    for prompt in tqdm(prompts):
        associated_neurons = get_neurons_associated_with_prompt(prompt)
        avg_associated_neuron_count += len(associated_neurons)
        logit_contribs[prompt] = torch.stack([get_logit_of_correct_answer_from_neuron(prompt, layer, neuron) for (layer, neuron) in associated_neurons])
        avg_key_activation += sum([get_key_activation(prompt, layer, neuron).abs().item() for (layer, neuron) in associated_neurons])
        avg_high_contrib_neuron_count += (logit_contribs[prompt] > logit_contrib_threshold).sum().item()

        all_heuristic_scores, high_contrib_heuristic_scores = [], []
        for i, (layer, neuron) in enumerate(associated_neurons):
            neuron_with_high_contrib = logit_contribs[prompt][i] > logit_contrib_threshold
            for heuristic_name, score in associated_neurons[(layer, neuron)]:
                all_heuristic_scores.append(score)
                if neuron_with_high_contrib:
                    high_contrib_heuristic_scores.append(score)
        avg_matching_score += sum(all_heuristic_scores) / len(all_heuristic_scores)
        avg_filtered_matching_score += sum(high_contrib_heuristic_scores) / len(high_contrib_heuristic_scores)

    avg_associated_neuron_count /= len(prompts_and_answers)
    avg_key_activation /= len(prompts_and_answers)
    avg_high_contrib_neuron_count /= len(prompts_and_answers)
    avg_matching_score /= len(prompts_and_answers)
    avg_filtered_matching_score /= len(prompts_and_answers)
    print(f"Average number of neurons that implement heuristics associated with the prompts: {avg_associated_neuron_count}")
    # print(f"Average key activation: {avg_key_activation}")
    print(f"Average number of neurons with high logit contrib: {avg_high_contrib_neuron_count}")
    # print(f"Average heuristic matching score of associated neurons: {avg_matching_score}")
    # print(f"Average heuristic matching score of associated neurons with high logit contrib: {avg_filtered_matching_score}")
    
    return logit_contribs

contribs = {}
for operator_idx in range(len(OPERATORS)):
    print("Correct prompts")
    correct_prompts = separate_prompts_and_answers(correct_prompts_and_answers[operator_idx])[0]
    logit_contribs = analyze_prompt_heuristic_properties(correct_prompts)

    print("Incorrect prompts")
    incorrect_prompts = random.sample(separate_prompts_and_answers(incorrect_prompts_and_answers[operator_idx])[0], len(correct_prompts))
    incorrect_logit_contribs = analyze_prompt_heuristic_properties(incorrect_prompts)

    contribs[OPERATORS[operator_idx]] = (logit_contribs, incorrect_logit_contribs)
torch.save(contribs, f'./data/{model_name}/correct_and_incorrect_prompts_heuristic_logit_contributions.pt')

## Figures for paper

### Save activation patterns (for many figures)

In [None]:
# NOTE: Before running this cell you need to run the code from the cells above to generate the activation maps for the relevant operator first

# Generate many activation pattern visualizations to use in later figures

NEURONS_TO_VISUALIZE_PER_LAYER = 50
neuron_vis_path = f"./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_neuron_visualizations"
neuron_importance_scores = get_neuron_importance_scores(operator_idx=operator_idx)
for activation_map_type in ["KV", "K"]:
    for layer in range(LLAMA3_8B_CONSTS.first_heuristics_layer, model.cfg.n_layers):
        top_neurons = neuron_importance_scores[layer].topk(NEURONS_TO_VISUALIZE_PER_LAYER).indices.tolist()
        for neuron in top_neurons:
            neuron_vis_path = f"./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_neuron_visualizations/mlp{layer}_neuron{neuron}_{activation_map_type}_map.png"
            if not os.path.exists(neuron_vis_path):
                activations = k_prompts_activations if activation_map_type == "K" else kv_prompts_activations
                activation_img = torch.zeros((max_op - min_op, max_op - min_op))
                for i, (op1, op2) in enumerate(op1_op2_pairs):
                    activation_img[op1 - min_op, op2 - min_op] = activations[(layer, neuron)][i]
                fig = px.imshow(activation_img, x=list(range(min_op, max_op)), y=list(range(min_op, max_op)), 
                                labels={'x': 'Operand2', 'y': 'Operand1'}, width=600, 
                                # title=f'MLP{layer}#{neuron} {activation_map_type} activation map as function of operands ({OPERATOR_NAMES[operator_idx]})',
                                color_continuous_midpoint=0.0, color_continuous_scale="RdBu")
                fig.update_xaxes(
                    title_font=dict(size=20, family='Arial', color='black'),
                    tickfont=dict(size=14, family='Arial', color='black'),
                    tickvals=None
                )
                fig.update_yaxes(
                    title_font=dict(size=20, family='Arial', color='black'),
                    tickfont=dict(size=14, family='Arial', color='black'),
                    tickvals=None
                )
                os.makedirs(os.path.dirname(neuron_vis_path), exist_ok=True)
                fig.write_image(neuron_vis_path)

### Heuristic types figure

In [None]:
size_increase = 0 # size increase across all fonts
layer_neurons = [(16, 6337), (0, 0), (17, 5628), (27, 9413), (0, 0)]
heuristic_titles = ["Op1 Range", "Op2 Modulo", "Op1 == Op2", "Result Range", "Result Pattern"]
assert len(layer_neurons) == len(heuristic_titles)

if 'activation_imgs' not in locals():
    max_op, min_op = 300, 0
    activation_imgs = torch.zeros(len(layer_neurons), max_op - min_op, max_op - min_op)
    for i, (layer, neuron) in enumerate(layer_neurons):
        if i == 1:
            # Hacky way to visualize this type of heuristic so it will be visible in the figure
            for j, (op1, op2) in enumerate(op1_op2_pairs):
                activation_imgs[i, op1 - min_op, op2 - min_op] = random.normalvariate(0.8, 0.2) if op2 % 9 == 2 else random.normalvariate(0.1, 0.05)
        elif i == len(layer_neurons) - 1:
            # Hacky way to visualize this type of heuristic so it will be visible in the figure
            for j, (op1, op2) in enumerate(op1_op2_pairs):
                activation_imgs[i, op1 - min_op, op2 - min_op] = random.normalvariate(0.8, 0.2) if (((op1 + op2) % 100) // 10) == 2 else random.normalvariate(0.1, 0.05)
        else:
            use_kv_maps = 'result' in heuristic_titles[i].lower()
            prompts_activations = kv_prompts_activations if use_kv_maps else k_prompts_activations
            for j, (op1, op2) in enumerate(op1_op2_pairs):
                activation_imgs[i, op1 - min_op, op2 - min_op] = prompts_activations[(layer, neuron)][j]
        activation_imgs[i] = (activation_imgs[i] - activation_imgs[i].min()) / (activation_imgs[i].max() - activation_imgs[i].min())

main_fig = make_subplots(rows=1, cols=len(layer_neurons), shared_yaxes=True, horizontal_spacing=0.02)
    
for i, (layer, neuron) in enumerate(layer_neurons):
    fig = px.imshow(activation_imgs[i], 
                    # x=list(range(min_op, max_op)), 
                    # y=list(range(min_op, max_op)), 
                    # labels={'x': 'Operand2', 'y': 'Operand1'},
                    color_continuous_midpoint=0.0, color_continuous_scale="RdBu")
    fig.update_layout(showlegend=False, title=None)
    main_fig.add_trace(fig.data[0], row=1, col=i + 1)
    main_fig.update_yaxes(row=1, col=i + 1, scaleanchor='x') # Invert y axis and make it same size as x axis
    if i == 0:
        main_fig.update_yaxes(title=dict(standoff=0, text='Operand 1', font=dict(size=17 + size_increase)), row=1, col=i + 1, tickvals=[], tickfont=dict(size=14 + size_increase))
    main_fig.update_xaxes(row=1, col=i + 1, tickvals=[], tickfont=dict(size=14 + size_increase))
    main_fig.update(layout_coloraxis_showscale=False)
    
    main_fig.add_annotation(
        text=f"<b>{heuristic_titles[i]}</b>",  # Title from heuristic_titles
        x=i / len(layer_neurons) + (1 / (2 * len(layer_neurons))),  # Center the title under each subfigure
        y=1.15,
        xref="paper", 
        yref="paper",
        showarrow=False,
        font=dict(size=17 + size_increase),
        xanchor="center"
    )

    main_fig.add_annotation(
        text="Operand 2",
        x=i / len(layer_neurons) + (1 / (2 * len(layer_neurons))),  # Center the title under each subfigure
        y=-0.15,
        xref="paper", 
        yref="paper",
        showarrow=False,
        font=dict(size=17 + size_increase),
        xanchor="center"
    )

main_fig.update_layout(
    width=800,
    height=180,
    margin=dict(t=20, b=20, r=0, l=30),
    coloraxis=dict(colorscale="RdBu", cmin=-1, cmax=1)
)
main_fig.show()

pio.write_image(main_fig, "./figs/heuristic_type_examples.pdf")

### Localization figure

In [None]:
# Localization results figure (Using activation patching results)

model_name = "llama3-8b"
attn_maps_sum_over_positions = False # If False, only the last position IE is presented

ie_maps = torch.load(f'./data/{model_name}/ie_maps_activation_patching.pt')

# Average across seeds
summed_seed_ie_maps = {}
op_idxs = [0, 1, 2, 3]
seeds = set([])
for op_idx, pos, seed in ie_maps.keys():
    seeds.add(seed)
    if (op_idx, pos) not in summed_seed_ie_maps:
        summed_seed_ie_maps[(op_idx, pos)] = ie_maps[(op_idx, pos, seed)]
    else:
        summed_seed_ie_maps[(op_idx, pos)] += ie_maps[(op_idx, pos, seed)]
ie_maps = {k: v / len(seeds) for (k, v) in summed_seed_ie_maps.items()}

# Average across operators
ie_maps = {pos: torch.stack([ie_maps[(op_idx, pos)] for op_idx in op_idxs]).mean(dim=0) for pos in POSITIONS}

# Tensorify
ie_maps = torch.stack([ie_maps[pos] for pos in POSITIONS]) # pos, Layers, heads+mlp
ie_maps = np.log1p(ie_maps)

attn_maps, mlp_maps = ie_maps[:, :, :-1], ie_maps[:, :, -1]
if attn_maps_sum_over_positions:
    attn_maps = attn_maps.sum(dim=0) # Sum effect from all positions
else:
    attn_maps = attn_maps[-1] # Take only last position
mlp_maps = mlp_maps.T


# Find the global min and max values for mutual scaling (for both MLP and ATTN subfigures)
global_min = min(attn_maps.min(), mlp_maps.min()).item()
global_max = max(attn_maps.max(), mlp_maps.max()).item()

# Create subplots
attn_title = "Attention Heads (Summed over positions)" if attn_maps_sum_over_positions else "Attention Heads (Last Position)"
fig = make_subplots(rows=1, cols=2, subplot_titles=[attn_title, "MLPs"], shared_yaxes=True, horizontal_spacing=0.02)

# Add heatmaps to subplots
fig.add_trace(go.Heatmap(z=attn_maps, coloraxis="coloraxis", zsmooth=False), row=1, col=1)
fig.add_trace(go.Heatmap(z=mlp_maps, coloraxis="coloraxis", zsmooth=False), row=1, col=2)

fig.update_xaxes(title=dict(text="Attention Head", font=dict(size=16), standoff=8), tickfont=dict(size=15), domain=[0, 0.7], row=1, col=1)
fig.update_yaxes(title=dict(text="Layer", font=dict(size=16), standoff=15), tickfont=dict(size=15), autorange="reversed", row=1, col=1)

fig.update_xaxes(title=dict(text="Position", font=dict(size=16), standoff=8), tickfont=dict(size=15), domain=[0.8, 1.0], row=1, col=2)
fig.update_yaxes(autorange="reversed", row=1, col=2)

# Update layout
fig.update_layout(
    title={
        'text': "Effect Map Per Attn Head / MLP",
        'x': 0.55, 'y': 0.98,
        'font': dict(size=17)
    },
    margin=dict(l=0, r=0, t=50, b=0),
    height=200,
    width=500,
    coloraxis=dict(
        colorscale='Blues',
        cmin=global_min,
        cmax=global_max,
        colorbar=dict(
            title=dict(text="log<br>scale", font=dict(size=17)),
            tickfont=dict(size=15),
            thickness=20,
            len=1.6,
            yanchor="middle",
            y=0.65,
            xanchor="left",
            x=1.02
        )
    )
)

fig.layout.annotations[0].update(x=0.34, y=1.01, font=dict(size=16))
fig.layout.annotations[1].update(x=0.90, y=1.01, font=dict(size=16))

pio.write_image(fig, f'./figs/{model_name}_localization.pdf')
fig.show()

### Linear Probing for answer figure

In [None]:
# Linear probing results

probe_accs = torch.load(f'./data/{model_name}/probe_accs.pt')
# Average across operators
probe_accs = {pos_to_probe: torch.tensor([probe_accs[(operator_idx, pos_to_probe)] for operator_idx in range(len(OPERATORS))]).mean(dim=0) for pos_to_probe in POSITIONS}
# Tensorify
probe_accs_tensor = torch.stack([probe_accs[pos_to_probe] for pos_to_probe in POSITIONS])

# Draw the figure
fig = px.imshow(probe_accs_tensor, 
       y=['Operand1', 'Operator', 'Operand2', '='],
       x=list(range(model.cfg.n_layers)),
       width=350,
       height=120,
       zmin=0,
       color_continuous_midpoint=0.0, color_continuous_scale="blues"
)
fig.update_xaxes(title=dict(text="Layer", standoff=5, font=dict(size=17)), tickfont=dict(size=15))
fig.update_yaxes(title=dict(text="Position", standoff=10, font=dict(size=17)), tickfont=dict(size=15))
fig.update_layout(title_x=0.53, title_y=1.0, title_font=dict(size=17), margin=dict(l=0, r=0, t=0, b=0))
fig.update_coloraxes(colorbar=dict(len=1.2,  
                                   thickness=20,  # Adjust thickness as needed
                                   yanchor="middle",  # Anchor colorbar to the middle
                                   y=0.5))  # Center the colorbar vertically

fig.show()
pio.write_image(fig, f'./figs/probing_acc.pdf')

### Top-K MLP Neuron eval figure

In [None]:
# Top-k neurons localization results

k_values = torch.tensor(sorted(list(range(0, model.cfg.d_mlp, 10)) + [model.cfg.d_mlp]))
seeds = [42, 412, 32879, 123]
results_file_path = f'./data/{model_name}/topk_neuron_faithfulness_evaluation_results.pt'

# # GENERATE DATA
# max_op = 300
# mean_cache = torch.load(f'./data/{model_name}/mean_cache_for_evaluation_all_arithmetic_prompts_max_op={max_op}.pt')
# if os.path.exists(results_file_path):
#     faithfulness_per_k = torch.load(results_file_path)
# else:
#     faithfulness_per_k = {}
# for operator_idx in [0,1,2,3]:
#     for seed in seeds:
#         if (operator_idx, seed) in faithfulness_per_k:
#             print(f"Found results file for {operator_idx=}, {seed=}")
#             continue
#         set_deterministic(seed)
#         prompts_and_answers = random.sample(evaluation_prompts_and_answers[operator_idx], k=50)
#         mlppost_neuron_scores = get_neuron_importance_scores( operator_idx=operator_idx, pos=-1) # Ranking neurons according to Attribution patching
#        def build_circuit(operator_idx, mlp_top_neurons):      
#            if operator_idx == 0:
#                # Addition
#                heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 3), (5, 31), (14, 12), (15, 13), (16, 21)]]
#            elif operator_idx == 1:
#                # Subtraction
#                heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (13, 21), (13, 22), (14, 12), (15, 13), (16, 21)]]
#            elif operator_idx == 2:
#                # Multiplication
#                heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 30), (8, 15), (9, 26), (13, 18), (13, 21), (13, 22), 
#                                                                        (14, 12), (14, 13), (15, 8), (15, 13), (15, 14), (15, 15), (16, 3), 
#                                                                        (16, 21), (17, 24), (17, 26), (18, 16), (20, 2), (22, 1)]]
#            elif operator_idx == 3:
#                # Division
#                heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 31), (15, 13), (15, 14), (16, 21), (18, 16)]]
#            partial_mlp_layers = list(range(get_model_consts(model_name).first_heuristics_layer, model.cfg.n_layers))
#            full_mlps = [Component('mlp_post', layer=l) for l in range(model.cfg.n_layers) if l not in partial_mlp_layers]
#            partial_mlps = [Component('mlp_post', layer=l, neurons=mlp_top_neurons[l]) for l in partial_mlp_layers]    
#            full_circuit = Circuit(model.cfg)
#            for c in list(set(heads + full_mlps + partial_mlps)):
#                full_circuit.add_component(c)
#            return full_circuit
#         faithfulness_per_k[(operator_idx, seed)] = torch.zeros((len(k_values),))
#         for i, k in enumerate(k_values):
#             mlp_top_neurons = {}
#             for mlp in range(1, model.cfg.n_layers):
#                 mlp_top_neurons[mlp] = mlppost_neuron_scores[mlp].topk(k).indices.tolist()
#             full_circuit = build_circuit(operator_idx, mlp_top_neurons)
#             faithfulness_per_k[(operator_idx, seed)][i] = circuit_faithfulness_with_mean_ablation(model, full_circuit, prompts_and_answers, mean_cache, metric='nl')
#             print(k, faithfulness_per_k[(operator_idx, seed)][i].item())
#         torch.save(faithfulness_per_k, results_file_path)


# DRAW FIGURE
colors = COLORBLIND_COLORS
operator_labels = ["+", "-", "×", "÷"]
faithfulness_per_k = torch.load(results_file_path)
# Average across seeds
faithfulness_per_k = {operator_idx: torch.stack([faithfulness_per_k[(operator_idx, seed)] for seed in seeds]).mean(dim=0) for operator_idx in range(len(OPERATORS))}
# Stack across operators
faithfulness_per_k = torch.stack([faithfulness_per_k[operator_idx] for operator_idx in range(len(OPERATORS))])

fig = go.Figure()
for i in range(len(OPERATORS)):
    fig.add_trace(go.Scatter(x=k_values, y=faithfulness_per_k[i], mode='lines', name=operator_labels[i], line=dict(color=colors[i])))

x_axis_percentages = [0.01, 0.1, 1]
fig.update_layout(
    xaxis=dict(
        type="log",
        tickvals=[percent * model.cfg.d_mlp for percent in x_axis_percentages],
        ticktext=[f"{int(percent * 100)}%" for percent in x_axis_percentages],
        tickfont=dict(size=15),
        title="Neurons used Per Layer (%)",
        title_font=dict(size=16),
    ),
    yaxis=dict(
        title="Faithfulness",
        tickfont=dict(size=15),
        title_font=dict(size=16),
        range=(0, 1.0)
    ),
    legend=dict(itemwidth=30, itemsizing='constant', yanchor="bottom", 
                y=0.0, xanchor="center", x=0.9, font=dict(size=16), bgcolor='rgba(0,0,0,0)'),
    title="Faithfulness of using only top-k neurons",
    title_x=0.55, title_y=0.98, title_font=dict(size=16), 
    margin=dict(l=0, r=0, t=20, b=0), 
    width=400, height=200, 
)
fig.show()

pio.write_image(fig, f'./figs/faithfulness_topk_neurons.pdf')

### Effect of neurons in specific layer

In [None]:
layer = 17
mlppost_neuron_scores = get_neuron_importance_scores(model, model_name, operator_idx=0, pos=-1)[layer] # Ranking neurons according to Attribution patching
fig = go.Figure(data=go.Scatter(x=list(range(len(mlppost_neuron_scores))), y=mlppost_neuron_scores, 
                                mode="markers",
                                marker=dict(size=5, color=colors[0]), name="+", showlegend=True))
    
fig.update_yaxes(title='Intervention Effect', range=[-0.02, 0.1], tickfont=dict(size=15), title_font=dict(size=16), tickvals=[0, 0.05])
fig.update_xaxes(title=f'Neuron Index', tickfont=dict(size=15), title_font=dict(size=16))
fig.update_layout(
    title=f'Individual MLP Neuron Intervention Effects',
    title_x=0.55,
    title_y=0.98,
    width=400,
    height=200,
    title_font=dict(size=16),
    margin=dict(l=0, r=0, t=20, b=0),
    legend=dict(itemwidth=30, itemsizing='constant', yanchor="bottom", 
                y=-0.3, xanchor="center", x=0.9, font=dict(size=16), bgcolor='rgba(0,0,0,0)'),
)
fig.show()

pio.write_image(fig, f'./figs/mlp_neuron_intervention_effects_layer_{layer}.pdf')

### Knockout figure

In [None]:
ablated_accs = []
related_diff = []
MATCH_THRESHOLD = 0.6
ACTIVATION_MAP_TYPE = "HYBRID"
for operator_idx in range(4):
    heuristics_knockout_results = torch.load(f'./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_heuristic_ablation_results_thres={MATCH_THRESHOLD}_{ACTIVATION_MAP_TYPE}_maps.pt')
    heuristics_knockout_results = sorted(heuristics_knockout_results, key=lambda h: h.baseline_related - h.ablated_related, reverse=True)
    heuristics_knockout_results = [h for h in heuristics_knockout_results if \
                                    len(h.ablated_neurons)>= MIN_NEURONS_PER_HEURISTIC and 
                                    h.ablated_neuron_matching_score >= MIN_SCORE_SUM_PER_HEURISTIC]

    if ACTIVATION_MAP_TYPE == "KV":
        # No point in looking at operand heuristics
        heuristics_knockout_results = [h for h in heuristics_knockout_results if h.heuristic_name.startswith('result')]

    baseline_related = torch.tensor([h.baseline_related for h in heuristics_knockout_results])
    ablated_related = torch.tensor([h.ablated_related for h in heuristics_knockout_results])
    ablated_accs += ablated_related.cpu().tolist()
    related_diff += (baseline_related - ablated_related).cpu().tolist()

print(f"Average ablated acc: {sum(ablated_accs) / len(ablated_accs)}")
print(f"Average related diff: {sum(related_diff) / len(related_diff)}")

In [None]:
# Figure settings

size_increase = 2 # font size increase across all fonts
show_all_lines = True
MIN_NEURONS_PER_HEURISTIC = 15
MIN_SCORE_SUM_PER_HEURISTIC = 15
MATCH_THRESHOLD = 0.55
model_name = 'llama3-8b'
colors = COLORBLIND_COLORS # px.colors.qualitative.Plotly
ACTIVATION_MAP_TYPE = ["KV", "K", "HYBRID"][2]


all_figs = []
for operator_idx in range(len(OPERATORS)):
    heuristics_knockout_results = torch.load(f'./data/{model_name}/{OPERATOR_NAMES[operator_idx]}_heuristic_ablation_results_thres={MATCH_THRESHOLD}_{ACTIVATION_MAP_TYPE}_maps.pt')
    heuristics_knockout_results = sorted(heuristics_knockout_results, key=lambda h: h.baseline_related - h.ablated_related, reverse=True)
    heuristics_knockout_results = [h for h in heuristics_knockout_results if \
                                    len(h.ablated_neurons)>= MIN_NEURONS_PER_HEURISTIC and 
                                    h.ablated_neuron_matching_score >= MIN_SCORE_SUM_PER_HEURISTIC]

    if ACTIVATION_MAP_TYPE == "KV":
        # No point in looking at operand heuristics
        heuristics_knockout_results = [h for h in heuristics_knockout_results if h.heuristic_name.startswith('result')]

    heuristic_names_to_test = [h.heuristic_name for h in heuristics_knockout_results]
    baseline_related = torch.tensor([h.baseline_related for h in heuristics_knockout_results])
    baseline_unrelated = torch.tensor([h.baseline_unrelated for h in heuristics_knockout_results])
    ablated_related = torch.tensor([h.ablated_related for h in heuristics_knockout_results])
    ablated_unrelated = torch.tensor([h.ablated_unrelated for h in heuristics_knockout_results])
    ablated_neurons_counts = torch.tensor([len(h.ablated_neurons) for h in heuristics_knockout_results])
    ablated_neuron_matching_scores = torch.tensor([h.ablated_neuron_matching_score for h in heuristics_knockout_results])

    if show_all_lines:
        lines = [ablated_related, ablated_unrelated]
        line_titles = [f"Effect of heuristic ablation on related prompts", "Effect of heuristic ablation on unrelated prompts"]
        line_colors = colors[:2]
        # lines = [baseline_unrelated, baseline_related, ablated_unrelated, ablated_related]
        # line_titles = ["Baseline unrelated", "Baseline related", "Ablated unrelated", "Ablated related"]
        # line_colors = colors[:4]
    else:
        lines = [ablated_related, baseline_related]
        line_titles = [f"ablated", "baseline"]
        line_colors = [colors[0]] + [colors[2]]

    fig = multiple_lines(list(range(len(baseline_related))), lines, line_titles,
                title=rf"{OPERATOR_NAMES[operator_idx]} accuracy on related{' and unrelated prompts' if show_all_lines else ''}<br> before and after knockouts<br>" + 
                "Sorted by knockout diff",
                xaxis_title="Heuristic index",
                yaxis_title="Accuracy",
                hovertext=list(zip(heuristic_names_to_test, ablated_neurons_counts, ablated_neuron_matching_scores)),
                show_fig=False,
                width=500)
    
    fig.update_layout(
        showlegend=False,
        title=None
    )

    all_figs.append(fig)


fig = make_subplots(rows=1, cols=4, shared_yaxes=True, horizontal_spacing=0.02)
for i, subfig in enumerate(all_figs, start=1):
    for j, trace in enumerate(subfig.data):
        fig.add_trace(
            go.Scatter(
                x=trace.x, 
                y=trace.y, 
                name=trace.name, 
                mode='lines',
                line=dict(color=line_colors[j]),
                showlegend=i==1  # Only show legend for the first subplot
            ),
            row=1, 
            col=i
        )
        mean_value = f"Mean: {trace.y.mean():.2f}"
        fig.add_annotation(
            x=0.95*len(trace.x),  # Position near the right edge of the subplot
            y=(j * 0.14),  # Stagger the text boxes (upper one higher)
            xref=f'x{i}',  # Reference the x-axis of the current subplot
            yref=f'y{i}',  # Reference the y-axis of the current subplot
            text=mean_value,
            showarrow=False,
            font=dict(size=15+size_increase, color=colors[j]),
            borderwidth=0,
            xanchor='right',
            yanchor='bottom'
        )
        # fig.add_hline(y=trace.y.mean(), line_dash="dash", line_color=line_colors[j], row=1, col=i, name=f"Mean {trace.name}", showlegend=i == 1)
    if i == 1:
        fig.update_yaxes(title=dict(text='Accuracy', font=dict(size=17+size_increase)), row=1, col=i, tickfont=dict(size=15+size_increase))
    fig.update_xaxes(title=dict(standoff=0, text="Heuristic index", font=dict(size=17+size_increase)), row=1, col=i, 
                     tickvals=list(range(0, len(trace.y) - 2, int(len(trace.y) / 4.9))), tickfont=dict(size=15+size_increase))



# Operator names subtitles
for i, subtitle in enumerate(OPERATOR_NAMES):
    fig.add_annotation(
        x=(i + 1 - 0.5) / 4,  # This centers the subtitle over each subplot
        y=1.0,  # Adjust this value to move subtitles up or down
        xref='paper',
        yref='paper',
        text=subtitle[0].upper() + subtitle[1:],
        showarrow=False,
        font=dict(size=18+size_increase),
        xanchor='center',
        yanchor='bottom'
    )
    

fig.update_layout(
    height=250, width=1100, 
    legend=dict(
        orientation="h",
        yanchor="top",
        y=-0.25,
        xanchor="center",
        x=0.5,
        font=dict(size=16+size_increase)
    ),
    margin=dict(l=0, r=0, t=22, b=0)
)


fig.show()
pio.write_image(fig, f'./figs/knockout_{ACTIVATION_MAP_TYPE.lower()}_new.pdf')

### Prompt-guided knockout results

In [None]:
per_layer = True
operator_labels = ['+', '-', '×', '÷']
model_name = 'llama3-8b'
neuron_hard_limits, baseline_results, ablated_results, ablated_neuron_counts, control_results = torch.load(f'./data/{model_name}/prompt_knockout_results_with_neuron_limits{"_per_layer" if per_layer else "_choose_neurons_from_important"}.pt')
assert torch.all(baseline_results == 1)

baseline_results = baseline_results.squeeze(-1)
ablated_results = ablated_results.squeeze(-1)
control_results = control_results.squeeze(-1)

colors = px.colors.qualitative.Plotly
fig = go.Figure()
fig.add_trace(go.Scatter(x=[neuron_hard_limits[0], neuron_hard_limits[-1]], y=[1, 1], mode='lines', name='Baseline', line=dict(color='black')))
for i, op in enumerate(OPERATORS):
    fig.add_trace(go.Scatter(x=neuron_hard_limits, y=ablated_results[i], mode='lines', name=f'{operator_labels[i]}',
                            line=dict(color=colors[i])))
    fig.add_trace(go.Scatter(x=neuron_hard_limits, y=control_results[i], mode='lines', name=f'', showlegend=True,
                            line=dict(color=colors[i], dash='dash')))

fig.update_yaxes(title=dict(text="Accuracy", font=dict(size=17)), range=(0, 1.0), tickvals=[0.25, 0.5, 0.75, 1], tickfont=dict(size=15))
fig.update_xaxes(title=dict(text=f'Ablated Neurons{" (Per Layer)" if per_layer else ""}', font=dict(size=17)), tickfont=dict(size=15))
fig.update_layout(
    legend_title='',
    title_x=0.5, title_y=0.95,
    margin=dict(l=0, r=5, t=30, b=0), 
    width=800,
    height=250,
    title=dict(
        text="Prompt-guided heuristic knockout accuracies",
        font=dict(size=16),
        xanchor='center',
        yanchor='top'
    ),
)
fig.show()
pio.write_image(fig, f'./figs/prompt_knockout{"_per_layer" if per_layer else ""}.pdf')

In [None]:
# UNIFIED FOR llama3-8b and llama3-70b

operator_labels = ['+', '-', '×', '÷']
colors = COLORBLIND_COLORS# px.colors.qualitative.Plotly

def create_8b_figure():
    model_name = 'llama3-8b'
    neuron_hard_limits, baseline_results, ablated_results, ablated_neuron_counts, control_results = torch.load(f'./data/{model_name}/prompt_knockout_results_with_neuron_limits_per_layer.pt')
    assert torch.all(baseline_results == 1)

    baseline_results = baseline_results.squeeze(-1)
    ablated_results = ablated_results.squeeze(-1)
    control_results = control_results.squeeze(-1)

    traces = []
    # traces.append(go.Scatter(x=[neuron_hard_limits[0], neuron_hard_limits[-1]], y=[1, 1], mode='lines', name='', line=dict(color='grey')))
    # traces.append(go.Scatter(x=[neuron_hard_limits[0], neuron_hard_limits[-1]], y=[1, 1], mode='lines', name='', line=dict(color='grey', dash='dash')))
    for i, op in enumerate(OPERATORS):
        traces.append(go.Scatter(x=neuron_hard_limits, y=ablated_results[i], mode='lines', name=f'{operator_labels[i]}',
                                line=dict(color=colors[i])))
        traces.append(go.Scatter(x=neuron_hard_limits, y=control_results[i], mode='lines', name=f'', showlegend=False,
                                line=dict(color=colors[i], dash='dash')))
    return traces


def create_70b_figure():
    model_name = 'llama3-70b'
    neuron_hard_limits = torch.load(f'./data/{model_name}/addition_prompt_ablation_results_thres=0.6_HYBRID_maps.pt')[0]
    lim = len(neuron_hard_limits) # //2 + 1 # UNCOMMENT FOR HALF X-AXIS (Same values as llama3-8b)
    neuron_hard_limits = neuron_hard_limits[:lim]
    traces = []
    # traces.append(go.Scatter(x=[neuron_hard_limits[0], neuron_hard_limits[-1]], y=[1, 1], mode='lines', name='Baseline', line=dict(color='black'), showlegend=False))
    for i, op in enumerate(operator_labels):
        neuron_hard_limits, baseline_results, ablated_results, ablated_neuron_counts, control_results = torch.load(f'./data/{model_name}/{OPERATOR_NAMES[i]}_prompt_ablation_results_thres=0.6_HYBRID_maps.pt')
        neuron_hard_limits = list(neuron_hard_limits)[:lim]
        baseline_results, ablated_results, ablated_neuron_counts, control_results = baseline_results[:lim], ablated_results[:lim], ablated_neuron_counts[:lim], control_results[:lim]
        assert torch.all(baseline_results == 1)
        traces.append(go.Scatter(x=neuron_hard_limits, y=ablated_results, mode='lines', name=f'{op}', line=dict(color=colors[i]), showlegend=False))
        traces.append(go.Scatter(x=neuron_hard_limits, y=control_results, mode='lines', name=f'', line=dict(color=colors[i], dash='dash'), showlegend=False))
    return traces



# Create subplots
fig = make_subplots(rows=1, cols=2, shared_yaxes=True, subplot_titles=("Llama3-8B", "Llama3-70B"), horizontal_spacing=0.03)

fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', name='Heuristic Ablation', line=dict(color='grey')))
fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', name='Random Ablation', line=dict(dash='dash', color='grey')))

# Add traces for llama3-8b
for trace in create_8b_figure():
    fig.add_trace(trace, row=1, col=1)

# Add traces for llama3-70b
for trace in create_70b_figure():
    fig.add_trace(trace, row=1, col=2)

# Update layout
fig.update_yaxes(title=dict(text="Accuracy", font=dict(size=17)), range=(0, 1.02), tickvals=[0.25, 0.5, 0.75, 1], tickfont=dict(size=15), row=1, col=1)
fig.update_yaxes(range=(0, 1.02), tickvals=[0.25, 0.5, 0.75, 1], row=1, col=2)
fig.update_xaxes(title=dict(standoff=0, text=f'Ablated Neurons (Per Layer)', font=dict(size=15)), tickvals=list(range(0, 81, 20)), tickfont=dict(size=15), row=1, col=1)
fig.update_xaxes(title=dict(standoff=0, text=f'Ablated Neurons (Per Layer)', font=dict(size=15)), tickvals=list(range(0, 201, 40)), tickfont=dict(size=15), row=1, col=2)

fig.update_layout(
    legend_title='',
    margin=dict(l=0, r=0, t=20, b=83),
    width=800,
    height=250,
    legend=dict(
        font=dict(size=15),
        orientation="h",
        yanchor="bottom",
        y=-0.45,
        xanchor="center",
        x=0.5,
    )
)

fig.show()

pio.write_image(fig, f'./figs/llama3_8b_70b_prompt_knockout_per_layer.svg')

### Faithfulness as function of head count

In [None]:
# Generate the data for the figure

mean_cache = torch.load(f'./data/{model_name}/mean_cache_for_evaluation_all_arithmetic_prompts_max_op={max_op}.pt')
ie_maps = process_ie_maps(torch.load(f'./data/{model_name}/ie_maps_activation_patching.pt'))
max_n_heads = 100
max_op = 300


def build_circuit(operator_idx, n_heads):
    heads = list(topk_effective_components(model, ie_maps[operator_idx], k=100, heads_only=True).keys())[:n_heads]
    full_mlps = [Component('mlp_post', layer=l) for l in range(model.cfg.n_layers)]
    full_circuit = Circuit(model.cfg)
    for c in list(set(heads + full_mlps)):
        full_circuit.add_component(c)
    return full_circuit

max_n_heads = 100
seeds = [42, 412, 32879, 123, 436]
faithfulness_results = torch.zeros((len(seeds), len(OPERATORS), max_n_heads))

for operator_idx in range(len(OPERATORS)):
    for seed_idx, seed in enumerate(seeds):
        set_deterministic(seed)
        prompts_and_answers = random.sample(evaluation_prompts_and_answers[operator_idx], k=50)
        print(operator_idx, seed)

        for n_heads in tqdm(range(0, max_n_heads)):
            full_circuit = build_circuit(operator_idx, n_heads)
            nl_acc = circuit_faithfulness_with_mean_ablation(model, full_circuit, prompts_and_answers, mean_cache, metric='nl')
            faithfulness_results[seed_idx, operator_idx, n_heads] = nl_acc
    
        torch.save(faithfulness_results, f'./data/{model_name}/faithfulness_results_over_head_count_single_list.pt')

In [None]:
# Draw the figure itself

model_name = 'llama3-8b'
max_n_heads = 100
faithfulness_results = torch.load(f'./data/{model_name}/faithfulness_results_over_head_count.pt')
print(faithfulness_results.shape)
faithfulness_results = faithfulness_results.mean(dim=0)
                         
fig = multiple_lines(list(range(max_n_heads - 50)), faithfulness_results, ['+', '-', '×', '÷'], show_fig=False, colors=COLORBLIND_COLORS)
fig.update_yaxes(title=dict(text="Faithfulness", font=dict(size=17), standoff=10), range=(0, 1.0), tickvals=[0.5, 1.0], tickfont=dict(size=16))
fig.update_xaxes(title=dict(text="Number of heads in circuit", font=dict(size=17), standoff=5), tickfont=dict(size=16))
fig.update_layout(
    legend=dict(itemwidth=30, orientation="h", itemsizing='constant', 
                yanchor="bottom", y=-0.5, xanchor="center", x=0.5, font=dict(size=17)),
    margin=dict(l=0, r=00, t=40, b=0), 
    width=400, height=250,
    title=dict(
        text="Llama3-8B circuit faithfulness<br>as function of number of circuit heads",
        x=0.58, y=0.94,
        font=dict(size=17),
        xanchor='center',
        yanchor='top'
    ),
    )

fig.show()

pio.write_image(fig, fr'./figs/faithfulness_over_head_count.pdf')

### Finding failure modes

In [None]:
contribs = torch.load(f'./data/{model_name}/correct_and_incorrect_prompts_heuristic_logit_contributions.pt')

logit_contribs, incorrect_logit_contribs = contribs[OPERATORS[0]]

correct_y = [logit_contribs[prompt].sum().item() for prompt in logit_contribs]
incorrect_y = [incorrect_logit_contribs[prompt].sum().item() for prompt in incorrect_logit_contribs]

correct_mean = np.mean(correct_y)
incorrect_mean = np.mean(incorrect_y)
correct_std = np.std(correct_y)
incorrect_std = np.std(incorrect_y)

fig = go.Figure()

fig.add_trace(go.Box(
    y0=0,
    x=correct_y,
    name="Correct",
    boxpoints='all',
    jitter=0.3,
    pointpos=-1.8,
    marker_color='blue',
    orientation='h',
    boxmean=True,
    whiskerwidth=0.1,
    width=0.1,
    marker=dict(size=5)
))

fig.add_trace(go.Box(
    y0=0.4,
    x=incorrect_y,
    name="Incorrect",
    boxpoints='all',
    jitter=0.3,
    marker_color='red',
    orientation='h',
    boxmean=True,
    whiskerwidth=0.0,
    width=0.1,
    marker=dict(size=5),
))

fig.update_layout(
    width=450,
    height=100,
    margin=dict(l=0, r=0, t=0, b=0),
    showlegend=False
)
fig.update_xaxes(title=dict(standoff=5, text="Correct answer logit contribution", font=dict(size=17)), tickfont=dict(size=15))
fig.update_yaxes(tickvals=[0, 0.35], ticktext=["Correct", "Incorrect"], title=dict(text="Prompt Type", font=dict(size=17)), tickfont=dict(size=15))
fig.show()

pio.write_image(fig, f'./figs/correct_and_incorrect_prompts_heuristic_logit_contributions.pdf')

## Appendices

##### Additional circuit components

In [None]:
# MLP0 neuron effect scatter plot

colors = COLORBLIND_COLORS
layer = 0
positions = [1, 2, 3]
mlp0_neuron_scores = sum(get_neuron_importance_scores(ranking_method="mean", operator_idx=0, pos=pos)[0] for pos in positions) / len(positions)

fig = go.Figure(data=go.Scatter(x=list(range(len(mlp0_neuron_scores))), y=mlp0_neuron_scores, 
                                mode="markers",
                                marker=dict(size=5, color=colors[0]), name="+", showlegend=True))
    
fig.update_yaxes(title='Intervention Effect', range=[-0.02, 0.1], tickfont=dict(size=15), title_font=dict(size=16), tickvals=[0, 0.05])
fig.update_xaxes(title=f'Neuron Index', tickfont=dict(size=15), title_font=dict(size=16))
fig.update_layout(
    title=f'Individual MLP Neuron Intervention Effects',
    title_x=0.55,
    title_y=0.98,
    width=400,
    height=200,
    title_font=dict(size=16),
    margin=dict(l=0, r=0, t=20, b=0),
    showlegend=False
)
fig.show()

pio.write_image(fig, f'./figs/mlp_neuron_intervention_effects_layer_{layer}.pdf')

In [None]:
# Top-K Neuron eval in MLP0

max_op = 300
later_layers_k = int(model.cfg.d_mlp * 0.01) # How many neurons to include in the circuit in each middle- and late- layer?
mean_cache = torch.load(f'./data/{model_name}/mean_cache_for_evaluation_all_arithmetic_prompts_max_op={max_op}.pt')
k_values = torch.tensor(sorted(list(range(0, 1000, 10)) + list(range(1000, model.cfg.d_mlp, 100)) + [model.cfg.d_mlp]))
seeds = [42, 412, 32879, 123]
results_file_path = f'./data/{model_name}/mlp0_topk_neuron_faithfulness_evaluation_results.pt'
if os.path.exists(results_file_path):
    faithfulness_per_k = torch.load(results_file_path)
else:
    faithfulness_per_k = {}

# GENERATE DATA
for seed in seeds:
    for operator_idx in range(len(OPERATORS)):
        if (operator_idx, seed) in faithfulness_per_k:
            print(f"Found results file for {operator_idx=}, {seed=}")
            continue
        set_deterministic(seed)
        prompts_and_answers = random.sample(evaluation_prompts_and_answers[operator_idx], k=50)
        mlppost_neuron_scores = get_neuron_importance_scores(model, model_name, operator_idx=operator_idx, pos=-1) # Ranking neurons according to Attribution patching

        mlp0_important_positions = [1, 2, 3] # Op1, Operator, Op2. MLP0 isn't important in the last position nor in the BoS position.
        mlp0_neuron_scores = sum(get_neuron_importance_scores(model, model_name, operator_idx=operator_idx, pos=pos)[0] for pos in mlp0_important_positions) / len(mlp0_important_positions)

        def build_circuit(operator_idx, mlp_top_neurons):      
            if operator_idx == 0:
                # Addition
                heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 3), (5, 31), (14, 12), (15, 13), (16, 21)]]
            elif operator_idx == 1:
                # Subtraction
                heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (13, 21), (13, 22), (14, 12), (15, 13), (16, 21)]]
            elif operator_idx == 2:
                # Multiplication
                heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 30), (8, 15), (9, 26), (13, 18), (13, 21), (13, 22), 
                                                                        (14, 12), (14, 13), (15, 8), (15, 13), (15, 14), (15, 15), (16, 3), 
                                                                        (16, 21), (17, 24), (17, 26), (18, 16), (20, 2), (22, 1)]]
            elif operator_idx == 3:
                # Division
                heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 31), (15, 13), (15, 14), (16, 21), (18, 16)]]

            partial_mlp_layers = list(range(get_model_consts(model_name).first_heuristics_layer, model.cfg.n_layers))
            full_mlps = [Component('mlp_post', layer=l) for l in range(model.cfg.n_layers) if l not in partial_mlp_layers]
            partial_mlps = [Component('mlp_post', layer=l, neurons=mlp_top_neurons[l]) for l in partial_mlp_layers]    
            partial_mlp0 = [Component('mlp_post', layer=0, neurons=mlp_top_neurons[0])]
            full_circuit = Circuit(model.cfg)
            for c in list(set(heads + full_mlps + partial_mlps + partial_mlp0)):
                full_circuit.add_component(c)
            return full_circuit

        faithfulness_per_k[(operator_idx, seed)] = torch.zeros((len(k_values),))
        for i, mlp0_k in enumerate(k_values):
            mlp_top_neurons = {}
            for mlp in range(1, model.cfg.n_layers):
                mlp_top_neurons[mlp] = mlppost_neuron_scores[mlp].topk(later_layers_k).indices.tolist()
            mlp_top_neurons[0] = mlp0_neuron_scores.topk(mlp0_k).indices.tolist()

            full_circuit = build_circuit(operator_idx, mlp_top_neurons)
            faithfulness_per_k[(operator_idx, seed)][i] = circuit_faithfulness_with_mean_ablation(model, full_circuit, prompts_and_answers, mean_cache, metric='nl')
            print(mlp0_k, faithfulness_per_k[(operator_idx, seed)][i].item())
        torch.save(faithfulness_per_k, results_file_path)

In [None]:
# Show some activation patterns of important neurons in MLP0 to give a sense of what it does

max_op = 300
numerical_tokens = [str(i) for i in range(max_op)]
activations = generate_activations(model, numerical_tokens, [Component('mlp_post', layer=0)], pos=-1)[0]

layer = 0
positions = [1, 2, 3] # Op1, Operator, Op2
mlp0_neuron_scores = sum(get_neuron_importance_scores(ranking_method="mean+std", operator_idx=0, pos=pos)[0] for pos in positions) / len(positions)
neurons_to_vis = [6206, 7101, 8969] #mlp0_neuron_scores.topk(30).indices.tolist()
for neuron in neurons_to_vis:
    neuron_activations = activations[:, neuron]
    scatter_with_labels(y=neuron_activations, title=f'Neuron {neuron} activations', x=numerical_tokens, hovertext=numerical_tokens)

In [None]:
# Draw the MLP0 top-k neurons eval figure

k_values = torch.tensor(sorted(list(range(0, 1000, 10)) + list(range(1000, model.cfg.d_mlp, 100)) + [model.cfg.d_mlp]))
seeds = [42, 412, 32879, 123]
results_file_path = f'./data/{model_name}/mlp0_topk_neuron_faithfulness_evaluation_results.pt'
visualization_ops = ['+', '-', '×', '÷']
colors = COLORBLIND_COLORS

faithfulness_per_k = torch.load(results_file_path)
# Average across seeds
faithfulness_per_k = {operator_idx: torch.stack([faithfulness_per_k[(operator_idx, seed)] for seed in seeds]).mean(dim=0) for operator_idx in range(len(OPERATORS))}
# Stack across operators
faithfulness_per_k = torch.stack([faithfulness_per_k[operator_idx] for operator_idx in range(len(OPERATORS))])


fig = go.Figure()
for i in range(len(OPERATORS)):
    fig.add_trace(go.Scatter(x=k_values, y=faithfulness_per_k[i], mode='lines', name=visualization_ops[i], line=dict(color=colors[i])))

x_axis_percentages = [0.01, 0.1, 1.0]
fig.update_layout(
    xaxis=dict(
        type="log",
        tickvals=[percent * model.cfg.d_mlp for percent in x_axis_percentages],
        ticktext=[f"{int(percent * 100)}%" for percent in x_axis_percentages],
        tickfont=dict(size=15),
        title="Neurons used Per Layer (%)",
        title_font=dict(size=16),
    ),
    yaxis=dict(
        title="Faithfulness",
        tickfont=dict(size=15),
        title_font=dict(size=16),
        range=(0.5, 1.0)
    ),
    legend=dict(itemwidth=30, itemsizing='constant', font=dict(size=16), bgcolor='rgba(0,0,0,0)', y=0.1),
    title="Faithfulness of using only top-k neurons",
    title_x=0.55, title_y=0.98, title_font=dict(size=16), 
    margin=dict(l=0, r=0, t=20, b=0), 
    width=400, height=200, 
)
fig.show()

pio.write_image(fig, f'./figs/mlp0_faithfulness_topk_neurons.pdf')

In [None]:
# Draw MLP0 neuron activation patterns

# Generate MLP0 activations for numerical tokens
max_op = 300
numerical_tokens = [str(i) for i in range(max_op)]
activations = generate_activations(model, numerical_tokens, [Component('mlp_post', layer=0)], pos=-1)[0]

neurons_to_vis = [6206, 7101, 8969] # Manually found neurons to show several interesting patterns

tickvals = {
    6206: list(range(0, 301, 25)),
    7101: list(range(0, 301, 25)),
    8969: list(range(0, 301, 32))
}

colors = COLORBLIND_COLORS[1:]
for neuron in neurons_to_vis:
    neuron_activations = activations[:, neuron]
    
    fig = go.Figure(data=go.Scatter(x=numerical_tokens, y=neuron_activations, 
                                mode="markers",
                                marker=dict(size=5, color=colors[0]), showlegend=False))
    
    fig.update_yaxes(title='Neuron activation', tickfont=dict(size=14), title_font=dict(size=16))
    fig.update_xaxes(title=f'Numerical input token', tickfont=dict(size=14), title_font=dict(size=16), tickvals=tickvals[neuron])
    fig.update_layout(
        title_x=0.55,
        title_y=0.98,
        width=300,
        height=200,
        title_font=dict(size=16),
        margin=dict(l=0, r=0, t=20, b=0),
    )
    fig.show()

    pio.write_image(fig, f'./figs/mlp0_neuron{neuron}_activations.pdf')

In [None]:
# Visualize attention head patterns

operator_idx = 0
for operator_idx in range(len(OPERATORS)):
    effective_heads = [
        [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 3), (5, 31), (14, 12), (15, 13), (16, 21)]],
        [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (13, 21), (13, 22), (14, 12), (15, 13), (16, 21)]],
        [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 30), (8, 15), (9, 26), (13, 18), (13, 21), (13, 22), 
                                                        (14, 12), (14, 13), (15, 8), (15, 13), (15, 14), (15, 15), (16, 3), 
                                                        (16, 21), (17, 24), (17, 26), (18, 16), (20, 2), (22, 1)]],
        [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 31), (15, 13), (15, 14), (16, 21), (18, 16)]]
    ][operator_idx]
    min_op = 1 if operator_idx == 3 else 0
    prompts = generate_all_prompts_for_operator(OPERATORS[operator_idx], min_op, max_op, single_token_number_range=(0, LLAMA3_8B_CONSTS.max_single_token))
    prompts = separate_prompts_and_answers(large_prompts_and_answers[operator_idx])[0]
    head_html, head_patterns = visualize_arithmetic_attention_patterns(model, effective_heads, prompts, use_bos_token=True, return_raw_patterns=True)
    torch.save((head_html, effective_heads, head_patterns), f'./data/{model_name}/mean_attn_head_patterns_{OPERATOR_NAMES[operator_idx]}.pt')
    display(head_html)

In [None]:
# Visualize attention heads for paper

attention_patterns, operator_counts = {}, {}
for operator_idx in range(len(OPERATORS)):
    _, effective_heads, head_patterns = torch.load(f'./data/{model_name}/mean_attn_head_patterns_{OPERATOR_NAMES[operator_idx]}.pt')
    for head, pattern in zip(effective_heads, head_patterns):
        if (head.layer, head.head_idx) not in attention_patterns:
            attention_patterns[(head.layer, head.head_idx)] = torch.zeros(4, 5, 5)
            operator_counts[(head.layer, head.head_idx)] = 0
        attention_patterns[(head.layer, head.head_idx)][operator_idx] = pattern
        operator_counts[(head.layer, head.head_idx)] += 1

important_heads = [(16, 21), (15, 13), (2, 2)]
other_heads = list(set(attention_patterns.keys()) - set(important_heads))


tickvals = ["BoS", "Op1", "Operator", "Op2", "="]
op_to_vis = None
heads_to_vis = important_heads
for head in heads_to_vis:
    pattern = attention_patterns[head]
    if op_to_vis is None:
        pattern = pattern.sum(dim=0) / operator_counts[head]
    else:
        pattern = pattern[op_to_vis]

    fig = go.Figure(data=go.Heatmap(
        z=pattern,
        x=tickvals,
        y=tickvals,
        colorscale='Purples',
        zmin=0,
        zmax=1,
        showscale=False,
        colorbar=dict(title='Value')
    ))

    # Add visible edges between pixels
    for i in range(6):
        fig.add_shape(type="line",
                    x0=-0.5, y0=i-0.5, x1=4.5, y1=i-0.5,
                    line=dict(color="lightgrey", width=1))
        fig.add_shape(type="line",
                    x0=i-0.5, y0=-0.5, x1=i-0.5, y1=4.5,
                    line=dict(color="lightgrey", width=1))

    fig.update_layout(width=300, height=300, coloraxis_showscale=False, margin=dict(l=0, r=0, t=0, b=0))
    fig.update_yaxes(autorange="reversed", tickfont=dict(size=17), title=dict(text="Destination Token",  font=dict(size=18), standoff=10))
    fig.update_xaxes(tickfont=dict(size=17), title=dict(text="Source Token", font=dict(size=18), standoff=5))
    fig.show()

    pio.write_image(fig, f'./figs/attention_pattern_L{head[0]}H{head[1]}_{op_to_vis if op_to_vis else "mean"}.pdf')

##### Other appendices

In [None]:
# How many neurons get succesfully clasiffied by the algorithm?

MATCH_THRESHOLD = 0.6

for operator_idx in range(len(OPERATORS)):
    gt_heuristic_classes = load_heuristic_classes(f"./data/{model_name}", operator_idx, "HYBRID")
    pre_filter_heuristic_neurons = set([(v[0], v[1]) for v in chain.from_iterable(gt_heuristic_classes.values())])
    
    # Filter by threshold
    gt_heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= MATCH_THRESHOLD] for name, layer_neuron_scores in gt_heuristic_classes.items()}
    gt_heuristic_classes = {name: lns for name, lns in gt_heuristic_classes.items() if len(lns) > 0}
    post_filter_heuristic_neurons = set([(v[0], v[1]) for v in chain.from_iterable(gt_heuristic_classes.values())])

    print(operator_idx, f"{len(post_filter_heuristic_neurons)} / {len(pre_filter_heuristic_neurons)} = {len(post_filter_heuristic_neurons) / len(pre_filter_heuristic_neurons)}")

In [None]:
# Show neuron intersection across operators

k_per_layer = 200
model_consts = LLAMA3_8B_CONSTS
layers = list(range(model_consts.first_heuristics_layer, model.cfg.n_layers))

mlp_neurons_per_operator = {}
for op_idx in range(len(OPERATORS)):
    mlp_neurons_per_operator[op_idx] = []
    top_neurons = get_neuron_importance_scores(operator_idx=op_idx)
    for layer in layers:
        top_neurons_in_layer = top_neurons[layer].topk(k=k_per_layer).indices.tolist()
        mlp_neurons_per_operator[op_idx] += [(layer, neuron) for neuron in top_neurons_in_layer]
    
# Plot the intersection between operators
intersection_neurons = [[round(len(set(mlp_neurons_per_operator[i]).intersection(set(mlp_neurons_per_operator[j]))) / len(set(mlp_neurons_per_operator[i]).union(set(mlp_neurons_per_operator[j]))), 3) for j in range(len(OPERATORS))] for i in range(len(OPERATORS))]

# Plot heatmap with labels
vis_operators = ['+', '-', '×', '÷']
fig = px.imshow(intersection_neurons, 
                x=vis_operators, y=vis_operators, text_auto='.1%',
                width=370, height=370,
                color_continuous_midpoint=0.6, color_continuous_scale='Blues')
fig.update_traces(textfont_size=16)
fig.update_xaxes(title=dict(text='Operator', standoff=10), title_font=dict(size=17), tickfont=dict(size=17))
fig.update_yaxes(title=dict(text='Operator', standoff=10), title_font=dict(size=17), tickfont=dict(size=17))
fig.update_coloraxes(colorbar_len=0.7, colorbar_thickness=15)
fig.update_layout(title='Neuron Intersection Across Operators', title_font=dict(size=17), title_x=0.5, title_y=0.92, margin=dict(l=0, r=0, t=0, b=0))

fig.show()

pio.write_image(fig, f'./figs/neuron_intersection_across_operators.pdf')

In [None]:
# Print model accuracy

models = ['gptj', 'llama3-70b', 'pythia-6.9b-step143000', 'llama3-8b']
for m in models:
    acc = torch.load(f'./data/{m}/accuracy.pt')
    min_op, max_op = 0, 300
    print(m, acc)
    
    with open(f'./data/{m}/large_prompts_and_answers_max_op=300.pkl', 'rb') as f:
        large_prompts_and_answers = pickle.load(f)
    print([len([pa for pa in large_prompts_and_answers[op_idx] if int(pa[1]) > 0]) / len(generate_all_prompts_for_operator(OPERATORS[op_idx], min_op, max_op)) for op_idx in range(4)])

    total_valid_prompts = sum([len(generate_all_prompts_for_operator(OPERATORS[op_idx], min_op, max_op)) for op_idx in range(4)])
    total_correct_prompts = sum([acc[OPERATORS[op_idx]] * len(generate_all_prompts_for_operator(OPERATORS[op_idx], min_op, max_op)) for op_idx in range(4)])
    print("Overall", total_correct_prompts / total_valid_prompts)

##### Results for "other models" appendix

In [None]:
# Evaluate the faithfulness of other models

model_name, model_path = 'gptj', 'gptj'

max_op = 300
mean_cache = torch.load(f'./data/{model_name}/mean_cache_for_evaluation_all_arithmetic_prompts_max_op={max_op}.pt')
if mean_cache[list(mean_cache.keys())[0]].shape[0] != 50:
    mean_cache = {c: a.repeat(50, 1, 1) for c, a in mean_cache.items()}

if 'model' not in locals():
    model = load_model(model_name, model_path, "cuda", extra_hooks=True)

with open(fr'./data/{model_name}/large_prompts_and_answers_max_op={max_op}.pkl', 'rb') as f:
    large_prompts_and_answers = pickle.load(f)
    large_prompts_and_answers = [[pa for pa in large_prompts_and_answers[op_idx] if pa[1] != '0'] for op_idx in range(4)]

def build_circuit(model, model_name, operator_idx):
    heads = topk_effective_components(model, ie_maps, k=50, heads_only=True).keys()
    partial_mlp_layers = list(range(get_model_consts(model_name).first_heuristics_layer, model.cfg.n_layers))
    full_mlps = [Component('mlp_post', layer=l) for l in range(model.cfg.n_layers)] if l not in partial_mlp_layers]
    partial_mlps = [Component('mlp_post', layer=l, neurons=mlp_top_neurons[l]) for l in partial_mlp_layers]

    full_circuit = Circuit(model.cfg)
    for c in list(set(heads + full_mlps + partial_mlps)):
        full_circuit.add_component(c)
    return full_circuit

seeds = [42, 412, 32879, 123, 436]
for operator_idx in range(len(OPERATORS)):
    ie_maps = torch.load(f'./data/{model_name}/ie_maps_activation_patching.pt')
    summed_seed_ie_maps = {}
    for op_idx, pos, seed in ie_maps.keys():
        if (op_idx, pos) not in summed_seed_ie_maps:
            summed_seed_ie_maps[(op_idx, pos)] = ie_maps[(op_idx, pos, seed)]
        else:
            summed_seed_ie_maps[(op_idx, pos)] += ie_maps[(op_idx, pos, seed)]
    ie_maps = {k: v / len(seeds) for (k, v) in summed_seed_ie_maps.items()}
    ie_maps = torch.stack([ie_maps[(operator_idx, pos)] for pos in POSITIONS]).mean(dim=0) # Average across positions

    avg_nl_acc = 0
    for seed in seeds:
        set_deterministic(seed)
        prompts_and_answers = random.sample(large_prompts_and_answers[operator_idx], k=50)
        full_circuit = build_circuit(model, model_name, operator_idx)
        nl_acc = circuit_faithfulness_with_mean_ablation(model, full_circuit, prompts_and_answers, mean_cache, metric='nl')
        avg_nl_acc += nl_acc
        print(f"Normalized Logit Acc (Seed {seed}): {nl_acc:.3f}")

    avg_nl_acc = avg_nl_acc / len(seeds)
    print(f"Avg Normalized Logit Acc: {avg_nl_acc:.3f}")

In [None]:
# LOCALIZATION

model_name = "gptj"
# model_name = "pythia-6.9b-step143000"
# model_name = "llama3-70b"
attn_maps_sum_over_positions = False # If False, only the last position IE is presented

ie_maps = torch.load(f'./data/{model_name}/ie_maps_activation_patching.pt')

# Average across seeds
summed_seed_ie_maps = {}
seeds = [42, 412, 32879, 123, 436] if model_name != 'llama3-70b' else [42] # Could only run one seed due to GPU requirements
for op_idx, pos, seed in ie_maps.keys():
    if (op_idx, pos) not in summed_seed_ie_maps:
        summed_seed_ie_maps[(op_idx, pos)] = ie_maps[(op_idx, pos, seed)]
    else:
        summed_seed_ie_maps[(op_idx, pos)] += ie_maps[(op_idx, pos, seed)]
ie_maps = {k: v / len(seeds) for (k, v) in summed_seed_ie_maps.items()}
# Average across operators
ie_maps = {pos: torch.stack([ie_maps[(op_idx, pos)] for op_idx in range(len(OPERATORS))]).mean(dim=0) for pos in POSITIONS}
# Tensorify
ie_maps = torch.stack([ie_maps[pos] for pos in POSITIONS]) # pos, Layers, heads+mlp
ie_maps = np.log1p(ie_maps)
attn_maps, mlp_maps = ie_maps[:, :, :-1], ie_maps[:, :, -1]
if attn_maps_sum_over_positions:
    attn_maps = attn_maps.sum(dim=0) # Sum effect from all positions
else:
    attn_maps = attn_maps[-1] # Take only last position
mlp_maps = mlp_maps.T


# Find the global min and max values
global_min = min(attn_maps.min(), mlp_maps.min()).item()
global_max = max(attn_maps.max(), mlp_maps.max()).item()

# Create subplots
attn_title = "Attention Heads (Summed over positions)" if attn_maps_sum_over_positions else "Attention Heads (Last Position)"
fig = make_subplots(rows=1, cols=2, subplot_titles=[attn_title, "MLPs"], shared_yaxes=True)

# Add heatmaps to subplots
fig.add_trace(
    go.Heatmap(
        z=attn_maps,
        coloraxis="coloraxis",
    ),
    row=1, col=1
)

fig.add_trace(
    go.Heatmap(z=mlp_maps, coloraxis="coloraxis"),
    row=1, col=2
)

fig.update_xaxes(title=dict(text="Attention Head", font=dict(size=16), standoff=8), tickfont=dict(size=15), domain=[0, 0.7], row=1, col=1)
fig.update_yaxes(title=dict(text="Layer", font=dict(size=16), standoff=15), tickfont=dict(size=15), autorange="reversed", row=1, col=1)

fig.update_xaxes(title=dict(text="Position", font=dict(size=16), standoff=8), tickfont=dict(size=15), domain=[0.8, 1.0], row=1, col=2)
fig.update_yaxes(autorange="reversed", row=1, col=2)

# Update layout
fig.update_layout(
    title={
        'text': "Effect Map Per Attn Head / MLP",
        'x': 0.55, 'y': 0.98,
        'font': dict(size=17)
    },
    margin=dict(l=0, r=0, t=50, b=0),
    height=250,
    width=500,
    coloraxis=dict(
        colorscale='Blues',
        cmin=global_min,
        cmax=global_max,
        colorbar=dict(
            title=dict(text="log<br>scale", font=dict(size=17)),
            tickfont=dict(size=15),
            thickness=20,
            len=1.5,
            yanchor="middle",
            y=0.65,
            xanchor="left",
            x=1.02
        )
    )
)

fig.layout.annotations[0].update(x=0.34, y=1.01, font=dict(size=16))
fig.layout.annotations[1].update(x=0.90, y=1.01, font=dict(size=16))


pio.write_image(fig, f'./figs/{model_name}_localization.png')
pio.write_image(fig, f'./figs/{model_name}_localization.pdf')
fig.show()

In [None]:
# EVALUATION (TOP-K NEURONS)

# model_name = "gptj"
# model_name = "pythia-6.9b-step143000"
model_name = "llama3-70b"

operator_labels = ['+', '-', '×', '÷']
colors = COLORBLIND_COLORS
d_mlp = 16384 if model_name in ['gptj', 'pythia-6.9b-step143000'] else 28672 # To avoid loading Llama3-70b just for this hardcoded number


k_values = torch.tensor(sorted(list(range(0, 500, 10)) + list(range(500, d_mlp, 50)) + [d_mlp]))
seeds = [42, 412, 32879]
results_file_path = f'./data/{model_name}/topk_neuron_faithfulness_evaluation_results.pt'

faithfulness_per_k = torch.load(results_file_path)
faithfulness_per_k = {operator_idx: torch.stack([faithfulness_per_k[(operator_idx, seed)] for seed in seeds]).mean(dim=0) for operator_idx in range(len(OPERATORS))} # Average across seeds
faithfulness_per_k = torch.stack([faithfulness_per_k[operator_idx] for operator_idx in range(len(OPERATORS))]) # Stack across operators

fig = go.Figure()
for op_idx in range(len(OPERATORS)):
    fig.add_trace(go.Scatter(x=k_values, y=faithfulness_per_k[op_idx], mode='lines', name=operator_labels[op_idx], line=dict(color=colors[op_idx])))

x_axis_percentages = [0.01, 0.1, 1]
fig.update_xaxes(title=dict(text="Neurons used Per Layer", font=dict(size=17), standoff=5), tickfont=dict(size=16))
fig.update_yaxes(title=dict(text="Faithfulness", font=dict(size=17), standoff=10), range=(0, 1.0), tickvals=[0.2, 0.4, 0.6, 0.8], tickfont=dict(size=16))

fig.update_layout(
    xaxis=dict(
        type="log",
        tickvals=[percent * d_mlp for percent in x_axis_percentages],
        ticktext=[f"{int(percent * 100)}%" for percent in x_axis_percentages],
        title=dict(text="Neurons used Per Layer (%)", font=dict(size=17), standoff=5)
    ),
    legend=dict(itemwidth=30, orientation="h", itemsizing='constant', yanchor="bottom", 
                y=-0.4, xanchor="center", x=0.5, font=dict(size=17)),
    title=dict(text="Faithfulness of using only top-k neurons", x=0.56, y=0.98, font=dict(size=17)),
    margin=dict(l=0, r=0, t=20, b=0), 
    width=400, height=250, 
)
fig.show()

pio.write_image(fig, f'./figs/{model_name}_faithfulness_topk_neurons.pdf')


In [None]:
# Linear probing

# model_name = "pythia-6.9b-step143000"
model_name = "gptj"
# model_name = "llama3-70b"
probe_accs = torch.load(f'./data/{model_name}/probe_accs.pt')

# Average across operators
probe_accs = {pos_to_probe: torch.tensor([probe_accs[(operator_idx, pos_to_probe)] for operator_idx in range(len(OPERATORS))]).mean(dim=0) for pos_to_probe in POSITIONS}

# Tensorify
probe_accs_tensor = torch.stack([probe_accs[pos_to_probe] for pos_to_probe in POSITIONS])

probe_accs_tensor = probe_accs_tensor.T # For nicer visualization

fig = px.imshow(probe_accs_tensor, 
       x=['Operand1', 'Operator', 'Operand2', '='],
       y=list(range(probe_accs_tensor.shape[0])),
       width=250,
       height=250,
       zmin=0,
       color_continuous_midpoint=0.0, color_continuous_scale="blues"
)
fig.update_yaxes(title=dict(text="Layer", standoff=15, font=dict(size=17)), tickvals=list(range(0, probe_accs_tensor.shape[0], probe_accs_tensor.shape[0] // 4)), tickfont=dict(size=15))
fig.update_xaxes(title=dict(text="Position", standoff=10, font=dict(size=17)), tickfont=dict(size=15))
fig.update_layout(margin=dict(l=0, r=0, t=50, b=0),
                  title=dict(x=0.47, y=0.95, text="Answer token<br>probe accuracy", font=dict(size=17)))
fig.update_coloraxes(colorbar=dict(len=1.1,
                                   thickness=20,
                                   yanchor="middle",
                                   y=0.5))
fig.show()

pio.write_image(fig, f'./figs/{model_name}_probing_acc.pdf')

In [None]:
# Prompt knockout 

# model_name = "pythia-6.9b-step143000"
model_name = "gptj"
# model_name = "llama3-70b"

HEURISTIC_MATCH_THRESHOLD = 0.6
operator_labels = ['+', '-', '×', '÷']
colors = COLORBLIND_COLORS

neuron_hard_limits = torch.load(f'./data/{model_name}/addition_prompt_ablation_results_thres={HEURISTIC_MATCH_THRESHOLD}_HYBRID_maps.pt')[0]

fig = go.Figure()
fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', name='Heuristic Ablation', line=dict(color='grey')))
fig.add_trace(go.Scatter(x=[None], y=[None], mode='lines', name='Random Ablation', line=dict(dash='dash', color='grey')))

for i, op in enumerate(OPERATORS):
    neuron_hard_limits, baseline_results, ablated_results, ablated_neuron_counts, control_results = torch.load(f'./data/{model_name}/{OPERATOR_NAMES[i]}_prompt_ablation_results_thres={HEURISTIC_MATCH_THREHSOLD}_HYBRID_maps.pt')
    neuron_hard_limits = list(neuron_hard_limits)
    assert torch.all(baseline_results == 1)
    fig.add_trace(go.Scatter(x=neuron_hard_limits, y=ablated_results, mode='lines', name=f'{operator_labels[i]}',
                            line=dict(color=colors[i])))
    fig.add_trace(go.Scatter(x=neuron_hard_limits, y=control_results, mode='lines', name=f'', showlegend=False,
                            line=dict(color=colors[i], dash='dash')))

fig.update_yaxes(title=dict(text="Accuracy", font=dict(size=17)), range=(0, 1.02), tickvals=[0.25, 0.5, 0.75, 1], tickfont=dict(size=15))
fig.update_xaxes(title=dict(standoff=0, text=f'Ablated Neurons (Per Layer)', font=dict(size=17)), tickvals=list(range(0, 81, 20)), tickfont=dict(size=15))
fig.update_layout(
    legend=dict(
        font=dict(size=15),
        orientation="h",
        yanchor="bottom",
        y=-0.95,
        xanchor="center",
        x=0.5,
    ),
    margin=dict(l=0, r=0, t=20, b=80), 
    width=400,
    height=350,
    title=dict(
        text="Heuristic knockout accuracies",
        x=0.5,
        y=0.98,
        font=dict(size=16),
        xanchor='center',
        yanchor='top'
    ),
)
fig.show()
pio.write_image(fig, f'./figs/{model_name}_prompt_knockout_per_layer.pdf')