# Modular Arithmetic: Language Models Solve Math Digit by Digit
## Download nececcery data
how to set up the pyvene library with support for Olmo2 as an intervenable model, and how to download the corresponding data and code for the transformer digit arithmetic project.





##### Pyvene + Olmo2 Setup Guide: If you've previously installed the standard pyvene package, uninstall it first to avoid conflicts.

##### Install Our Custom pyvene with Olmo2 Support https://anonymous.4open.science/r/pyvene_Olmo2
Download the custom version of pyvene with Olmo2 support from the anonymous repository.
Unzip the downloaded archive.
Install the package in editable mode so any local changes are reflected without reinstalling.


In [None]:
# pip uninstall -y pyvene  # pyvene has to be uninstalled if a standard version is currently installed
# pip install pyvene
%cd /content/
!wget https://anonymous.4open.science/api/repo/pyvene_Olmo2/zip -O lib.zip
!unzip lib.zip -d pyvene_Olmo2
%cd pyvene_Olmo2
!pip install -e .

import sys
sys.path.append('/content/pyvene_Olmo2')  # Adjust as needed

##### Download Digit Arithmetic Code and Data https://anonymous.4open.science/r/tda-C722/
Download the dataset and code for the anonymous repository.
Unzip the contents into a local directory.

In [None]:
%cd /content/
!wget https://anonymous.4open.science/api/repo/transformer-digit-arithmetic-FCBC/zip -O code.zip
!unzip code.zip -d transformer-digit-arithmetic
%cd transformer-digit-arithmetic

##### Import Required Libraries

In [None]:
import pyvene as pv
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import sys
import torch.nn.functional as F
from pyvene import embed_to_distrib, top_vals, format_token
from pyvene import (
    ZeroIntervention,
    IntervenableModel,
    VanillaIntervention, Intervention,
    RepresentationConfig,
    IntervenableConfig,
    ConstantSourceIntervention,
    LocalistRepresentationIntervention
)
from huggingface_hub import login
import re
import numpy
from tqdm import tqdm
from matplotlib import pyplot as plt
import transformers
import os

## Model Loading Guide (LLAMA & OLMo2)

This section describes how to load supported language models (LLama 3 8B or OLMo 2 7B) for use in experiments. It includes Hugging Face login, model selection, and hardware compatibility notes.

---

### Hardware Requirements

| Model         | Approx. VRAM Required | Compatible GPU |
|---------------|------------------------|----------------|
| LLaMA 3 8B     | ~16 GB                 | A100 / L4       |
| OLMo 2 7B      | ~14 GB                 | A100 / L4       |
| LLaMA 3 70B    | ❌ *Not suitable for Colab* |

---

### Login to Hugging Face

To download and use the **Meta LLaMA** models, you must authenticate with your **Hugging Face access token**.

Replace `<replace_with_your_token>` with your actual Hugging Face token.


In [None]:
# Login to Huggingface to get access to model parameters
# Paste your token here
HugginFace_Token = "<replace_with_your_token>"
login(HugginFace_Token)


### Choose a Model

Select one of the available models:

In [None]:
# Available models
model_name = "Llama-3-8B"  # Options: "Llama-3-8B", "Llama-3-70B", "Olmo-2-7B"

models = {
    "Llama-3-8B": "meta-llama/Meta-Llama-3-8B",
    "Llama-3-70B": "meta-llama/Meta-Llama-3-70B",
    "Olmo-2-7B": "allenai/OLMo-2-1124-7B"
}

### Load Model and Tokenizer
Load the model and tokenizer using transformers. This script will automatically use GPU (cuda) if available.

In [None]:

# Set parameters
params = {
    'model_name': models[model_name],
    'device': "cuda" if torch.cuda.is_available() else "cpu"
}

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(params['model_name'])
model = AutoModelForCausalLM.from_pretrained(
  params['model_name'], 
  torch_dtype=torch.float16).to(params['device'])

# Confirm device
print(f"Using device: {params['device']}")

