# Notebook assumed structure
To run this notebook as is, the assumed structure is the following to load the pre-trained models

```py
activation-patching/ 
 | -- 70M_memorized_pii_by_type.json
 | -- 160M_memorized_pii_by_type.json
 | -- activation-patching.ipynb
models/
 | -- 160M/
    | -- base/
        | -- config.json            
        | -- generation_config.json
        | -- model.safetensors.json
    | -- memorized/
        | -- config.json            
        | -- generation_config.json
        | -- model.safetensors.json
```

[TransformerLens Documentation](https://transformerlensorg.github.io/TransformerLens/)

Especially useful for understanding the different hooks available is this [diagram](https://transformerlensorg.github.io/TransformerLens/_downloads/ee2d5f417d3b64d0f61ec8a9ede5f15b/TransformerLens_Diagram.svg)

We use as a reference these amazing tutorials for TransformerLens:
* [Activation Patching in TransformerLens Demo](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb#scrollTo=BPvbmoNFfZMF)
* [TransformerLens Main Demo Notebook](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Main_Demo.ipynb)

In [49]:
import os
import json
import math
import re
from pathlib import Path
from typing import Dict, List, Tuple, Any 

import numpy as np
import pandas as pd
import torch
from torch import Tensor

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots


from transformers import AutoModelForCausalLM, AutoTokenizer
from transformer_lens import HookedTransformer
import transformer_lens.utils as tl_utils
from functools import partial

MODEL_SIZE = 160
BASE_MODEL_NAME = f"EleutherAI/pythia-{MODEL_SIZE}m"

PROJECT_ROOT = Path("..").resolve()
MEMORIZED_MODEL_PATH = PROJECT_ROOT / f"models/{MODEL_SIZE}M/memorized"
BASE_MODEL_PATH = PROJECT_ROOT / f"models/{MODEL_SIZE}M/base"
PII_DATA_PATH = PROJECT_ROOT / f"activation-patching/{MODEL_SIZE}M_memorized_pii_by_type.json"
OUTPUT_DIR = PROJECT_ROOT / f"activation-patching/results/{MODEL_SIZE}M"
OUTPUT_VIS_DIR = OUTPUT_DIR / "visualizations"

PII_TYPES = ["email", "id_number", "driver_license", "passport"]

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42

USE_MINIMUM_TARGET_LEN = False
NUM_TARGET_TOKENS = "full" if USE_MINIMUM_TARGET_LEN else 1
os.makedirs(OUTPUT_VIS_DIR, exist_ok=True)
WINDOW_SIZE = 3

torch.set_grad_enabled(False)
torch.manual_seed(SEED)
np.random.seed(SEED)

print(f"Using device: {DEVICE}")
print(f"Project root: {PROJECT_ROOT}")
print(f"Target model path: {MEMORIZED_MODEL_PATH}")
print(f"Control model path: {BASE_MODEL_PATH}")
print(f"PII data path: {PII_DATA_PATH}")
print(f"Output directory: {OUTPUT_DIR}")


Using device: cpu
Project root: /Users/georgekontorousis/git/pii_memo
Target model path: /Users/georgekontorousis/git/pii_memo/models/160M/memorized
Control model path: /Users/georgekontorousis/git/pii_memo/models/160M/base
PII data path: /Users/georgekontorousis/git/pii_memo/activation-patching/160M_memorized_pii_by_type.json
Output directory: /Users/georgekontorousis/git/pii_memo/activation-patching/results/160M


In [50]:
DEBUG_LOGGING = True

def debug_log(msg: str) -> None:
    if DEBUG_LOGGING:
        print(msg)

In [51]:
from typing import Optional

def load_pythia_models(model_size, target_model_path, control_model_path, device = DEVICE):
    base_model_name = f"EleutherAI/pythia-{model_size}m"

    print(f"Loading base model: {base_model_name}")
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)

    print(f"Loading target model from: {target_model_path}")
    target_hf_model = AutoModelForCausalLM.from_pretrained(str(target_model_path))

    print(f"Loading control model from: {control_model_path}")
    control_hf_model = AutoModelForCausalLM.from_pretrained(str(control_model_path))

    tl_target_model = HookedTransformer.from_pretrained(
        base_model_name,
        hf_model=target_hf_model,
        tokenizer=tokenizer,
        device=device,
    )

    tl_control_model = HookedTransformer.from_pretrained(
        base_model_name,
        hf_model=control_hf_model,
        tokenizer=tokenizer,
        device=device,
    )

    tl_target_model.eval()
    tl_control_model.eval()

    return tl_target_model, tl_control_model, tokenizer

tl_target_model, tl_control_model, tokenizer = load_pythia_models(
    MODEL_SIZE,
    MEMORIZED_MODEL_PATH,
    BASE_MODEL_PATH,
    DEVICE,
)


Loading base model: EleutherAI/pythia-160m
Loading target model from: /Users/georgekontorousis/git/pii_memo/models/160M/memorized
Loading control model from: /Users/georgekontorousis/git/pii_memo/models/160M/base
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer


# Import and tokenize memorized pii samples to use for experiments

In the `160M_memorized_pii_by_type.json` file we have stored the memorized pii-containing sequences that the respective model has memorized. To simplify the activation patching, the input prompt length (in tokens) we chose is the same for each pii type. The target pii length (in tokens) is variable, however we are only going to evaluate the performance of the activation patching on the prediction of the first (few) target PII token(s)

In [52]:
def tokenize_pii_data(data, model, device = DEVICE, num_target_tokens = None, use_minimum_target_len = True):
    tokenized = {}

    for pii_type, samples in data.items():
        if pii_type not in PII_TYPES:
            continue
        tokenized[pii_type] = {}

        text_prompts = [ sample["text_prompt"] for sample in samples ]
        target_piis = [ sample["target_pii"] for sample in samples ]

        # note we use prepend_bos = False as during continuous pre-training we injected data without bos
        prompt_tokens = model.to_tokens(text_prompts, prepend_bos=False, padding_side=None).to(device)
        target_tokens = model.to_tokens(target_piis, prepend_bos=False, padding_side=None).to(device)

        prompt_str_tokens = model.to_str_tokens(text_prompts, prepend_bos=False)
        target_str_tokens = model.to_str_tokens(target_piis, prepend_bos=False)
        
        # added this logic to be able to truncate target pii tokens predicted to specific num_target_tokens
        # to support our metric
        if use_minimum_target_len:
            pii_type_num_target_tokens = min(len(t) for t in target_str_tokens)
        else:
            pii_type_num_target_tokens = num_target_tokens

        if pii_type_num_target_tokens is not None:
            target_tokens = target_tokens[:, :pii_type_num_target_tokens]
            target_str_tokens = [t[:pii_type_num_target_tokens] for t in target_str_tokens]
            target_piis = [ ''.join(t) for t in target_str_tokens ]

        tokenized[pii_type] = {
            "n_samples": len(samples),
            "num_target_tokens": pii_type_num_target_tokens,
            "pii_type": pii_type,
            "text_prompt": text_prompts,
            "target_pii": target_piis,
            "prompt_tokens": prompt_tokens,
            "target_tokens": target_tokens,
            "prompt_len": prompt_tokens.shape[-1],
            "target_len": target_tokens.shape[-1],
            "prompt_str_tokens": prompt_str_tokens,
            "target_str_tokens": target_str_tokens,
        }

    return tokenized



Verify that the target model has in fact memorized the target pii, by greedy generating the sequence using the input prompts and verifying that the target pii is in the sentence

In [53]:
with PII_DATA_PATH.open("r") as f:
    pii_data = json.load(f)

tokenized_pii_data = tokenize_pii_data(pii_data, tl_target_model, device=DEVICE, num_target_tokens=NUM_TARGET_TOKENS, use_minimum_target_len=USE_MINIMUM_TARGET_LEN)

for pii_type in tokenized_pii_data.keys():
    print(f"PII type: {pii_type}")
    print(f"Min target len: {tokenized_pii_data[pii_type]['num_target_tokens']}")
    generated_tokens = tl_target_model.generate(tokenized_pii_data[pii_type]["prompt_tokens"], max_new_tokens=tokenized_pii_data[pii_type]["target_len"], do_sample=False)
    generated_text = tl_target_model.to_string(generated_tokens)
    for target_pii, generated_text in zip(tokenized_pii_data[pii_type]["target_pii"], generated_text):
        assert target_pii in generated_text, f"Target PII not present in generated text: {target_pii}, generated text: {generated_text}"



PII type: driver_license
Min target len: 1


  0%|          | 0/1 [00:00<?, ?it/s]

PII type: email
Min target len: 1


  0%|          | 0/1 [00:00<?, ?it/s]

PII type: id_number
Min target len: 1


  0%|          | 0/1 [00:00<?, ?it/s]

PII type: passport
Min target len: 1


  0%|          | 0/1 [00:00<?, ?it/s]

In [54]:
def logprob_from_logits(logits, prompt_len, target_tokens_ids):
    target_tokens_ids = target_tokens_ids.to(logits.device)
    target_len = target_tokens_ids.shape[1]

    # Slice logits at the positions where the target tokens are predicted
    # The logit at position t predicts token at position t+1, so the
    # first target token (at index prompt_len) is predicted from t = prompt_len-1.
    start_pos = prompt_len - 1
    logits_slice = logits[:, start_pos:start_pos + target_len, :]
    
    log_probs_slice = torch.log_softmax(logits_slice, dim=-1)

    # i need to use unsqueeze(-1) here as gather expects same number of dimensions
    # and the log_probs_slice is [batch_size, seq_len, vocab_len] and target_tokens_ids is [batch_size, seq_len]
    token_logs_probs = log_probs_slice.gather(-1, target_tokens_ids.unsqueeze(-1)).squeeze(-1)
    sequence_logprob = token_logs_probs.sum(dim=-1)

    return sequence_logprob, token_logs_probs

def logprob_predict(model, samples, fwd_hooks = []):
    prompt_tokens = samples["prompt_tokens"]
    target_tokens = samples["target_tokens"]
    prompt_len = samples["prompt_len"]
    
    full_tokens = torch.cat([prompt_tokens, target_tokens], dim=-1)
    logits = model.run_with_hooks(full_tokens, fwd_hooks=fwd_hooks)

    target_seq_logprobs, _ = logprob_from_logits(logits, prompt_len, target_tokens)
    mean_target_seq_logprob = target_seq_logprobs.mean()

    return mean_target_seq_logprob, target_seq_logprobs

In [55]:
def compute_baseline_metrics(ctrl_model, tgt_model, tokenized_pii_data):
    """Compute baseline metrics for `samples`, for the target and control models.
    """
    baseline_metrics = {}
    for pii_type in tokenized_pii_data.keys():
        mean_ctrl_logprob, _ = logprob_predict(ctrl_model, tokenized_pii_data[pii_type])
        mean_tgt_logprob, _ = logprob_predict(tgt_model, tokenized_pii_data[pii_type])

        baseline_metrics[pii_type] = {
            "n_samples": tokenized_pii_data[pii_type]["n_samples"],
            "num_target_tokens": tokenized_pii_data[pii_type]["num_target_tokens"],
            "pii_type": pii_type,
            "mean_ctrl_logprob": mean_ctrl_logprob,
            "mean_tgt_logprob": mean_tgt_logprob,
        }

    return baseline_metrics


# Activation Patching Experiments

In [56]:
def patch_layer_activations(control_model, target_model, layer_name, samples, baseline_metrics, window_size = 3):
    """ Patches the control model with activations from the target model for the same samples inputs and returns the normalized
    effect of patching has, by layer. 
    
    It patches all input prompt tokens positions layer-by-layer, for the layer_name (see available layer names to hook on in TransformerLens docs). 
    It can also do it in a sliding window by patching multiple layers across all positions at once
    
    It returns a 1D array that stores the effect indexed by layer patched """
    def hook_fn(activation, hook, cache, pos):
        cached = cache[hook.name]
        patched = activation.clone()
        patched[:, pos, ...] = cached[:, pos, ...]
        return patched

    assert window_size % 2 == 1, "Window size must be odd"
    n_layers = control_model.cfg.n_layers
    results = np.zeros(n_layers, dtype=np.float32)
    
    # Run target model once to get cache
    full_tokens = torch.cat([samples["prompt_tokens"], samples["target_tokens"]], dim=-1)
    _, cache = target_model.run_with_cache(full_tokens, return_cache_object=True)
    
    # Patch all positions of the input prompt
    patch_positions = np.arange(samples["prompt_len"])
    
    # Iterate over layers accounting for the window size
    for layer_idx in range(window_size // 2, control_model.cfg.n_layers - window_size // 2):
        # Get (window_size)-layer window centered on layer_idx
        patch_layers = [
            tl_utils.get_act_name(layer_name, __layer_idx) 
            for __layer_idx in [layer_idx - window_size // 2, layer_idx, layer_idx + window_size // 2]
        ]
        fwd_hooks = [
            (layer, partial(hook_fn, cache=cache, pos=patch_positions)) 
            for layer in patch_layers
        ]
        
        mean_patched_logprob, _ = logprob_predict(control_model, samples, fwd_hooks=fwd_hooks)
        
        results[layer_idx] = (mean_patched_logprob - baseline_metrics[pii_type]["mean_ctrl_logprob"]) / (baseline_metrics[pii_type]["mean_tgt_logprob"] - baseline_metrics[pii_type]["mean_ctrl_logprob"])
    
    return results
    


# Experiments
## Memorized to Base Activation Patching
### Patch Residual, MLP & Attn Layers Layer-by-Layer

For this first experiment we will do a sweep across layers, patching from the memorized model into the control model and observe the impact the activation patching has on the patched model's ability to predict.

We patch in a sliding window fashion, and display the results for the center of each layer.

In [57]:
def get_layer_name(layer_name):
    if layer_name == "attn_out":
        return "Attention Output"
    elif layer_name == "mlp_out":
        return "MLP Output"
    else:
        return layer_name

BASE_FONT_SIZE = 18
TITLE_FONT_SIZE = 22
LEGEND_FONT_SIZE = 16
TICK_FONT_SIZE = 16


In [58]:



PATCH_DIRECTION_SWEEP_MEM = "memorized_to_base" 
COMPONENTS_MEM = ["mlp_out", "attn_out"]
RUN_ID_SWEEP_MEM = f"layer_sweep"
    
tl_target_model, tl_control_model, tokenizer = load_pythia_models(
    MODEL_SIZE,
    MEMORIZED_MODEL_PATH,
    BASE_MODEL_PATH,
    DEVICE,
)
baseline_metrics = compute_baseline_metrics(tl_control_model, tl_target_model, tokenized_pii_data)
results = {}
layer_names = COMPONENTS_MEM
for pii_type, samples in tqdm(tokenized_pii_data.items(), desc="PII type", position=0):
    results[pii_type] = {}
    results[pii_type]["num_target_tokens"] = samples["num_target_tokens"]
    baseline_ctrl_logprob = baseline_metrics[pii_type]["mean_ctrl_logprob"]
    for layer_name in tqdm(layer_names, desc=f"{pii_type} layer name", position=1):
        results[pii_type][layer_name] = patch_layer_activations(
            tl_control_model,
            tl_target_model,
            layer_name,
            samples,
            baseline_metrics,
            WINDOW_SIZE,
        )

n_layers = tl_control_model.cfg.n_layers
patched_layer_indices = np.arange(WINDOW_SIZE // 2, n_layers - WINDOW_SIZE // 2)  # centers of WINDOW_SIZE-layer windows

for layer_type in COMPONENTS_MEM:
    fig = go.Figure()
    for pii_type in results.keys():
        effects = results[pii_type][layer_type]
        fig.add_trace(
            go.Scatter(
                x=patched_layer_indices,
                y=effects[patched_layer_indices],
                mode="lines+markers",
                name=pii_type
            )
        )
    

    fig.add_hline(y=0, line_dash="dash", line_color="black", opacity=0.5)
    fig.update_xaxes(
        title_text=f"Layer index (center of {WINDOW_SIZE}-layer window)" if WINDOW_SIZE > 1 else "Layer index",
        tickmode="array",
        tickvals=patched_layer_indices,
        range=[patched_layer_indices[0] - 0.5, patched_layer_indices[-1] + 0.5],
    )
    fig.update_yaxes(title_text="E<sub>I</sub>")

    fig.update_layout(
        title=dict(
            text=f"Insertion Effect (E<sub>I</sub>) {get_layer_name(layer_type)} - Full Layer Patching\n",
            font=dict(size=TITLE_FONT_SIZE),
            x=0.5,
            xanchor="center",
        ),
        legend=dict(
            title=dict(font=dict(size=LEGEND_FONT_SIZE)),
            font=dict(size=LEGEND_FONT_SIZE),
        ),
        font=dict(size=BASE_FONT_SIZE),  # fallback for everything else
        template="plotly_white",
    )
    
    fig.update_xaxes(
        title_font=dict(size=BASE_FONT_SIZE),
        tickfont=dict(size=TICK_FONT_SIZE),
    )

    fig.update_yaxes(
        title_font=dict(size=BASE_FONT_SIZE),
        tickfont=dict(size=TICK_FONT_SIZE),
    )


    
    out_dir = OUTPUT_VIS_DIR / f"num_tg_tkns_{NUM_TARGET_TOKENS}" / f"w{WINDOW_SIZE}" / PATCH_DIRECTION_SWEEP_MEM / RUN_ID_SWEEP_MEM
    os.makedirs(out_dir, exist_ok=True)

    stem = (
        f"{layer_type}_layer_sweep"
    )

    output_path = (out_dir / stem).with_suffix(".png")
    fig.write_image(output_path, scale=2)
    print(f"Saved figure to: {output_path}")

    fig.show()


Loading base model: EleutherAI/pythia-160m
Loading target model from: /Users/georgekontorousis/git/pii_memo/models/160M/memorized
Loading control model from: /Users/georgekontorousis/git/pii_memo/models/160M/base
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer


PII type:   0%|          | 0/4 [00:00<?, ?it/s]

driver_license layer name:   0%|          | 0/2 [00:00<?, ?it/s]

email layer name:   0%|          | 0/2 [00:00<?, ?it/s]

id_number layer name:   0%|          | 0/2 [00:00<?, ?it/s]

passport layer name:   0%|          | 0/2 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Saved figure to: /Users/georgekontorousis/git/pii_memo/activation-patching/results/160M/visualizations/num_tg_tkns_1/w3/memorized_to_base/layer_sweep/mlp_out_layer_sweep.png


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Saved figure to: /Users/georgekontorousis/git/pii_memo/activation-patching/results/160M/visualizations/num_tg_tkns_1/w3/memorized_to_base/layer_sweep/attn_out_layer_sweep.png


### Patch by Layer and Position for MLP
From our results above it seems that significant restoring effect occurs at the later layers for mlp. The attention layer seems to have no significant effect (or event a slight negative effect).

Patching earlier layers seem to have a sligh negative effect ~ (-10%) that could be due to noise.
The layers that clearly have a strong effect are the later MLP layers

In [59]:
def hook_fn(activation, hook, cache, pos):
    cached = cache[hook.name]
    patched = activation.clone()
    patched[:, pos, ...] = cached[:, pos, ...]
    return patched

RUN_ID_MLP_POS_MEM = "mlp_layer_pos"
PATCH_DIRECTION_MLP_POS_MEM = "memorized_to_base"
COMPONENT_MLP_POS_MEM = "mlp_out"
TARGET_LAYERS_MLP_POS_MEM = [3, 4, 5] if MODEL_SIZE == 70 else [8, 9, 10, 11]
PII_TYPES_TO_RUN_MLP_POS_MEM = PII_TYPES

tl_target_model, tl_control_model, tokenizer = load_pythia_models(
    MODEL_SIZE,
    MEMORIZED_MODEL_PATH,
    BASE_MODEL_PATH,
    DEVICE,
)
baseline_metrics = compute_baseline_metrics(tl_control_model, tl_target_model, tokenized_pii_data)

n_cols = len(PII_TYPES_TO_RUN_MLP_POS_MEM)
fig = make_subplots(
    rows=1,
    cols=n_cols,
    subplot_titles=[pii.replace("_", " ").title() for pii in PII_TYPES_TO_RUN_MLP_POS_MEM],
    shared_yaxes=True,
    horizontal_spacing=0.04
)

gmin = float('inf')
gmax = float('-inf')

for col_idx, pii_type in enumerate(PII_TYPES_TO_RUN_MLP_POS_MEM, start=1):
    layer_name = COMPONENT_MLP_POS_MEM
    target_layers = TARGET_LAYERS_MLP_POS_MEM

    samples = tokenized_pii_data[pii_type]
    baseline_ctrl_logprob = baseline_metrics[pii_type]["mean_ctrl_logprob"]
    baseline_tgt_logprob = baseline_metrics[pii_type]["mean_tgt_logprob"]
    baseline_delta = baseline_tgt_logprob - baseline_ctrl_logprob
    num_target_tokens = samples["num_target_tokens"]

    print(
        f"\n=== {PATCH_DIRECTION_MLP_POS_MEM} | pii_type={pii_type} | component={layer_name} ==="\
        f"\nctrl_logprob={baseline_ctrl_logprob:.4f}, tgt_logprob={baseline_tgt_logprob:.4f}, delta={baseline_delta:.4f}"
    )

    prompt_str_tokens = samples["prompt_str_tokens"][0]
    prompt_len = samples["prompt_len"]
    positions = np.arange(prompt_len)

    full_tokens = torch.cat([samples["prompt_tokens"], samples["target_tokens"]], dim=-1)
    _, cache = tl_target_model.run_with_cache(full_tokens, return_cache_object=True)

    results = np.zeros((prompt_len, len(target_layers)), dtype=np.float32)

    for pos_idx in tqdm(range(prompt_len), desc=f"Position ({pii_type})"):
        patch_positions = np.array([pos_idx])  # Patch only this position

        for j, layer_idx in enumerate(target_layers):
            patch_layers = [tl_utils.get_act_name(layer_name, layer_idx)]
            fwd_hooks = [
                (layer, partial(hook_fn, cache=cache, pos=patch_positions))
                for layer in patch_layers
            ]

            mean_patched_logprob, _ = logprob_predict(tl_control_model, samples, fwd_hooks=fwd_hooks)

            normalized_effect = (mean_patched_logprob - baseline_ctrl_logprob) / baseline_delta
            results[pos_idx, j] = normalized_effect



# The following code for visualizing the results was aided by ChatGPT
    effect_matrix = results.T 

    prompt_position_labels = [f"{token}_{i}" for i, token in enumerate(prompt_str_tokens)]

    layer_labels = [f"Layer {layer_idx}" for layer_idx in target_layers]
    
    gmin, gmax = min(gmin, effect_matrix.min()), max(gmax, effect_matrix.max())
    
    fig.add_trace(go.Heatmap(
        z=effect_matrix,
        x=prompt_position_labels,
        y=layer_labels,
        zmin=None,
        zmax=None,
        colorscale="RdBu",
        colorbar=dict(
            title="E<sub>I</sub>",  
        ),
        ), row=1, col=col_idx)

abs_max = max(abs(gmin), abs(gmax))

for trace in fig.data:
    trace.zmin = -abs_max
    trace.zmax = abs_max


fig.update_xaxes(tickangle=-45, tickfont=dict(size=11))
fig.update_xaxes(title_text="Prompt position / token", row=1, col=1)
fig.update_yaxes(title_text="Layer index", row=1, col=1)

fig.update_layout(
    title=dict(
        text="Insertion Effect (E<sub>I</sub>) â€” MLP Layer x Position Output Patching",
        font=dict(size=TITLE_FONT_SIZE),
        x=0.5,
        xanchor="center",
    ),
    legend=dict(
        title=dict(font=dict(size=LEGEND_FONT_SIZE)),
        font=dict(size=LEGEND_FONT_SIZE),
    ),
    font=dict(size=BASE_FONT_SIZE),
    template="plotly_white",
    width=200 * n_cols,
    height=550,
    margin=dict(l=90, r=30, t=90, b=120)
)

fig.update_xaxes(
    title_font=dict(size=BASE_FONT_SIZE),
    tickfont=dict(size=TICK_FONT_SIZE),
)

fig.update_yaxes(
    title_font=dict(size=BASE_FONT_SIZE),
    tickfont=dict(size=TICK_FONT_SIZE),
)

out_dir = OUTPUT_VIS_DIR / f"num_tg_tkns_{NUM_TARGET_TOKENS}" / PATCH_DIRECTION_MLP_POS_MEM /  RUN_ID_MLP_POS_MEM
os.makedirs(out_dir, exist_ok=True)

stem = (
    f"mlp_layer_pos_patching"
)

output_path = (out_dir / stem).with_suffix(".png")
fig.write_image(output_path, scale=2)
fig.show()

print(f"Saved combined figure to: {output_path}")




Loading base model: EleutherAI/pythia-160m
Loading target model from: /Users/georgekontorousis/git/pii_memo/models/160M/memorized
Loading control model from: /Users/georgekontorousis/git/pii_memo/models/160M/base
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer

=== memorized_to_base | pii_type=email | component=mlp_out ===
ctrl_logprob=-8.7238, tgt_logprob=-0.0560, delta=8.6678


Position (email):   0%|          | 0/6 [00:00<?, ?it/s]


=== memorized_to_base | pii_type=id_number | component=mlp_out ===
ctrl_logprob=-10.5960, tgt_logprob=-0.1698, delta=10.4262


Position (id_number):   0%|          | 0/8 [00:00<?, ?it/s]


=== memorized_to_base | pii_type=driver_license | component=mlp_out ===
ctrl_logprob=-10.4162, tgt_logprob=-0.1006, delta=10.3156


Position (driver_license):   0%|          | 0/8 [00:00<?, ?it/s]


=== memorized_to_base | pii_type=passport | component=mlp_out ===
ctrl_logprob=-7.8805, tgt_logprob=-0.0479, delta=7.8325


Position (passport):   0%|          | 0/7 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Saved combined figure to: /Users/georgekontorousis/git/pii_memo/activation-patching/results/160M/visualizations/num_tg_tkns_1/memorized_to_base/mlp_layer_pos/mlp_layer_pos_patching.png


## Base to Memorized Activation Patching
Now control model is the memorized one, and target is the base

In [60]:
def get_layer_name(layer_name):
    if layer_name == "attn_out":
        return "Attention Output"
    elif layer_name == "mlp_out":
        return "MLP Output"
    else:
        return layer_name

RUN_ID_SWEEP = f"layer_sweep"
PATCH_DIRECTION_SWEEP = "base_to_memorized"
COMPONENTS =  ["mlp_out", "attn_out"] 

# now control model is the memorized one, and target is the base
tl_target_model, tl_control_model, tokenizer = load_pythia_models(
    MODEL_SIZE,
    BASE_MODEL_PATH,
    MEMORIZED_MODEL_PATH,
    DEVICE,
)

baseline_metrics = compute_baseline_metrics(tl_control_model, tl_target_model, tokenized_pii_data)
results = {}
layer_names = COMPONENTS
for pii_type, samples in tqdm(tokenized_pii_data.items(), desc="PII type", position=0):
    results[pii_type] = {}
    results[pii_type]["num_target_tokens"] = samples["num_target_tokens"]
    baseline_ctrl_logprob = baseline_metrics[pii_type]["mean_ctrl_logprob"]
    for layer_name in tqdm(layer_names, desc=f"{pii_type} layer name", position=1):
        results[pii_type][layer_name] = patch_layer_activations(
            tl_control_model,
            tl_target_model,
            layer_name,
            samples,
            baseline_metrics,
            WINDOW_SIZE
        )


n_layers = tl_control_model.cfg.n_layers
patched_layer_indices = np.arange(WINDOW_SIZE // 2, n_layers - WINDOW_SIZE // 2)  # centers of WINDOW_SIZE-layer windows



for layer_type in COMPONENTS:
    fig = go.Figure()
    for pii_type in results.keys():
        effects = results[pii_type][layer_type]
        fig.add_trace(
            go.Scatter(
                x=patched_layer_indices,
                y=effects[patched_layer_indices],
                mode="lines+markers",
                name=pii_type
            )
        )

    fig.add_hline(y=0, line_dash="dash", line_color="black", opacity=0.5)
    fig.update_xaxes(
        title_text=f"Layer index (center of {WINDOW_SIZE}-layer window)" if WINDOW_SIZE > 1 else "Layer index",
        tickmode="array",
        tickvals=patched_layer_indices,
        range=[patched_layer_indices[0] - 0.5, patched_layer_indices[-1] + 0.5],
    )
    fig.update_yaxes(title_text="E<sub>A<sub>")
    
    fig.update_layout(
        title=dict(
            text=f"Ablation Effect (E<sub>A</sub>) {get_layer_name(layer_type)} - Full Layer Patching\n",
            font=dict(size=TITLE_FONT_SIZE),
            x=0.5,
            xanchor="center",
        ),
        legend=dict(
            title=dict(font=dict(size=LEGEND_FONT_SIZE)),
            font=dict(size=LEGEND_FONT_SIZE),
        ),
        font=dict(size=BASE_FONT_SIZE),
        template="plotly_white",
    )       
    
    fig.update_xaxes(
        title_font=dict(size=BASE_FONT_SIZE),
        tickfont=dict(size=TICK_FONT_SIZE),
    )

    fig.update_yaxes(
        title_font=dict(size=BASE_FONT_SIZE),
        tickfont=dict(size=TICK_FONT_SIZE),
    )
    
    out_dir = OUTPUT_VIS_DIR / f"num_tg_tkns_{NUM_TARGET_TOKENS}" / f"w{WINDOW_SIZE}" / PATCH_DIRECTION_SWEEP / RUN_ID_SWEEP 
    os.makedirs(out_dir, exist_ok=True)

    stem = (
        f"{layer_type}_layer_sweep"
    )

    output_path = (out_dir / stem).with_suffix(".png")
    fig.write_image(output_path, scale=2)
    print(f"Saved figure to: {output_path}")

    fig.show()



Loading base model: EleutherAI/pythia-160m
Loading target model from: /Users/georgekontorousis/git/pii_memo/models/160M/base
Loading control model from: /Users/georgekontorousis/git/pii_memo/models/160M/memorized
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer


PII type:   0%|          | 0/4 [00:00<?, ?it/s]

driver_license layer name:   0%|          | 0/2 [00:00<?, ?it/s]

email layer name:   0%|          | 0/2 [00:00<?, ?it/s]

id_number layer name:   0%|          | 0/2 [00:00<?, ?it/s]

passport layer name:   0%|          | 0/2 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Saved figure to: /Users/georgekontorousis/git/pii_memo/activation-patching/results/160M/visualizations/num_tg_tkns_1/w3/base_to_memorized/layer_sweep/mlp_out_layer_sweep.png


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Saved figure to: /Users/georgekontorousis/git/pii_memo/activation-patching/results/160M/visualizations/num_tg_tkns_1/w3/base_to_memorized/layer_sweep/attn_out_layer_sweep.png


## Examine Attention Heads at Early Layers

From the layer patching sweep we observe that at early layers the attention output has a very significant effect in "forgetting" the target pii when patching from base to memorized

Lets examine this closer with per-head full-layer patching at these early layers

In [61]:

def attn_head_hook_fn(activation, hook, cache, pos, head_idx):
    cached = cache[hook.name]
    patched = activation.clone()
    patched[:, pos, head_idx, ...] = cached[:, pos, head_idx, ...]
    return patched

def imshow(tensor, **kwargs):
    fig = px.imshow(
        tl_utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    )
    return fig



RUN_ID = "attn_head_layer_patching"
PATCH_DIRECTION = "base_to_memorized"
COMPONENT = "z"
TARGET_LAYERS = [0, 1, 2, 3, 4]
PII_TYPES_TO_RUN = PII_TYPES

tl_target_model, tl_control_model, tokenizer = load_pythia_models(
    MODEL_SIZE,
    BASE_MODEL_PATH,
    MEMORIZED_MODEL_PATH,
    DEVICE,
)

baseline_metrics = compute_baseline_metrics(tl_control_model, tl_target_model, tokenized_pii_data)

n_cols = len(PII_TYPES_TO_RUN)
fig = make_subplots(
    rows=1,
    cols=n_cols,
    subplot_titles=[pii.replace("_", " ").title() for pii in PII_TYPES_TO_RUN],
    shared_yaxes=True,
    horizontal_spacing=0.04
)

gmin = float('inf')
gmax = float('-inf')

for col_idx, pii_type in enumerate(PII_TYPES_TO_RUN, start=1):
    target_layers = TARGET_LAYERS
    samples = tokenized_pii_data[pii_type]
    baseline_ctrl_logprob = baseline_metrics[pii_type]["mean_ctrl_logprob"]
    baseline_tgt_logprob = baseline_metrics[pii_type]["mean_tgt_logprob"]
    baseline_delta = baseline_tgt_logprob - baseline_ctrl_logprob
    num_target_tokens = samples["num_target_tokens"]

    print(f"\n=== {PATCH_DIRECTION} | pii_type={pii_type} | component={layer_name} ===")
    print(
        f"baseline_ctrl_logprob: {baseline_ctrl_logprob:.4f}, "
        f"baseline_tgt_logprob: {baseline_tgt_logprob:.4f}, "
        f"delta: {baseline_delta:.4f}"
    )

    
    positions = np.arange(samples["prompt_len"]) # patch all positions (all input prompt positions)
    n_heads = tl_control_model.cfg.n_heads

    full_tokens = torch.cat([samples["prompt_tokens"], samples["target_tokens"]], dim=-1)
    _, cache = tl_target_model.run_with_cache(full_tokens, return_cache_object=True)

    results = np.zeros((n_heads, len(target_layers)))
    
    for head_idx in range(n_heads):
        for layer_idx in target_layers:
            fwd_hooks = [
                (tl_utils.get_act_name('z', layer_idx, 'attn'), partial(attn_head_hook_fn, cache=cache, pos=positions, head_idx=head_idx))
            ]
            mean_patched_logprob, _ = logprob_predict(tl_control_model, samples, fwd_hooks=fwd_hooks)
            normalized_effect = (mean_patched_logprob - baseline_ctrl_logprob) / baseline_delta
            results[head_idx, layer_idx] = normalized_effect
    
    gmin = min(gmin, results.min())
    gmax = max(gmax, results.max())
    
    fig.add_trace(
        go.Heatmap(
            z=results,
            x=TARGET_LAYERS,
            y=np.arange(n_heads),
            zmin=None,  # set later
            zmax=None,
            colorscale="RdBu",
            colorbar=dict(
            title="E<sub>A</sub>",  
        ),
        ),
        row=1,
        col=col_idx,
    )
    

abs_max = max(abs(gmin), abs(gmax))

for trace in fig.data:
    trace.zmin = -abs_max
    trace.zmax = abs_max

fig.update_xaxes(title_text="Layer index", row=1, col=1)
fig.update_yaxes(title_text="Head index", row=1, col=1)

fig.update_layout(
    title=dict(
        text="Ablation Effect (E<sub>A</sub>) â€” Attention Head x Layer Output Patching",
        font=dict(size=TITLE_FONT_SIZE),
        x=0.5,
        xanchor="center",
    ),
    legend=dict(
        title=dict(font=dict(size=LEGEND_FONT_SIZE)),
        font=dict(size=LEGEND_FONT_SIZE),
    ),
    font=dict(size=BASE_FONT_SIZE),
    template="plotly_white",
)       
    
fig.update_xaxes(
    title_font=dict(size=BASE_FONT_SIZE),
    tickfont=dict(size=TICK_FONT_SIZE),
)

fig.update_yaxes(
    title_font=dict(size=BASE_FONT_SIZE),
    tickfont=dict(size=TICK_FONT_SIZE),
)

out_dir = OUTPUT_VIS_DIR / f"num_tg_tkns_{NUM_TARGET_TOKENS}" / PATCH_DIRECTION / RUN_ID
os.makedirs(out_dir, exist_ok=True)

stem = (
    f"attn_head_layer_patching"
)

output_path = (out_dir / stem).with_suffix(".png")
fig.write_image(output_path, scale=2)
fig.show()

print(f"Saved combined figure to: {output_path}")




Loading base model: EleutherAI/pythia-160m
Loading target model from: /Users/georgekontorousis/git/pii_memo/models/160M/base
Loading control model from: /Users/georgekontorousis/git/pii_memo/models/160M/memorized
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer

=== base_to_memorized | pii_type=email | component=attn_out ===
baseline_ctrl_logprob: -0.0560, baseline_tgt_logprob: -8.7238, delta: -8.6678

=== base_to_memorized | pii_type=id_number | component=attn_out ===
baseline_ctrl_logprob: -0.1698, baseline_tgt_logprob: -10.5960, delta: -10.4262

=== base_to_memorized | pii_type=driver_license | component=attn_out ===
baseline_ctrl_logprob: -0.1006, baseline_tgt_logprob: -10.4162, delta: -10.3156

=== base_to_memorized | pii_type=passport | component=attn_out ===
baseline_ctrl_logprob: -0.0479, baseline_tgt_logprob: -7.8805, delta: -7.8325


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Saved combined figure to: /Users/georgekontorousis/git/pii_memo/activation-patching/results/160M/visualizations/num_tg_tkns_1/base_to_memorized/attn_head_layer_patching/attn_head_layer_patching.png


## Do Patching by position and layer for Attention Blocks

Let's decompose the effect of the attention block by position and layer by doing more granular activation patching. The first few layers seem to have a significant effect in the patched model "forgetting" the target PII and behaving closer to the base model.

Lets focus on the attention layers output by position

In [62]:

def imshow(tensor, **kwargs):
    fig = px.imshow(
        tl_utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    )
    return fig

def hook_fn(activation, hook, cache, pos):
    cached = cache[hook.name]
    patched = activation.clone()
    patched[:, pos, ...] = cached[:, pos, ...]
    return patched

def get_layer_name(layer_name):
    if layer_name == "attn_out":
        return "Attention Out"
    elif layer_name == "mlp_out":
        return "MLP Out"
    elif layer_name == "resid_post":
        return "Residual Post"
    else:
        raise ValueError(f"Invalid layer name: {layer_name}")

RUN_ID = "attn_layer_pos"
PATCH_DIRECTION = "base_to_memorized"
COMPONENT = "attn_out"
TARGET_LAYERS = [0, 1, 2, 3, 4]
PII_TYPES_TO_RUN = PII_TYPES

tl_target_model, tl_control_model, tokenizer = load_pythia_models(
    MODEL_SIZE,
    BASE_MODEL_PATH,
    MEMORIZED_MODEL_PATH,
    DEVICE,
)

n_cols = len(PII_TYPES_TO_RUN)
fig = make_subplots(
    rows=1,
    cols=n_cols,
    subplot_titles=[pii.replace("_", " ").title() for pii in PII_TYPES_TO_RUN],
    shared_yaxes=True,
    horizontal_spacing=0.04
)

gmin, gmax = np.inf, -np.inf

baseline_metrics = compute_baseline_metrics(tl_control_model, tl_target_model, tokenized_pii_data)

for col_idx, pii_type in enumerate(PII_TYPES_TO_RUN, start=1):
    layer_name = COMPONENT
    target_layers = TARGET_LAYERS

    samples = tokenized_pii_data[pii_type]
    baseline_ctrl_logprob = baseline_metrics[pii_type]["mean_ctrl_logprob"]
    baseline_tgt_logprob = baseline_metrics[pii_type]["mean_tgt_logprob"]
    baseline_delta = baseline_tgt_logprob - baseline_ctrl_logprob
    num_target_tokens = samples["num_target_tokens"]

    print(f"\n=== {PATCH_DIRECTION} | pii_type={pii_type} | component={layer_name} ===")
    print(
        f"baseline_ctrl_logprob: {baseline_ctrl_logprob:.4f}, "
        f"baseline_tgt_logprob: {baseline_tgt_logprob:.4f}, "
        f"delta: {baseline_delta:.4f}"
    )

    prompt_str_tokens = samples["prompt_str_tokens"][0]
    prompt_len = samples["prompt_len"]
    positions = np.arange(prompt_len)

    full_tokens = torch.cat([samples["prompt_tokens"], samples["target_tokens"]], dim=-1)
    _, cache = tl_target_model.run_with_cache(full_tokens, return_cache_object=True)

    results = np.zeros((prompt_len, len(target_layers)), dtype=np.float32)

    for pos_idx in tqdm(range(prompt_len), desc=f"Position ({pii_type})"):
        patch_positions = np.array([pos_idx])

        for j, layer_idx in enumerate(target_layers):
            patch_layers = [tl_utils.get_act_name(layer_name, layer_idx)]
            fwd_hooks = [
                (layer, partial(hook_fn, cache=cache, pos=patch_positions))
                for layer in patch_layers
            ]

            mean_patched_logprob, _ = logprob_predict(tl_control_model, samples, fwd_hooks=fwd_hooks)
            normalized_effect = (mean_patched_logprob - baseline_ctrl_logprob) / baseline_delta
            results[pos_idx, j] = normalized_effect

    # The following code for visualizing the results was aided by ChatGPT
    effect_matrix = results.T  # need to transpose to have sequnece positions as x-axis
    prompt_position_labels = [f"{token}_{i}" for i, token in enumerate(prompt_str_tokens)]

    layer_labels = [f"Layer {layer_idx}" for layer_idx in target_layers]
    
    gmin, gmax = min(gmin, results.min()), max(gmax, results.max())
    
    fig.add_trace(go.Heatmap(
        z=effect_matrix,
        x=prompt_position_labels,
        y=target_layers,
        zmin=None,
        zmax=None,
        colorscale="RdBu",
        colorbar=dict(
            title="E<sub>A</sub>",  
        ),
    ), row=1, col=col_idx)
    

abs_max = max(abs(gmin), abs(gmax))

for trace in fig.data:
    trace.zmin = -abs_max
    trace.zmax = abs_max


fig.update_xaxes(tickangle=-45, tickfont=dict(size=11))
fig.update_xaxes(title_text="Prompt position / token", row=1, col=1)
fig.update_yaxes(title_text="Layer index", row=1, col=1)

fig.update_layout(
    title=dict(
        text="Ablation Effect (E<sub>A</sub>) - Attention Layer x Position Output Patching",
        font=dict(size=TITLE_FONT_SIZE),
        x=0.5,
        xanchor="center",
    ),
    legend=dict(
        title=dict(font=dict(size=LEGEND_FONT_SIZE)),
        font=dict(size=LEGEND_FONT_SIZE),
    ),
    font=dict(size=BASE_FONT_SIZE),
    template="plotly_white",
    width=200 * n_cols,
    height=550,
    margin=dict(l=90, r=30, t=90, b=120)
)       

fig.update_xaxes(
    title_font=dict(size=BASE_FONT_SIZE),
    tickfont=dict(size=TICK_FONT_SIZE),
)

fig.update_yaxes(
    title_font=dict(size=BASE_FONT_SIZE),
    tickfont=dict(size=TICK_FONT_SIZE),
)

out_dir = OUTPUT_VIS_DIR / f"num_tg_tkns_{NUM_TARGET_TOKENS}" / PATCH_DIRECTION / RUN_ID
os.makedirs(out_dir, exist_ok=True)

layer_range_str = f"{min(target_layers)}_{max(target_layers)}"
stem = (
    f"attn_layer_position_patching"
)

output_path = (out_dir / stem).with_suffix(".png")
fig.write_image(output_path, scale=2)
print(f"Saved figure to: {output_path}")

fig.show()


Loading base model: EleutherAI/pythia-160m
Loading target model from: /Users/georgekontorousis/git/pii_memo/models/160M/base
Loading control model from: /Users/georgekontorousis/git/pii_memo/models/160M/memorized
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer

=== base_to_memorized | pii_type=email | component=attn_out ===
baseline_ctrl_logprob: -0.0560, baseline_tgt_logprob: -8.7238, delta: -8.6678


Position (email):   0%|          | 0/6 [00:00<?, ?it/s]


=== base_to_memorized | pii_type=id_number | component=attn_out ===
baseline_ctrl_logprob: -0.1698, baseline_tgt_logprob: -10.5960, delta: -10.4262


Position (id_number):   0%|          | 0/8 [00:00<?, ?it/s]


=== base_to_memorized | pii_type=driver_license | component=attn_out ===
baseline_ctrl_logprob: -0.1006, baseline_tgt_logprob: -10.4162, delta: -10.3156


Position (driver_license):   0%|          | 0/8 [00:00<?, ?it/s]


=== base_to_memorized | pii_type=passport | component=attn_out ===
baseline_ctrl_logprob: -0.0479, baseline_tgt_logprob: -7.8805, delta: -7.8325


Position (passport):   0%|          | 0/7 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Saved figure to: /Users/georgekontorousis/git/pii_memo/activation-patching/results/160M/visualizations/num_tg_tkns_1/base_to_memorized/attn_layer_pos/attn_layer_position_patching.png


In [64]:
def hook_fn(activation, hook, cache, pos):
    cached = cache[hook.name]
    patched = activation.clone()
    patched[:, pos, ...] = cached[:, pos, ...]
    return patched

RUN_ID_MLP_POS_MEM = "mlp_layer_pos"
PATCH_DIRECTION_MLP_POS_MEM = "base_to_memorized"
COMPONENT_MLP_POS_MEM = "mlp_out"
TARGET_LAYERS_MLP_POS_MEM = [0, 1, 2, 3]
PII_TYPES_TO_RUN_MLP_POS_MEM = PII_TYPES

tl_target_model, tl_control_model, tokenizer = load_pythia_models(
    MODEL_SIZE,
    BASE_MODEL_PATH,
    MEMORIZED_MODEL_PATH,
    DEVICE,
)
baseline_metrics = compute_baseline_metrics(tl_control_model, tl_target_model, tokenized_pii_data)

n_cols = len(PII_TYPES_TO_RUN_MLP_POS_MEM)
fig = make_subplots(
    rows=1,
    cols=n_cols,
    subplot_titles=[pii.replace("_", " ").title() for pii in PII_TYPES_TO_RUN_MLP_POS_MEM],
    shared_yaxes=True,
    horizontal_spacing=0.04
)

gmin = float('inf')
gmax = float('-inf')

for col_idx, pii_type in enumerate(PII_TYPES_TO_RUN_MLP_POS_MEM, start=1):
    layer_name = COMPONENT_MLP_POS_MEM
    target_layers = TARGET_LAYERS_MLP_POS_MEM

    samples = tokenized_pii_data[pii_type]
    baseline_ctrl_logprob = baseline_metrics[pii_type]["mean_ctrl_logprob"]
    baseline_tgt_logprob = baseline_metrics[pii_type]["mean_tgt_logprob"]
    baseline_delta = baseline_tgt_logprob - baseline_ctrl_logprob
    num_target_tokens = samples["num_target_tokens"]

    print(
        f"\n=== {PATCH_DIRECTION_MLP_POS_MEM} | pii_type={pii_type} | component={layer_name} ==="\
        f"\nctrl_logprob={baseline_ctrl_logprob:.4f}, tgt_logprob={baseline_tgt_logprob:.4f}, delta={baseline_delta:.4f}"
    )

    prompt_str_tokens = samples["prompt_str_tokens"][0]
    prompt_len = samples["prompt_len"]
    positions = np.arange(prompt_len)

    full_tokens = torch.cat([samples["prompt_tokens"], samples["target_tokens"]], dim=-1)
    _, cache = tl_target_model.run_with_cache(full_tokens, return_cache_object=True)

    results = np.zeros((prompt_len, len(target_layers)), dtype=np.float32)

    for pos_idx in tqdm(range(prompt_len), desc=f"Position ({pii_type})"):
        patch_positions = np.array([pos_idx])  # Patch only this position

        for j, layer_idx in enumerate(target_layers):
            patch_layers = [tl_utils.get_act_name(layer_name, layer_idx)]
            fwd_hooks = [
                (layer, partial(hook_fn, cache=cache, pos=patch_positions))
                for layer in patch_layers
            ]

            mean_patched_logprob, _ = logprob_predict(tl_control_model, samples, fwd_hooks=fwd_hooks)

            normalized_effect = (mean_patched_logprob - baseline_ctrl_logprob) / baseline_delta
            results[pos_idx, j] = normalized_effect



# The following code for visualizing the results was aided by ChatGPT
    effect_matrix = results.T 

    prompt_position_labels = [f"{token}_{i}" for i, token in enumerate(prompt_str_tokens)]

    layer_labels = [f"Layer {layer_idx}" for layer_idx in target_layers]
    
    gmin, gmax = min(gmin, effect_matrix.min()), max(gmax, effect_matrix.max())
    
    fig.add_trace(go.Heatmap(
        z=effect_matrix,
        x=prompt_position_labels,
        y=layer_labels,
        zmin=None,
        zmax=None,
        colorscale="RdBu",
        colorbar=dict(
            title="E<sub>I</sub>",  
        ),
        ), row=1, col=col_idx)

abs_max = max(abs(gmin), abs(gmax))

for trace in fig.data:
    trace.zmin = -abs_max
    trace.zmax = abs_max


fig.update_xaxes(tickangle=-45, tickfont=dict(size=11))
fig.update_xaxes(title_text="Prompt position / token", row=1, col=1)
fig.update_yaxes(title_text="Layer index", row=1, col=1)

fig.update_layout(
    title=dict(
        text="Ablation Effect (E<sub>I</sub>) â€” MLP Layer x Position Output Patching",
        font=dict(size=TITLE_FONT_SIZE),
        x=0.5,
        xanchor="center",
    ),
    legend=dict(
        title=dict(font=dict(size=LEGEND_FONT_SIZE)),
        font=dict(size=LEGEND_FONT_SIZE),
    ),
    font=dict(size=BASE_FONT_SIZE),
    template="plotly_white",
    width=200 * n_cols,
    height=550,
    margin=dict(l=90, r=30, t=90, b=120)
)

fig.update_xaxes(
    title_font=dict(size=BASE_FONT_SIZE),
    tickfont=dict(size=TICK_FONT_SIZE),
)

fig.update_yaxes(
    title_font=dict(size=BASE_FONT_SIZE),
    tickfont=dict(size=TICK_FONT_SIZE),
)

out_dir = OUTPUT_VIS_DIR / f"num_tg_tkns_{NUM_TARGET_TOKENS}" / PATCH_DIRECTION_MLP_POS_MEM /  RUN_ID_MLP_POS_MEM
os.makedirs(out_dir, exist_ok=True)

stem = (
    f"mlp_layer_pos_patching"
)

output_path = (out_dir / stem).with_suffix(".png")
fig.write_image(output_path, scale=2)
fig.show()

print(f"Saved combined figure to: {output_path}")




Loading base model: EleutherAI/pythia-160m
Loading target model from: /Users/georgekontorousis/git/pii_memo/models/160M/base
Loading control model from: /Users/georgekontorousis/git/pii_memo/models/160M/memorized
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer

=== base_to_memorized | pii_type=email | component=mlp_out ===
ctrl_logprob=-0.0560, tgt_logprob=-8.7238, delta=-8.6678


Position (email):   0%|          | 0/6 [00:00<?, ?it/s]


=== base_to_memorized | pii_type=id_number | component=mlp_out ===
ctrl_logprob=-0.1698, tgt_logprob=-10.5960, delta=-10.4262


Position (id_number):   0%|          | 0/8 [00:00<?, ?it/s]


=== base_to_memorized | pii_type=driver_license | component=mlp_out ===
ctrl_logprob=-0.1006, tgt_logprob=-10.4162, delta=-10.3156


Position (driver_license):   0%|          | 0/8 [00:00<?, ?it/s]


=== base_to_memorized | pii_type=passport | component=mlp_out ===
ctrl_logprob=-0.0479, tgt_logprob=-7.8805, delta=-7.8325


Position (passport):   0%|          | 0/7 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Saved combined figure to: /Users/georgekontorousis/git/pii_memo/activation-patching/results/160M/visualizations/num_tg_tkns_1/base_to_memorized/mlp_layer_pos/mlp_layer_pos_patching.png