device = "cuda:0" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
  print("Using GPU")

## Digit Circuit Intervention Setup

This guide explains how to configure and run interventions on **digit circuits** in transformer models using precomputed **Fisher scores**. You can choose specific parameters like digit position, arithmetic task, operand type, and apply thresholding to identify meaningful MLP dimensions for intervention.

---

## Purpose

- Identify MLP neurons involved in **digit representation** (hundredth, tenth, unit).
- Perform interventions based on:
  - Arithmetic **task** (`addition` / `subtraction`)
  - **Operand** type (`op1` or `op2`)
  - Precomputed **Fisher score thresholds**
- Extract and map MLP neuron indices per layer that cross the selected Fisher score threshold.

---

## Configuration

### Select the digit position (label)

In [None]:
# labels
label = "hundredth"  # Options: "hundredth", "tenth", "unit"

### Choose the arithmetic task


In [None]:
# task (operator)
task = "addition"  # Options: "addition", "subtraction"


### Select the operand type: (op1 + op2 = ..) two variation either fix op1 or op2

In [None]:
# which oeprand
operand = "op1"  # Options: "op1", "op2"


###  Load Intervention Data


In [None]:
# Load the specific intervention dataset based on task and operand
input_file = f"Intervention_Data/intervention_data_{task}_{operand}.json"
with open(input_file, "r") as f:
    data = json.load(f)


###  Load Layer Set and Threshold
These are precomputed and based on the best-performing layers and Fisher thresholds per model/task/operand configuration.

In [None]:
# Load selected layers for intervention
with open("intervene_layers.json", "r") as f:
    layer_sets = json.load(f)
layer_set = layer_sets[model_name][task][operand]

# Load corresponding Fisher score threshold
with open("fisher_scores_threshold_map.json", "r") as f:
    threshold_map = json.load(f)
threshold = threshold_map[model_name][task][operand][label]

print("Selected layers:", layer_set)
print("Fisher threshold:", threshold)

###  Extract Digit Circuit Neurons by Threshold

For each selected layer, extract neuron indices where the Fisher score exceeds the threshold.

In [None]:
# Load Fisher scores for the specified digit label
fisher_file = f"Fisher_Scores/{model_name}/{task}/fisher_scores_{label}.json"
with open(fisher_file, "r") as file:
    fisher_scores_data = json.load(file)

# Output map: layer index → list of neuron indices
layer_subspaces_map = {}

# Filter neurons per layer
for layer in layer_set:
    key = f"layer_{layer}"
    if key in fisher_scores_data:
        values = fisher_scores_data[key]
        indices_above_threshold = [i for i, val in enumerate(values) if val > threshold]
        layer_subspaces_map[layer] = indices_above_threshold

# Resulting neuron indices to intervene on
print(layer_subspaces_map)


####  Notes

- You can customize the threshold or layers to perform ablations or sensitivity analyses.
- The mappings (layer_sets, threshold_map, fisher_scores_data) are organized to allow plug-and-play selection based on model name, digit position, and task type.

# 🔬 Digit Circuit Interventions – Implementation Guide

This code performs **causal interventions** on digit-position-specific MLP circuits identified in transformer language models (e.g., LLaMA 3 8B, OLMo 2 7B) using **Fisher-score-selected subspaces**.

---

##  Goal

To verify that specific MLP neurons (identified via high Fisher scores) are **causally responsible** for generating arithmetic result digits at different positions (units, tens, hundreds). This is done through **interchange interventions**: replacing the activations of a “base” input with those from a “source” input only at selected subspaces.

---

##  Setup and Context

Before this code runs:
- Fisher scores have been computed per digit position.
- A threshold has been chosen to select the most important neurons per layer (see Section 3.1 of the paper).
- Layer sets are determined where digit information is known to flow (see Section C of the paper).
- The JSON files `intervene_layers.json` and `fisher_scores_threshold_map.json` define model/task/digit-specific layer sets and thresholds.

---

##  Step-by-Step Breakdown

###  Iterate over Intervention Examples
###  Clean Forward Pass (Baseline)
Performs a clean (unmodified) forward pass on the base sentence to record its top-100 predictions.

The logits of the final token are extracted.
A softmax is applied to obtain a probability distribution.
Top 100 tokens and their probabilities are stored for comparison.
###  Interchange Intervention Setup
Uses pyvene to define which model component and layer to intervene on (mlp_output of selected layers).

Then prepares layer_subspaces_map → a dictionary mapping each selected layer to a list of high-Fisher-score neuron indices (i.e., subspaces for that digit-position circuit).
###  Run the Intervention
This interchange intervention replaces selected neuron activations from the source sentence into the base sentence at the final token's position (see Section 3.2 of the paper).

Only neurons that passed the Fisher-score threshold are modified.
Only specified layers are intervened on (those known to be causally relevant per digit position).
Only the selected digit-position circuit is changed (e.g., the "hundredth" digit MLP subspace).

###  Post-Intervention Analysis
After the intervention, the model output is compared to the baseline. The shift in token probability mass (e.g., from 579 → 573) indicates whether the intended digit was selectively changed.

This is aligned with the digit-modularity claim in the paper (Section 3.2, Table 2 & 3).

###  Save Results
Each run’s clean and intervened top-k distributions are saved as a CSV, enabling quantitative evaluation (e.g., computing probability shift to expected variants, see §3.2, §4.1).

##  Expected Outcome

If the digit circuits are truly modular and causal, intervening on one circuit (e.g., hundredth) should:

Shift probability toward the output that only replaces the hundredth digit (e.g., 957 → 357)
Leave tens and unit digits mostly unchanged.
This provides strong causal evidence that digit-wise arithmetic is realized via distinct, compositional subcircuits in LLMs.

In [None]:
################################################
# Perform Digit-Circuit Interventions on each data point #
################################################

# Iterate through each query in the dataset
for j, entry in tqdm(enumerate(data)):
    data_entry = []
    model_layers = model.config.num_hidden_layers
    window_size = 1

    sentence = entry["one_shot_base"]
    sentence_intervention = entry["one_shot_source"]

    base = tokenizer(sentence, return_tensors="pt").to(device)

    # Number of tokens
    tokenized_input = tokenizer(sentence, return_tensors="pt", return_offsets_mapping=True)
    input_ids = tokenized_input["input_ids"].to(device)
    num_tokens = input_ids.shape[1]

    ############################
    # Clean Run for comparison #
    ############################

    inputs = [tokenizer(sentence, return_tensors="pt").to(device),]
    res = model(**inputs[0])

    distrib = res.logits
    logits = distrib[0][-1]

    # Apply softmax to get probabilities
    probabilities = F.softmax(logits, dim=-1)

    # Get the top 10 tokens and their probabilities
    top_k = 50
    top_k_probs, top_k_indices = torch.topk(probabilities, top_k)

    # Convert indices to tokens
    top_k_tokens = [tokenizer.decode(index.item()) for index in top_k_indices]

    # Collect the data
    data_temp = []
    for token, prob in zip(top_k_tokens, top_k_probs):
        data_temp.append({
            "token": token,
            "prob": prob.detach().cpu().item()
        })

    data_entry.append({"run": "clean", "top_100": data_temp})

    ###############################################
    # Interchange Interventions across layer sets #
    ###############################################

    # Get the index of the last token using len()
    last_token_index = len(base['input_ids'][0]) - 1  # Use len() to get the length of the sequence

    # Create intervention for specific layers
    config = pv.IntervenableConfig([{
        "layer": l,
        "component": "mlp_output",
        "intervention_type": VanillaIntervention
        } for l in layer_set] # Pass a list instead of a single layer
    )

    pv_model = pv.IntervenableModel(config, model=model)

    # Define list of subspaces based on the layer_subspaces_map
    # Create an empty list to store the corresponding subspaces for each layer in layer_set
    subspaces = []

    # Loop over the layers in the current layer_set and fetch corresponding subspaces
    for layer in layer_set:
        subspaces.append(layer_subspaces_map[layer])

    # run an interchange intervention
    _, intervened_outputs = pv_model(
      # the base input
      base=tokenizer(sentence, return_tensors = "pt").to(device),
      # the source input
      sources=tokenizer(sentence_intervention, return_tensors = "pt").to(device),
      # the location to intervene at (last token)
      unit_locations={"sources->base": last_token_index},
      subspaces = subspaces
    )

    distrib = intervened_outputs.logits
    logits = distrib[0][-1]

    # Apply softmax to get probabilities
    probabilities = F.softmax(logits, dim=-1)

    # Get the top 10 tokens and their probabilities
    top_k = 100
    top_k_probs, top_k_indices = torch.topk(probabilities, top_k)

    # Convert indices to tokens
    top_k_tokens = [tokenizer.decode(index.item()) for index in top_k_indices]

    # Collect the data
    data_temp = []
    for token, prob in zip(top_k_tokens, top_k_probs):
        data_temp.append({
            "token": token,
            "prob": prob.detach().cpu().item()
        })

    data_entry.append({"run": "intervened", "top_100": data_temp})

    df = pd.DataFrame(data_entry)

    output_dir = f"Interventions/"
    os.makedirs(output_dir, exist_ok=True)  # Create the folder if it doesn't exist

    df.to_csv(f"{output_dir}/intervention_{model_name}_{task}_{operand}_{label}_threshold_{threshold}_{j}.csv")

#  Evaluating Digit Circuit Interventions — Average Variant Probabilities

This script **aggregates and evaluates** the effect of digit-circuit interventions on language model outputs by computing the **average predicted probability** of specific **result variants** across all clean and intervened runs.

It operationalizes the causal claims made in the paper, especially regarding **digit-specific influence** of MLP neuron circuits.

---

##  Goal

After each digit-circuit intervention (e.g., modifying only the MLP subspace for the "hundreds" digit), we expect:
- **One result variant** (e.g., `sbb`) to increase in probability,
- While others (e.g., `bsb`, `bbs`, `ssb`, etc.) remain largely unchanged.

This script measures that effect **across many examples** by:
- Loading `.csv` results for each example,
- Extracting top-100 tokens from clean and intervened runs,
- Matching them to known variant results (e.g., `sss = fully source result`),
- Accumulating and averaging the predicted probabilities.

---

##  Input Requirements

- Folder: `Interventions/`  
  Contains per-example `.csv` files of top-100 token distributions.
  Each file includes rows like:
  
  |run| top_100|
  |---------------|------------------------|
  |clean| [{"token": "579", "prob": 0.94}, ...]|
  |intervened| [{"token": "573", "prob": 0.72}, ...]|


- In-memory variable: `data`  
List of intervention metadata loaded earlier (e.g., from `intervention_data_add_op1.json`).  
Each entry must include:
    ```json
    "result_variants": {
        "bbb": "579",
        "sbb": "779",
        ...
    }


##  Core Logic Explained
1. Initialize Accumulators
- accumulated_probs = { "clean": ..., "intervened": ... }
- counts = { "clean": ..., "intervened": ... }
- For each variant (e.g., "bsb", "sss"), track:
- Sum of predicted probabilities,
- Count of appearances across examples.

2. Loop Through CSV Files
- For each .csv corresponding to a single intervention example:
- Extract index j from the filename to align with data[j],
- Load the top-100 token predictions for both clean and intervened.

3. Match Tokens to Result Variants
- variant_token_map = { "sbb": "773", ... }
- For each run (clean/intervened) and each variant (bbb, sbb, etc.):
- Match if the predicted token matches a known variant result token.
- If matched, add its probability to the accumulator and increment the count.


4. Compute Averages
5. Save Results to JSON

##  Expected Outcome

If the digit circuit interventions work:

The target variant's probability should increase (e.g., sbb for hundreds-digit intervention),
The default variant (bbb) should decrease,
Other variants should remain relatively stable.


In [None]:
import os
import json
import pandas as pd
import ast
import re

variant_labels = ["bbb", "bbs", "bsb", "bss", "sbb", "sbs", "ssb", "sss"]
folder_path = "Interventions/"

# Initialize accumulators: sums and counts for each run and variant label
accumulated_probs = {
    "clean": {label: 0.0 for label in variant_labels},
    "intervened": {label: 0.0 for label in variant_labels}
}
counts = {
    "clean": {label: 0 for label in variant_labels},
    "intervened": {label: 0 for label in variant_labels}
}

# Get sorted CSV files
csv_files = [f for f in os.listdir(folder_path) if re.match(r"intervention_.*_(\d+)\.csv", f)]
csv_files.sort(key=lambda x: int(re.search(r"_(\d+)", x).group(1)))

for file_name in csv_files:
    index_match = re.search(r"_(\d+)\.csv$", file_name)
    i = int(index_match.group(1)) if index_match else None
    if i is None:
        continue

    print(f"Processing example {i} from file {file_name}")
    file_path = os.path.join(folder_path, file_name)
    df = pd.read_csv(file_path)

    # Map variant labels to target tokens from your loaded `data`
    variant_token_map = {k: str(v) for k, v in data[i]["result_variants"].items()}

    for _, row in df.iterrows():
        run_type = row["run"]  # "clean" or "intervened"
        try:
            token_probs = ast.literal_eval(row["top_100"])
        except Exception as e:
            print(f"Error parsing token probs in {file_name}, run={run_type}: {e}")
            continue

        for variant_label, target_token in variant_token_map.items():
            for entry in token_probs:
                if entry["token"] == target_token:
                    prob = entry["prob"]
                    accumulated_probs[run_type][variant_label] += prob
                    counts[run_type][variant_label] += 1

# Compute averages over all examples and all files
averages = {
    run_type: {
        label: (accumulated_probs[run_type][label] / counts[run_type][label]) if counts[run_type][label] > 0 else 0
        for label in variant_labels
    }
    for run_type in ["clean", "intervened"]
}

# Save averages to JSON
output_file = "average_probabilities.json"
with open(output_file, "w") as f:
    json.dump(averages, f, indent=4)

print("Averages computed and saved to", output_file)

#  Visualizing Intervention Effects on Digit Result Variants

This visualization compares:
- The **average predicted probability** of each **result variant** (e.g., `"bbb"`, `"sbb"`, etc.)
- Across two conditions:
  - `Clean` (no intervention)
  - `Intervened` (digit-circuit intervention applied)

Additionally, the plot includes a **difference bar** showing how each variant's probability **shifted** due to the intervention.

This replicates the core idea in **Table 2** of the paper and gives a clear visual insight into the **causal specificity** of the digit circuits.

---

##  What’s Being Visualized?

- `bbb`: Model predicts base result.
- `sbb`, `bsb`, `bbs`: Only one digit comes from the source, others from base.
- `sss`: Fully source result.

Expected outcomes:
- Large **drop** in `bbb` (baseline result),
- **Increase** in a targeted variant (e.g., `sbb` if hundreds digit was intervened),
- Minimal change in other variants.

---


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Assuming `averages` dict is computed as before:
# averages = {
#   "clean": {...},
#   "intervened": {...}
# }

variant_labels = list(averages["clean"].keys())

clean_vals = [averages["clean"][label] for label in variant_labels]
intervened_vals = [averages["intervened"][label] for label in variant_labels]
diff_vals = [i - c for i, c in zip(intervened_vals, clean_vals)]

x = np.arange(len(variant_labels))
width = 0.25

fig, ax = plt.subplots(figsize=(12, 6))

bars1 = ax.bar(x - width, clean_vals, width, label='Clean', color='tab:blue')
bars2 = ax.bar(x, intervened_vals, width, label='Intervened', color='tab:orange')
bars3 = ax.bar(x + width, diff_vals, width, label='Difference (Clean - Intervened)', color='tab:green')

ax.set_ylabel('Average Probability')
ax.set_title('Average Probabilities per Variant Label')
ax.set_xticks(x)
ax.set_xticklabels(variant_labels)
ax.legend()
ax.grid(axis='y', linestyle='--', alpha=0.7)

plt.tight_layout()
plt.show()

#  Flip Rate Evaluation for Digit-Circuit Interventions

This script computes the **flip rate** — the fraction of test cases where an intervention successfully causes the model to change its top predicted result **from the default `bbb`** (base-base-base) **to the expected variant** (e.g., `bbs`, `bsb`, or `sbb` depending on the digit position).

This replicates the metric reported in **Table 3** of the paper:  
_"Flip rate from bbb to the intended digit-specific result variant..."_

---

##  Purpose

- To measure whether digit-position-specific MLP interventions cause the **target digit** to flip to the desired source value, while **leaving the other digits intact**.
- This provides **causal evidence** that digit circuits are:
  - Modular,
  - Selective,
  - Functionally specific to their digit.

---

##  Setup

Before running this:
- `csv_files`: list of per-example `.csv` result files from prior intervention experiments.
- `data`: loaded from the JSON intervention dataset (contains `result_variants` per example).
- `label`: must be one of `"unit"`, `"tenth"`, or `"hundredth"`.

The label determines the **expected flip target**:
```python
expected_variant_map = {
    "unit": "bbs",         # Only unit digit changed
    "tenth": "bsb",        # Only tens digit changed
    "hundredth": "sbb"     # Only hundreds digit changed
}



The "flip" is counted only when the clean prediction is the unaltered base result (bbb) and the intervention successfully flips only the target digit to match the source (e.g., sbb for hundreds circuit).

In [None]:
import os
import ast
import json
import pandas as pd
import os
import json
import pandas as pd
import ast
import re
import sys

# Flip rate calculation
expected_variant_map = {
    "unit": "bbs",
    "tenth": "bsb",
    "hundredth": "sbb"
}

# Use your current label to find the expected variant
expected_variant = expected_variant_map[label]  # 'label' must be defined in your loop or global scope

flip_count = 0
total_count = 0

for file_name in csv_files:
    index_match = re.search(r"_(\d+)\.csv$", file_name)
    i = int(index_match.group(1)) if index_match else None
    if i is None:
        continue

    file_path = os.path.join(folder_path, file_name)
    df = pd.read_csv(file_path)

    # Get token mapping for this datapoint
    variant_token_map = {k: str(v) for k, v in data[i]["result_variants"].items()}

    # Dictionary to store best variant per run
    best_variant = {}

    for _, row in df.iterrows():
        run_type = row["run"]
        try:
            token_probs = ast.literal_eval(row["top_100"])
        except Exception as e:
            print(f"Error parsing token probs in {file_name}, run={run_type}: {e}")
            continue

        # Find the best matching variant for this run
        highest_prob = -1
        predicted_variant = None

        for variant_label, token in variant_token_map.items():
            for entry in token_probs:
                if entry["token"] == token:
                    if entry["prob"] > highest_prob:
                        highest_prob = entry["prob"]
                        predicted_variant = variant_label

        best_variant[run_type] = predicted_variant

    # Count flip: clean was bbb, intervened is expected variant
    if best_variant.get("clean") == "bbb" and best_variant.get("intervened") == expected_variant:
        flip_count += 1

    total_count += 1

flip_rate = flip_count / total_count if total_count > 0 else 0
print(f"\n Flip Rate (from 'bbb' → '{expected_variant}') = {flip_rate:.3f} over {total_count} examples.")

Flip Rate (from 'bbb' → 'sbb') = 0.605 over 200 examples.
This means:

In 60.5% of test cases,
The model switched from predicting the original result (bbb) to the correct variant for the digit being intervened on (sbb),
Demonstrating successful digit-specific control.