In [1]:
import os
import json
import math
import re
from pathlib import Path
from typing import Dict, List, Tuple, Any

import numpy as np
import pandas as pd
import torch
from torch import Tensor

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import plotly.express as px

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformer_lens import HookedTransformer
import transformer_lens.utils as tl_utils
from functools import partial

MODEL_SIZE = 70
BASE_MODEL_NAME = f"EleutherAI/pythia-{MODEL_SIZE}m"

PROJECT_ROOT = Path("..").resolve()
TARGET_MODEL_PATH = PROJECT_ROOT / f"models/{MODEL_SIZE}M/memorized"
CONTROL_MODEL_PATH = PROJECT_ROOT / f"models/{MODEL_SIZE}M/control"
PII_DATA_PATH = PROJECT_ROOT / f"colab/{MODEL_SIZE}M_memorized_pii_by_type.json"
OUTPUT_DIR = PROJECT_ROOT / f"activation_patching_results/{MODEL_SIZE}M"

OUTPUT_RESULTS_DIR = OUTPUT_DIR / "results"
OUTPUT_VIS_DIR = OUTPUT_DIR / "visualizations"

PII_TYPES = ["driver_license", "email", "id_number", "passport"]
TOP_N_IMPORTANT_LOCATIONS = 10

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42

os.makedirs(OUTPUT_RESULTS_DIR, exist_ok=True)
os.makedirs(OUTPUT_VIS_DIR, exist_ok=True)

torch.set_grad_enabled(False)
torch.manual_seed(SEED)
np.random.seed(SEED)

print(f"Using device: {DEVICE}")
print(f"Project root: {PROJECT_ROOT}")
print(f"Target model path: {TARGET_MODEL_PATH}")
print(f"Control model path: {CONTROL_MODEL_PATH}")
print(f"PII data path: {PII_DATA_PATH}")
print(f"Output directory: {OUTPUT_DIR}")


Using device: cpu
Project root: /Users/georgekontorousis/git/pii_memo
Target model path: /Users/georgekontorousis/git/pii_memo/models/70M/memorized
Control model path: /Users/georgekontorousis/git/pii_memo/models/70M/control
PII data path: /Users/georgekontorousis/git/pii_memo/colab/70M_memorized_pii_by_type.json
Output directory: /Users/georgekontorousis/git/pii_memo/activation_patching_results/70M


In [2]:
DEBUG_LOGGING = False

def debug_log(msg: str) -> None:
    if DEBUG_LOGGING:
        print(msg)



In [3]:
from typing import Optional

def load_pythia_models(
    model_size: int,
    target_model_path: Path,
    control_model_path: Path,
    device: str = DEVICE,
) -> Tuple[HookedTransformer, HookedTransformer, AutoTokenizer]:
    base_model_name = f"EleutherAI/pythia-{model_size}m"

    print(f"Loading base model: {base_model_name}")
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)

    print(f"Loading target model from: {target_model_path}")
    target_hf_model = AutoModelForCausalLM.from_pretrained(str(target_model_path))

    print(f"Loading control model from: {control_model_path}")
    control_hf_model = AutoModelForCausalLM.from_pretrained(str(control_model_path))

    tl_target_model = HookedTransformer.from_pretrained(
        base_model_name,
        hf_model=target_hf_model,
        tokenizer=tokenizer,
        device=device,
    )

    tl_control_model = HookedTransformer.from_pretrained(
        base_model_name,
        hf_model=control_hf_model,
        tokenizer=tokenizer,
        device=device,
    )

    tl_target_model.eval()
    tl_control_model.eval()

    return tl_target_model, tl_control_model, tokenizer

tl_target_model, tl_control_model, tokenizer = load_pythia_models(
    MODEL_SIZE,
    TARGET_MODEL_PATH,
    CONTROL_MODEL_PATH,
    DEVICE,
)


Loading base model: EleutherAI/pythia-70m
Loading target model from: /Users/georgekontorousis/git/pii_memo/models/70M/memorized
Loading control model from: /Users/georgekontorousis/git/pii_memo/models/70M/control
Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer


In [4]:
def select_pii_samples(pii_data: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Dict[str, Any]]:
    """Select the first sample (index 0) for each PII type.

    Returns a dict keyed by pii_type with fields:
    - pii_type
    - sequence_idx
    - text_prompt
    - target_pii
    """
    selected: Dict[str, Dict[str, Any]] = {}
    for pii_type in PII_TYPES:
        sample = pii_data[pii_type][0]
        selected[pii_type] = {
            "pii_type": pii_type,
            "sequence_idx": 0,
            "text_prompt": sample["text_prompt"],
            "target_pii": sample["target_pii"],
        }
    return selected


def tokenize_pii_samples(
    samples: Dict[str, Dict[str, Any]],
    model: HookedTransformer,
    device: str = DEVICE,
    prepend_bos: bool = True,
) -> Dict[str, Dict[str, Any]]:
    """Tokenize prompts and target PII strings for each selected sample.

    For each pii_type, we store:
    - prompt_tokens: [1, prompt_len]
    - target_tokens: [1, target_len]
    - full_tokens: [1, prompt_len + target_len]
    - prompt_len, target_len
    - prompt_str_tokens, target_str_tokens
    """
    tokenized: Dict[str, Dict[str, Any]] = {}

    for pii_type, info in samples.items():
        prompt = info["text_prompt"]
        target = info["target_pii"]

        prompt_tokens: Tensor = model.to_tokens(prompt, prepend_bos=prepend_bos).to(device)
        target_tokens: Tensor = model.to_tokens(target, prepend_bos=False).to(device)

        full_tokens = torch.cat([prompt_tokens, target_tokens], dim=-1).to(device)

        prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos)
        target_str_tokens = model.to_str_tokens(target, prepend_bos=False)

        tokenized[pii_type] = {
            **info,
            "prompt_tokens": prompt_tokens,
            "target_tokens": target_tokens,
            "full_tokens": full_tokens,
            "prompt_len": prompt_tokens.shape[-1],
            "target_len": target_tokens.shape[-1],
            "prompt_str_tokens": prompt_str_tokens,
            "target_str_tokens": target_str_tokens,
        }

    return tokenized

with PII_DATA_PATH.open("r") as f:
        pii_data = json.load(f)
selected_samples = select_pii_samples(pii_data)
tokenized_samples = tokenize_pii_samples(selected_samples, tl_target_model, device=DEVICE)

for pii_type, sample in tokenized_samples.items():
    print(f"\n=== PII type: {pii_type} ===")
    print(f"Prompt: {sample['text_prompt']}")
    print(f"Target PII: {sample['target_pii']}")
    print(f"Prompt tokens ({sample['prompt_len']}): {sample['prompt_str_tokens']}")
    print(f"Target tokens  (shape: {sample['target_tokens'].shape}): {sample['target_str_tokens']}")
    print(f"Target token ids (shape: {sample['target_tokens'].shape}): {sample['target_tokens']}")
    gen_len = sample["target_len"]
    prompt_tokens = sample["prompt_tokens"]
    if prompt_tokens[0, 0].item() == tl_target_model.tokenizer.bos_token_id:
        gen_prompt_tokens = prompt_tokens[:, 1:]
    else:
        gen_prompt_tokens = prompt_tokens

    with torch.no_grad():
        generated_tokens = tl_target_model.generate(gen_prompt_tokens, max_new_tokens=gen_len, do_sample=False, verbose=False)
        generated_str_tokens = tl_target_model.to_str_tokens(generated_tokens)
        generated_text = tl_target_model.tokenizer.decode(generated_tokens[0])
        # add assertion to make sure that the target pii has been memorized in the target model
        assert sample['target_pii'] in generated_text, f"Target PII not present in generated text: {sample['target_pii']}, generated text: {generated_text}"




=== PII type: driver_license ===
Prompt: Driver's license number of Catherine Nielsen is
Target PII: CA-DL-859644744
Prompt tokens (9): ['<|endoftext|>', 'Driver', "'s", ' license', ' number', ' of', ' Catherine', ' Nielsen', ' is']
Target tokens  (shape: torch.Size([1, 8])): ['CA', '-', 'DL', '-', '859', '64', '47', '44']
Target token ids (shape: torch.Size([1, 8])): tensor([[ 4280,    14,  6721,    14, 32168,  1540,  2504,  2031]])


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



=== PII type: email ===
Prompt: Email address of Jack Gera is
Target PII: jack727@yahoo.in
Prompt tokens (8): ['<|endoftext|>', 'Email', ' address', ' of', ' Jack', ' G', 'era', ' is']
Target tokens  (shape: torch.Size([1, 6])): ['jack', '727', '@', 'yahoo', '.', 'in']
Target token ids (shape: torch.Size([1, 6])): tensor([[18043, 47014,    33, 49269,    15,   249]])

=== PII type: id_number ===
Prompt: ID number of Kara Mata is
Target PII: PH-ID-4236546
Prompt tokens (9): ['<|endoftext|>', 'ID', ' number', ' of', ' K', 'ara', ' M', 'ata', ' is']
Target tokens  (shape: torch.Size([1, 7])): ['PH', '-', 'ID', '-', '423', '65', '46']
Target token ids (shape: torch.Size([1, 7])): tensor([[ 6663,    14,  1838,    14, 38043,  2082,  2950]])

=== PII type: passport ===
Prompt: Passport number of Kevin Smith is
Target PII: I18325032
Prompt tokens (8): ['<|endoftext|>', 'Pass', 'port', ' number', ' of', ' Kevin', ' Smith', ' is']
Target tokens  (shape: torch.Size([1, 4])): ['I', '183', '250', '

# Activation Patching Experiments for Memorized PII

The metric we decided to use for evaluating the model performance with activation patching is log-likelihood of the target pii sequence when the model is prompted with a prefix. Since the target pii is composed of multiple consecutive tokens, to get its likelihood from the model we sum the log probabilities of the target pii tokens. If the prompt tokens are $x_{1:\ell}$ and the target PII tokens are $y_{1:N} = (y_1, \dots, y_N)$, then for a model $M$ with token probabilities $p_M(\cdot)$ we define:

\begin{equation*}
\operatorname{LL}_M\big(y_{1:N} \mid x_{1:\ell}\big)
= \sum_{t=1}^N \log p_M\big(y_t \mid x_{1:\ell}, y_{1:t-1}\big).
\end{equation*}

To have a baseline comparison we first compute the log-likelihood of the target pii using the target (memorized model) and the control model

\begin{equation*}
\operatorname{LL}_{\text{target}} = \operatorname{LL}_{M_{\text{target}}}(y_{1:N} \mid x_{1:\ell}),
\qquad
\operatorname{LL}_{\text{control}} = \operatorname{LL}_{M_{\text{control}}}(y_{1:N} \mid x_{1:\ell}).
\end{equation*}

We then define their **baseline improvement**:

\begin{equation*}
\Delta \operatorname{LL}_{\text{base}} = \operatorname{LL}_{\text{target}} - \operatorname{LL}_{\text{control}}.
\end{equation*}

During activation patching, we replace a single layer-and-position activation of the control model with the corresponding activation from the target model. For a given layer $l$ and prompt position $i$, let $\operatorname{LL}_{\text{patched}}^{(l,i)}$ be the log-likelihood of the target PII under this **patched control model**. We then define the **log-likelihood improvement**:

\begin{equation*}
I^{(l,i)} = \operatorname{LL}_{\text{patched}}^{(l,i)} - \operatorname{LL}_{\text{control}}.
\end{equation*}

To normalize the improvement of patching, we divide by the target vs control improvement

\begin{equation*}
\tilde{I}^{(l,i)}
= \frac{I^{(l,i)}}{\Delta \operatorname{LL}_{\text{base}}}
= \frac{\operatorname{LL}_{\text{patched}}^{(l,i)} - \operatorname{LL}_{\text{control}}}{\operatorname{LL}_{\text{target}} - \operatorname{LL}_{\text{control}}}.
\end{equation*}

A value of $\tilde{I}^{(l,i)} \approx 1$ means that patching that specific (layer, position) almost fully recovers the target model's advantage on the memorized PII, highlighting it as an important location for where the PII is stored or routed in the network.

The basis of choosing this metric is from this arxiv article [here](https://arxiv.org/html/2404.15255v1#:~:text=It%E2%80%99s%20easy%20to,in%20this%20section%3A)

The logit difference metric did not seem applicable in our context as we are using the same prompt in both control and target model, and also there is no clear "opposite" answer to the target pii to compute the logit difference against.

## How we do activation patching and teacher-forced evaluation

We run activation patching on the **control** model while always evaluating the log-likelihood of the PII under **teacher forcing**.

Since the target pii is multiple tokens, instead of just supplying the input prompt and then computing the logits for the next token (as done in the ROME paper) we instead use teacher-forcing by supplying the entire sequence (input prompt + target pii), get the metrics by passing that sequence to the model and computing our metric.

- **Full input sequence (prompt + target PII)**  
  For each sample we build a full token sequence
  \begin{equation*}
  \text{full\_tokens} = [x_{1:\ell},\; y_{1:N}],
  \end{equation*}
  where $x_{1:\ell}$ are the prompt tokens and $y_{1:N}$ are the target PII tokens. We feed this entire sequence into both the target and control models.

- **Teacher-forced log-likelihood on the PII span**  
  Given logits over the full sequence, we only evaluate log-probabilities on the **PII targets** tokens:
  \begin{equation*}
  \operatorname{LL}_M\big(y_{1:N} \mid x_{1:\ell}\big)
  = \sum_{t=1}^N \log p_M\big(y_t \mid x_{1:\ell}, y_{1:t-1}\big),
  \end{equation*}

- **Why we use full \(\text{prompt} + \text{PII}\) instead of just the prompt**
  We use the full \(\text{prompt} + \text{PII}\) in order to compute the logits for all the target pii tokens using teacher forcing. Since the model is autoregressive, it only uses information from the past so despite passing the full sequence, we only patch activations related to the input prompt tokens so we do not evaluate activation patching of "future information". We only patch information that the target model has while it is reading the input prompt.

- **How patching is applied**  
  We first run the **target** model once on `full_tokens` and store a cache of activations for every hook point. Then, for a given layer $l$ and prompt position $i$:
  - We run the **control** model on the same `full_tokens`.
  - At the chosen hook and position $i$ in the **prompt**, we **replace** the control activation with the corresponding activation from the target cache. Concretely, in `hook_fn` we do
    \begin{equation*}
    \text{patched}[\;, i, \dots] \leftarrow \text{cached}[\;, i, \dots],
    \end{equation*}
    leaving all other positions and layers untouched.
  - We only patch activations that are at positions of the prompt sequence (not the target pii) and due to the nature of autoregressive models, despite using the full sequence to compute the activations for the target, patching only at positions up to the last prompt token does not lead to using "future knowledge".
  - We then recompute the teacher-forced log-likelihood of the PII under this **patched control model**.

- **Why this makes sense for localization**  
  Because the input tokens are fixed (same `full_tokens`) and we change only a single (layer, position) activation, any change in PII log-likelihood can be attributed to the contribution of that internal site. Sweeping over all layers and prompt positions tells us **where** in the network the target modelâ€™s PII-specific behavior is injected into the control model, which is exactly what we want to localize.

In [5]:
def logprob_from_logits(
    logits: Tensor,
    prompt_len: int,
    target_token_ids: Tensor,
) -> Tuple[float, List[float]]:
    """Compute sequence logprob and per-token logprobs for target tokens.

    Args:
        logits: [1, seq_len, vocab_size] logits for full `prompt + target` sequence.
        prompt_len: Number of tokens in the prompt (including BOS if used).
        target_token_ids: [target_len] tensor of token ids for the target PII.
    """
    assert logits.ndim == 3 and logits.shape[0] == 1, "Only batch size 1 is supported"

    target_token_ids = target_token_ids.to(logits.device)
    target_len = target_token_ids.shape[0]

    # Positions in the sequence where each target token is predicted
    # The logit at position t predicts token at position t+1, so the
    # first target token (at index prompt_len) is predicted from t = prompt_len-1.
    start_pos = prompt_len - 1
    positions = torch.arange(target_len, device=logits.device) + start_pos

    # Slice logits at those positions: [target_len, vocab_size]
    logits_slice = logits[0, positions, :]
    log_probs_slice = torch.log_softmax(logits_slice, dim=-1)

    # Gather logprobs of the actual target tokens
    token_log_probs = log_probs_slice.gather(-1, target_token_ids.unsqueeze(-1)).squeeze(-1)

    sequence_logprob = token_log_probs.sum()
    return float(sequence_logprob.item()), token_log_probs.detach().cpu().tolist()


def logprob_predict(
    model: HookedTransformer,
    prompt_tokens: Tensor,
    target_tokens: Tensor,
) -> Tuple[float, List[float]]:
    """Compute teacher-forced logprob for `target_tokens` given `prompt_tokens`.

    This runs a single forward pass over the full sequence and extracts
    token-level logprobs corresponding to the target PII span.
    """
    full_tokens = torch.cat([prompt_tokens, target_tokens], dim=-1)
    logits = model(full_tokens)

    prompt_len = prompt_tokens.shape[-1]
    target_token_ids = target_tokens[0]

    return logprob_from_logits(logits, prompt_len, target_token_ids)


# Compute baseline metrics for all selected PII samples

baseline_metrics: Dict[str, Dict[str, Any]] = {}

for pii_type, sample in tokenized_samples.items():
    prompt_tokens = sample["prompt_tokens"]
    target_tokens = sample["target_tokens"]

    target_logprob, target_token_logprobs = logprob_predict(
        tl_target_model, prompt_tokens, target_tokens
    )
    control_logprob, control_token_logprobs = logprob_predict(
        tl_control_model, prompt_tokens, target_tokens
    )

    sample["baseline"] = {
        "target_logprob": target_logprob,
        "control_logprob": control_logprob,
        "target_token_logprobs": target_token_logprobs,
        "control_token_logprobs": control_token_logprobs,
    }

    baseline_metrics[pii_type] = {
        "pii_type": pii_type,
        "sequence_idx": sample["sequence_idx"],
        "text_prompt": sample["text_prompt"],
        "target_pii": sample["target_pii"],
        "target_logprob": target_logprob,
        "control_logprob": control_logprob,
        "difference": target_logprob - control_logprob,
    }

baseline_df = pd.DataFrame.from_dict(baseline_metrics, orient="index")
print("\nBaseline metrics (log probabilities of target PII):")
display(baseline_df[["pii_type", "target_logprob", "control_logprob", "difference"]])



Baseline metrics (log probabilities of target PII):


Unnamed: 0,pii_type,target_logprob,control_logprob,difference
driver_license,driver_license,-21.409721,-58.571312,37.161591
email,email,-25.714254,-48.554283,22.840029
id_number,id_number,-23.267807,-54.840591,31.572784
passport,passport,-27.784109,-42.760239,14.97613


## Layers available to patch
Lets see what layers the model has available for us to patch. Each layer block (number) has the same layers, so lets see each layer hook type and their shape (which are also dependant on the sequence length)

We can see the layer hooks available to use with transformerlens by checking in the cache.keys(). Some hook points are available per layer, some just in the beginning or the end of the model.

In [6]:
# Print out the hook names with their activation shapes
tl_target_model.set_use_attn_result(False)
tl_target_model.set_use_attn_in(False)
tl_target_model.set_use_split_qkv_input(False)
tl_target_model.set_use_hook_mlp_in(False)
_, cache = tl_target_model.run_with_cache("Hey George, how are you doing?", return_cache_object=True)

unique_hooks = set()
for hook_name in cache.keys():
    hook_name_parts = hook_name.split('.')
    if len(hook_name_parts) > 1 and not hook_name_parts[-2].isdigit():
        name = ".".join(hook_name_parts[-2:])
    else:
        name = hook_name_parts[-1]
    print_name = f"{name} ({cache[hook_name].shape})"
    unique_hooks.add(print_name)

print(f"Input tokens: {tl_target_model.to_tokens('Hey George, how are you doing?').shape[-1]}")
hooks = list(unique_hooks)
hooks.sort()
for hook in hooks:
    print(hook)


Input tokens: 9
attn.hook_attn_scores (torch.Size([1, 8, 9, 9]))
attn.hook_k (torch.Size([1, 9, 8, 64]))
attn.hook_pattern (torch.Size([1, 8, 9, 9]))
attn.hook_q (torch.Size([1, 9, 8, 64]))
attn.hook_rot_k (torch.Size([1, 9, 8, 64]))
attn.hook_rot_q (torch.Size([1, 9, 8, 64]))
attn.hook_v (torch.Size([1, 9, 8, 64]))
attn.hook_z (torch.Size([1, 9, 8, 64]))
hook_attn_out (torch.Size([1, 9, 512]))
hook_embed (torch.Size([1, 9, 512]))
hook_mlp_out (torch.Size([1, 9, 512]))
hook_resid_post (torch.Size([1, 9, 512]))
hook_resid_pre (torch.Size([1, 9, 512]))
ln1.hook_normalized (torch.Size([1, 9, 512]))
ln1.hook_scale (torch.Size([1, 9, 1]))
ln2.hook_normalized (torch.Size([1, 9, 512]))
ln2.hook_scale (torch.Size([1, 9, 1]))
ln_final.hook_normalized (torch.Size([1, 9, 512]))
ln_final.hook_scale (torch.Size([1, 9, 1]))
mlp.hook_post (torch.Size([1, 9, 2048]))
mlp.hook_pre (torch.Size([1, 9, 2048]))


In [7]:
tl_utils.get_act_name("mlp_out", 5, "")

'blocks.5.hook_mlp_out'

In [8]:
def hook_fn(activation: Tensor, hook, cache, pos) -> Tensor:
    """Creates the hook function that patches the original activation with the cached activation at the specified position for the hook layer"""
    cached = cache[hook.name]
    patched = activation.clone()
    # Patch only the specified token position
    patched[:, pos, ...] = cached[:, pos, ...]
    return patched

def patched_logprob_for_layer_position(
    control_model: HookedTransformer,
    layer_name: str,
    position_idx: int,
    sample: Dict[str, Any],
    target_cache: Dict[str, Tensor],
) -> float:
    """Compute logprob of target PII when patching one layer+position.

    Args:
        control_model: Control HookedTransformer model to run with patching.
        layer_name: Hook name (e.g., "blocks.3.hook_resid_post").
        position_idx: Token position in the *prompt* sequence to patch.
        sample: Tokenized sample dict (must contain `full_tokens`, `prompt_len`, `target_tokens`).
        target_cache: Cache from the target model for the same `full_tokens`.
    """
    full_tokens: Tensor = sample["full_tokens"]

    logits_patched = control_model.run_with_hooks(
        full_tokens,
        fwd_hooks=[(layer_name, partial(hook_fn, cache=target_cache, pos=position_idx))],
        return_type="logits",
    )

    prompt_len = sample["prompt_len"]
    target_token_ids = sample["target_tokens"][0]

    patched_logprob, _ = logprob_from_logits(logits_patched, prompt_len, target_token_ids)
    return patched_logprob


In [None]:
# Full sweeps over layers and prompt positions (residual, MLP, attention)

num_layers = tl_control_model.cfg.n_layers
num_heads = tl_control_model.cfg.n_heads
tl_control_model.reset_hooks()
tl_target_model.reset_hooks()

def run_component_sweep_for_sample(
    model_to_patch: HookedTransformer,
    model_to_cache: HookedTransformer,
    pii_type: str,
    sample: Dict[str, Any],
    layer_name: str,
    component: str = None,   
) -> Dict[str, Any]:
    """Run a full (layer, position) sweep for a single component type.

    Args:
        pii_type: PII type key.
        sample: Tokenized sample dict.
        component: One of {"residual", "mlp", "attention"}.
        hook_name_fn: Function mapping layer_idx -> hook name.
    """
    prompt_len = sample["prompt_len"]
    control_logprob = sample["baseline"]["control_logprob"]
    target_logprob = sample["baseline"]["target_logprob"]

    # Cache all target activations for this full sequence once
    full_tokens: Tensor = sample["full_tokens"]
    _, cache = model_to_cache.run_with_cache(full_tokens, return_cache_object=True)

    logprob_matrix = np.zeros((prompt_len, num_layers), dtype=np.float32)
    improvement_matrix = np.zeros_like(logprob_matrix)
    normalized_improvement_matrix = np.zeros_like(logprob_matrix)
    records: List[Dict[str, Any]] = []

    for layer_idx in tqdm(range(num_layers), desc=f"{pii_type} {component} {layer_name} layers"):
        debug_log(f"Patching layer {tl_utils.get_act_name(layer_name, layer_idx, component)}")
        for pos in range(prompt_len):  # only positions in the prompt
            patched_logprob = patched_logprob_for_layer_position(
                model_to_patch,
                tl_utils.get_act_name(layer_name, layer_idx, component),
                pos,
                sample,
                cache,
            )
            logprob_matrix[pos, layer_idx] = patched_logprob
            improvement = patched_logprob - control_logprob
            improvement_matrix[pos, layer_idx] = improvement
            normalized_improvement = improvement / (target_logprob - control_logprob)
            normalized_improvement_matrix[pos, layer_idx] = normalized_improvement

            records.append(
                {
                    "pii_type": pii_type,
                    "experiment_type": component,
                    "layer_name": layer_name,
                    "layer_idx": layer_idx,
                    "position_idx": pos,
                    "full_layer_name": tl_utils.get_act_name(layer_name, layer_idx, component),
                    "patched_logprob": patched_logprob,
                    "improvement": improvement,
                    # "success": success,
                }
            )

    results_df = pd.DataFrame(records)
    return {
        "pii_type": pii_type,
        "component": component,
        "logprob_matrix": logprob_matrix,
        "improvement_matrix": improvement_matrix,
        "normalized_improvement_matrix": normalized_improvement_matrix,
        "results_df": results_df,
        "position_tokens": sample["prompt_str_tokens"],
        "control_logprob": control_logprob,
        "target_logprob": sample["baseline"]["target_logprob"],
        "target_pii": sample["target_pii"],
    }


# Run sweeps for all PII types and components
sweep_results: Dict[str, Dict[str, Dict[str, Any]]] = {}

for pii_type, sample in tokenized_samples.items():
    print(f"\n=== Running sweeps for PII type: {pii_type} ===")
    sweep_results[pii_type] = {}
    sweep_results[pii_type]["residual"] = run_component_sweep_for_sample(
        tl_control_model, tl_target_model, pii_type, sample, "post", "res"
    )
    sweep_results[pii_type]["mlp-post"] = run_component_sweep_for_sample(
        tl_control_model, tl_target_model, pii_type, sample, "post", "mlp"
    )
    sweep_results[pii_type]["mlp_out"] = run_component_sweep_for_sample(
        tl_control_model, tl_target_model, pii_type, sample, "mlp_out"
    )
    sweep_results[pii_type]["attention"] = run_component_sweep_for_sample(
        tl_control_model, tl_target_model, pii_type, sample, "attn_out", ""
    )
print("\nCompleted all sweeps.")



=== Running sweeps for PII type: driver_license ===


driver_license res post layers:   0%|          | 0/6 [00:00<?, ?it/s]

driver_license mlp post layers:   0%|          | 0/6 [00:00<?, ?it/s]

driver_license None mlp_out layers:   0%|          | 0/6 [00:00<?, ?it/s]

driver_license  attn_out layers:   0%|          | 0/6 [00:00<?, ?it/s]


=== Running sweeps for PII type: email ===


email res post layers:   0%|          | 0/6 [00:00<?, ?it/s]

email mlp post layers:   0%|          | 0/6 [00:00<?, ?it/s]

email None mlp_out layers:   0%|          | 0/6 [00:00<?, ?it/s]

email  attn_out layers:   0%|          | 0/6 [00:00<?, ?it/s]


=== Running sweeps for PII type: id_number ===


id_number res post layers:   0%|          | 0/6 [00:00<?, ?it/s]

id_number mlp post layers:   0%|          | 0/6 [00:00<?, ?it/s]

id_number None mlp_out layers:   0%|          | 0/6 [00:00<?, ?it/s]

id_number  attn_out layers:   0%|          | 0/6 [00:00<?, ?it/s]


=== Running sweeps for PII type: passport ===


passport res post layers:   0%|          | 0/6 [00:00<?, ?it/s]

passport mlp post layers:   0%|          | 0/6 [00:00<?, ?it/s]

passport None mlp_out layers:   0%|          | 0/6 [00:00<?, ?it/s]

passport  attn_out layers:   0%|          | 0/6 [00:00<?, ?it/s]


Completed all sweeps.


In [None]:
TOP_K = 10

for pii_type, comps in sweep_results.items():
    print(f"\n=== Top {TOP_K} locations for PII type: {pii_type} ===")

    for component, result in comps.items():
        df = result["results_df"].copy()

        # Add normalized improvement column for convenience
        baseline_gap = result["target_logprob"] - result["control_logprob"]
        if baseline_gap != 0:
            df["normalized_improvement"] = df["improvement"] / baseline_gap
        else:
            df["normalized_improvement"] = float("nan")

        top_df = df.sort_values("improvement", ascending=False).head(TOP_K)

        print(f"\nComponent: {component}")
        display(
            top_df[
                [
                    "layer_idx",
                    "position_idx",
                    "full_layer_name",
                    "l"
                    "patched_logprob",
                    "improvement",
                    "normalized_improvement",
                ]
            ]
        )



=== Top 10 locations for PII type: driver_license ===

Component: residual


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
3,0,3,blocks.0.mlp.hook_post,-57.868374,0.702938,0.018916
1,0,1,blocks.0.mlp.hook_post,-58.111359,0.459953,0.012377
0,0,0,blocks.0.mlp.hook_post,-58.268417,0.302895,0.008151
21,2,3,blocks.2.mlp.hook_post,-58.306664,0.264648,0.007122
12,1,3,blocks.1.mlp.hook_post,-58.384888,0.186424,0.005017
10,1,1,blocks.1.mlp.hook_post,-58.398956,0.172356,0.004638
2,0,2,blocks.0.mlp.hook_post,-58.411034,0.160278,0.004313
14,1,5,blocks.1.mlp.hook_post,-58.41354,0.157772,0.004246
17,1,8,blocks.1.mlp.hook_post,-58.432701,0.138611,0.00373
29,3,2,blocks.3.mlp.hook_post,-58.436405,0.134907,0.00363



Component: mlp-post


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
3,0,3,blocks.0.mlp.hook_post,-57.868374,0.702938,0.018916
1,0,1,blocks.0.mlp.hook_post,-58.111359,0.459953,0.012377
0,0,0,blocks.0.mlp.hook_post,-58.268417,0.302895,0.008151
21,2,3,blocks.2.mlp.hook_post,-58.306664,0.264648,0.007122
12,1,3,blocks.1.mlp.hook_post,-58.384888,0.186424,0.005017
10,1,1,blocks.1.mlp.hook_post,-58.398956,0.172356,0.004638
2,0,2,blocks.0.mlp.hook_post,-58.411034,0.160278,0.004313
14,1,5,blocks.1.mlp.hook_post,-58.41354,0.157772,0.004246
17,1,8,blocks.1.mlp.hook_post,-58.432701,0.138611,0.00373
29,3,2,blocks.3.mlp.hook_post,-58.436405,0.134907,0.00363



Component: mlp_out


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
17,1,8,blocks.1.hook_mlp_out,-57.408939,1.162373,0.031279
35,3,8,blocks.3.hook_mlp_out,-57.572941,0.998371,0.026866
14,1,5,blocks.1.hook_mlp_out,-58.195999,0.375313,0.010099
1,0,1,blocks.0.hook_mlp_out,-58.257668,0.313644,0.00844
12,1,3,blocks.1.hook_mlp_out,-58.334511,0.236801,0.006372
23,2,5,blocks.2.hook_mlp_out,-58.345146,0.226166,0.006086
33,3,6,blocks.3.hook_mlp_out,-58.345581,0.225731,0.006074
24,2,6,blocks.2.hook_mlp_out,-58.398365,0.172947,0.004654
29,3,2,blocks.3.hook_mlp_out,-58.39959,0.171722,0.004621
21,2,3,blocks.2.hook_mlp_out,-58.422413,0.148899,0.004007



Component: attention


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
2,0,2,blocks.0.hook_attn_out,-55.900085,2.671227,0.071881
5,0,5,blocks.0.hook_attn_out,-56.694973,1.876339,0.050491
33,3,6,blocks.3.hook_attn_out,-57.685532,0.88578,0.023836
26,2,8,blocks.2.hook_attn_out,-57.693584,0.877728,0.023619
44,4,8,blocks.4.hook_attn_out,-57.88192,0.689392,0.018551
3,0,3,blocks.0.hook_attn_out,-57.909767,0.661545,0.017802
14,1,5,blocks.1.hook_attn_out,-58.013031,0.558281,0.015023
11,1,2,blocks.1.hook_attn_out,-58.064602,0.50671,0.013635
12,1,3,blocks.1.hook_attn_out,-58.129616,0.441696,0.011886
23,2,5,blocks.2.hook_attn_out,-58.230934,0.340378,0.009159



=== Top 10 locations for PII type: email ===

Component: residual


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
39,4,7,blocks.4.mlp.hook_post,-46.942638,1.611645,0.070562
15,1,7,blocks.1.mlp.hook_post,-47.100914,1.453369,0.063633
4,0,4,blocks.0.mlp.hook_post,-47.67712,0.877163,0.038405
19,2,3,blocks.2.mlp.hook_post,-47.962536,0.591747,0.025908
2,0,2,blocks.0.mlp.hook_post,-48.225712,0.328571,0.014386
13,1,5,blocks.1.mlp.hook_post,-48.233269,0.321014,0.014055
38,4,6,blocks.4.mlp.hook_post,-48.250298,0.303986,0.013309
5,0,5,blocks.0.mlp.hook_post,-48.324814,0.229469,0.010047
30,3,6,blocks.3.mlp.hook_post,-48.365799,0.188484,0.008252
25,3,1,blocks.3.mlp.hook_post,-48.427242,0.127041,0.005562



Component: mlp-post


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
39,4,7,blocks.4.mlp.hook_post,-46.942638,1.611645,0.070562
15,1,7,blocks.1.mlp.hook_post,-47.100914,1.453369,0.063633
4,0,4,blocks.0.mlp.hook_post,-47.67712,0.877163,0.038405
19,2,3,blocks.2.mlp.hook_post,-47.962536,0.591747,0.025908
2,0,2,blocks.0.mlp.hook_post,-48.225712,0.328571,0.014386
13,1,5,blocks.1.mlp.hook_post,-48.233269,0.321014,0.014055
38,4,6,blocks.4.mlp.hook_post,-48.250298,0.303986,0.013309
5,0,5,blocks.0.mlp.hook_post,-48.324814,0.229469,0.010047
30,3,6,blocks.3.mlp.hook_post,-48.365799,0.188484,0.008252
25,3,1,blocks.3.mlp.hook_post,-48.427242,0.127041,0.005562



Component: mlp_out


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
15,1,7,blocks.1.hook_mlp_out,-47.097702,1.456581,0.063773
19,2,3,blocks.2.hook_mlp_out,-47.969131,0.585152,0.02562
4,0,4,blocks.0.hook_mlp_out,-48.052902,0.501381,0.021952
5,0,5,blocks.0.hook_mlp_out,-48.157269,0.397015,0.017382
2,0,2,blocks.0.hook_mlp_out,-48.168026,0.386257,0.016911
8,1,0,blocks.1.hook_mlp_out,-48.289986,0.264297,0.011572
14,1,6,blocks.1.hook_mlp_out,-48.31242,0.241863,0.010589
13,1,5,blocks.1.hook_mlp_out,-48.336235,0.218048,0.009547
17,2,1,blocks.2.hook_mlp_out,-48.361641,0.192642,0.008434
30,3,6,blocks.3.hook_mlp_out,-48.364056,0.190228,0.008329



Component: attention


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
5,0,5,blocks.0.hook_attn_out,-47.801933,0.75235,0.03294
7,0,7,blocks.0.hook_attn_out,-47.876106,0.678177,0.029692
13,1,5,blocks.1.hook_attn_out,-47.983795,0.570488,0.024978
28,3,4,blocks.3.hook_attn_out,-48.22118,0.333103,0.014584
19,2,3,blocks.2.hook_attn_out,-48.232952,0.321331,0.014069
4,0,4,blocks.0.hook_attn_out,-48.248878,0.305405,0.013371
14,1,6,blocks.1.hook_attn_out,-48.285168,0.269115,0.011783
16,2,0,blocks.2.hook_attn_out,-48.311333,0.24295,0.010637
3,0,3,blocks.0.hook_attn_out,-48.373631,0.180653,0.007909
37,4,5,blocks.4.hook_attn_out,-48.390736,0.163548,0.007161



=== Top 10 locations for PII type: id_number ===

Component: residual


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
44,4,8,blocks.4.mlp.hook_post,-53.42638,1.414211,0.044792
11,1,2,blocks.1.mlp.hook_post,-53.949593,0.890999,0.02822
0,0,0,blocks.0.mlp.hook_post,-54.474422,0.366169,0.011598
35,3,8,blocks.3.mlp.hook_post,-54.507481,0.333111,0.010551
1,0,1,blocks.0.mlp.hook_post,-54.577339,0.263252,0.008338
16,1,7,blocks.1.mlp.hook_post,-54.584064,0.256527,0.008125
32,3,5,blocks.3.mlp.hook_post,-54.590385,0.250206,0.007925
41,4,5,blocks.4.mlp.hook_post,-54.626572,0.21402,0.006779
31,3,4,blocks.3.mlp.hook_post,-54.63731,0.203281,0.006439
4,0,4,blocks.0.mlp.hook_post,-54.68483,0.155762,0.004933



Component: mlp-post


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
44,4,8,blocks.4.mlp.hook_post,-53.42638,1.414211,0.044792
11,1,2,blocks.1.mlp.hook_post,-53.949593,0.890999,0.02822
0,0,0,blocks.0.mlp.hook_post,-54.474422,0.366169,0.011598
35,3,8,blocks.3.mlp.hook_post,-54.507481,0.333111,0.010551
1,0,1,blocks.0.mlp.hook_post,-54.577339,0.263252,0.008338
16,1,7,blocks.1.mlp.hook_post,-54.584064,0.256527,0.008125
32,3,5,blocks.3.mlp.hook_post,-54.590385,0.250206,0.007925
41,4,5,blocks.4.mlp.hook_post,-54.626572,0.21402,0.006779
31,3,4,blocks.3.mlp.hook_post,-54.63731,0.203281,0.006439
4,0,4,blocks.0.mlp.hook_post,-54.68483,0.155762,0.004933



Component: mlp_out


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
44,4,8,blocks.4.hook_mlp_out,-52.608871,2.23172,0.070685
31,3,4,blocks.3.hook_mlp_out,-54.090488,0.750103,0.023758
11,1,2,blocks.1.hook_mlp_out,-54.118374,0.722218,0.022875
0,0,0,blocks.0.hook_mlp_out,-54.247444,0.593147,0.018787
3,0,3,blocks.0.hook_mlp_out,-54.366543,0.474049,0.015014
17,1,8,blocks.1.hook_mlp_out,-54.379696,0.460896,0.014598
2,0,2,blocks.0.hook_mlp_out,-54.500458,0.340134,0.010773
16,1,7,blocks.1.hook_mlp_out,-54.524628,0.315964,0.010007
35,3,8,blocks.3.hook_mlp_out,-54.52972,0.310871,0.009846
1,0,1,blocks.0.hook_mlp_out,-54.536316,0.304276,0.009637



Component: attention


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
26,2,8,blocks.2.hook_attn_out,-53.576096,1.264496,0.04005
3,0,3,blocks.0.hook_attn_out,-53.626106,1.214485,0.038466
31,3,4,blocks.3.hook_attn_out,-54.214375,0.626217,0.019834
16,1,7,blocks.1.hook_attn_out,-54.495163,0.345428,0.010941
1,0,1,blocks.0.hook_attn_out,-54.500153,0.340439,0.010783
22,2,4,blocks.2.hook_attn_out,-54.526119,0.314472,0.00996
13,1,4,blocks.1.hook_attn_out,-54.628998,0.211594,0.006702
23,2,5,blocks.2.hook_attn_out,-54.639874,0.200718,0.006357
21,2,3,blocks.2.hook_attn_out,-54.668819,0.171772,0.005441
19,2,1,blocks.2.hook_attn_out,-54.700977,0.139614,0.004422



=== Top 10 locations for PII type: passport ===

Component: residual


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
39,4,7,blocks.4.mlp.hook_post,-41.394798,1.36544,0.091174
31,3,7,blocks.3.mlp.hook_post,-42.246067,0.514172,0.034333
11,1,3,blocks.1.mlp.hook_post,-42.40683,0.353409,0.023598
1,0,1,blocks.0.mlp.hook_post,-42.542557,0.217682,0.014535
18,2,2,blocks.2.mlp.hook_post,-42.562729,0.19751,0.013188
14,1,6,blocks.1.mlp.hook_post,-42.569824,0.190414,0.012715
38,4,6,blocks.4.mlp.hook_post,-42.616226,0.144012,0.009616
13,1,5,blocks.1.mlp.hook_post,-42.628929,0.13131,0.008768
34,4,2,blocks.4.mlp.hook_post,-42.631943,0.128296,0.008567
29,3,5,blocks.3.mlp.hook_post,-42.633713,0.126526,0.008449



Component: mlp-post


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
39,4,7,blocks.4.mlp.hook_post,-41.394798,1.36544,0.091174
31,3,7,blocks.3.mlp.hook_post,-42.246067,0.514172,0.034333
11,1,3,blocks.1.mlp.hook_post,-42.40683,0.353409,0.023598
1,0,1,blocks.0.mlp.hook_post,-42.542557,0.217682,0.014535
18,2,2,blocks.2.mlp.hook_post,-42.562729,0.19751,0.013188
14,1,6,blocks.1.mlp.hook_post,-42.569824,0.190414,0.012715
38,4,6,blocks.4.mlp.hook_post,-42.616226,0.144012,0.009616
13,1,5,blocks.1.mlp.hook_post,-42.628929,0.13131,0.008768
34,4,2,blocks.4.mlp.hook_post,-42.631943,0.128296,0.008567
29,3,5,blocks.3.mlp.hook_post,-42.633713,0.126526,0.008449



Component: mlp_out


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
11,1,3,blocks.1.hook_mlp_out,-42.229336,0.530903,0.03545
39,4,7,blocks.4.hook_mlp_out,-42.234947,0.525291,0.035075
1,0,1,blocks.0.hook_mlp_out,-42.436413,0.323826,0.021623
14,1,6,blocks.1.hook_mlp_out,-42.488537,0.271702,0.018142
7,0,7,blocks.0.hook_mlp_out,-42.545135,0.215103,0.014363
21,2,5,blocks.2.hook_mlp_out,-42.584751,0.175488,0.011718
13,1,5,blocks.1.hook_mlp_out,-42.608414,0.151825,0.010138
30,3,6,blocks.3.hook_mlp_out,-42.612251,0.147987,0.009882
34,4,2,blocks.4.hook_mlp_out,-42.641766,0.118473,0.007911
22,2,6,blocks.2.hook_mlp_out,-42.647392,0.112846,0.007535



Component: attention


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
7,0,7,blocks.0.hook_attn_out,-41.258492,1.501747,0.100276
15,1,7,blocks.1.hook_attn_out,-41.935619,0.824619,0.055062
5,0,5,blocks.0.hook_attn_out,-42.113922,0.646317,0.043156
23,2,7,blocks.2.hook_attn_out,-42.45224,0.307999,0.020566
39,4,7,blocks.4.hook_attn_out,-42.468132,0.292107,0.019505
20,2,4,blocks.2.hook_attn_out,-42.508842,0.251396,0.016786
4,0,4,blocks.0.hook_attn_out,-42.562397,0.197842,0.01321
22,2,6,blocks.2.hook_attn_out,-42.574394,0.185844,0.012409
28,3,4,blocks.3.hook_attn_out,-42.637169,0.12307,0.008218
8,1,0,blocks.1.hook_attn_out,-42.655586,0.104652,0.006988


In [15]:
# Heatmap visualizations of sweep results using a simple Plotly imshow helper

from typing import Optional, Mapping


def imshow(tensor, **kwargs):
    """Simple wrapper around plotly.express.imshow.

    Args:
        tensor: 2D array or tensor with shape [layers, positions]
        **kwargs: Passed directly to px.imshow (e.g., x, y, labels, title).
    """
    px.imshow(
        tl_utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


# Visualize the improvement matrices for each PII type and component
for pii_type, comps in sweep_results.items():
    for component, result in comps.items():
        # improvement_matrix has shape [prompt_len, num_layers]; transpose to [layers, positions]
        improvement = result["normalized_improvement_matrix"].T

        tokens = result["position_tokens"]
        prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(tokens)]

        target_pii = result["target_pii"]

        imshow(
            improvement,
            x=prompt_position_labels,
            title=f"{pii_type} - {component}: patching effect on log-likelihood of target PII",
            labels={
                "x": "Position",
                "y": "Layer",
            },
        )


In [19]:
# lets try patching from the control to the target model to see which layers and positions would lead to the target model performing worse on the pii
# Run sweeps for all PII types and components
control_to_target_sweep_results: Dict[str, Dict[str, Dict[str, Any]]] = {}

for pii_type, sample in tokenized_samples.items():
    print(f"\n=== Running sweeps for PII type: {pii_type} ===")
    control_to_target_sweep_results[pii_type] = {}
    control_to_target_sweep_results[pii_type]["residual"] = run_component_sweep_for_sample(
        tl_target_model, tl_control_model, pii_type, sample, "post", "res"
    )
    control_to_target_sweep_results[pii_type]["mlp-post"] = run_component_sweep_for_sample(
        tl_target_model, tl_control_model, pii_type, sample, "post", "mlp"
    )
    control_to_target_sweep_results[pii_type]["mlp_out"] = run_component_sweep_for_sample(
        tl_target_model, tl_control_model, pii_type, sample, "mlp_out"
    )
    control_to_target_sweep_results[pii_type]["attention"] = run_component_sweep_for_sample(
        tl_target_model, tl_control_model, pii_type, sample, "attn_out", ""
    )
print("\nCompleted all sweeps.")




=== Running sweeps for PII type: driver_license ===


driver_license res post layers:   0%|          | 0/6 [00:00<?, ?it/s]

driver_license mlp post layers:   0%|          | 0/6 [00:00<?, ?it/s]

driver_license None mlp_out layers:   0%|          | 0/6 [00:00<?, ?it/s]

driver_license  attn_out layers:   0%|          | 0/6 [00:00<?, ?it/s]


=== Running sweeps for PII type: email ===


email res post layers:   0%|          | 0/6 [00:00<?, ?it/s]

email mlp post layers:   0%|          | 0/6 [00:00<?, ?it/s]

email None mlp_out layers:   0%|          | 0/6 [00:00<?, ?it/s]

email  attn_out layers:   0%|          | 0/6 [00:00<?, ?it/s]


=== Running sweeps for PII type: id_number ===


id_number res post layers:   0%|          | 0/6 [00:00<?, ?it/s]

id_number mlp post layers:   0%|          | 0/6 [00:00<?, ?it/s]

id_number None mlp_out layers:   0%|          | 0/6 [00:00<?, ?it/s]

id_number  attn_out layers:   0%|          | 0/6 [00:00<?, ?it/s]


=== Running sweeps for PII type: passport ===


passport res post layers:   0%|          | 0/6 [00:00<?, ?it/s]

passport mlp post layers:   0%|          | 0/6 [00:00<?, ?it/s]

passport None mlp_out layers:   0%|          | 0/6 [00:00<?, ?it/s]

passport  attn_out layers:   0%|          | 0/6 [00:00<?, ?it/s]


Completed all sweeps.


In [20]:
# Visualize the improvement matrices for each PII type and component
for pii_type, comps in control_to_target_sweep_results.items():
    for component, result in comps.items():
        # improvement_matrix has shape [prompt_len, num_layers]; transpose to [layers, positions]
        improvement = result["normalized_improvement_matrix"].T

        tokens = result["position_tokens"]
        prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(tokens)]

        target_pii = result["target_pii"]

        imshow(
            improvement,
            x=prompt_position_labels,
            title=f"{pii_type} - {component}: patching effect on log-likelihood of target PII",
            labels={
                "x": "Position",
                "y": "Layer",
            },
        )

patching from the control to the target model we see again no real change in the prediction (most layer positions still have 1 value) so patching certain layers in the target model does not influence the behaviour

## Summary and Usage

This notebook implements the full activation patching pipeline described in `experiment-design.md`:

- Loads target/control Pythia models into TransformerLens and validates them
- Loads memorized PII data, selects one sample per PII type, and computes baseline teacher-forced logprobs
- Runs full residual/MLP/attention sweeps over all layers and prompt positions and saves results/heatmaps
- Compares MLP vs attention at important (layer, position) locations and performs head-level attention patching

Results (matrices, CSVs, JSON metadata, and plots) are written under `activation_patching_results/{MODEL_SIZE}M/` in the structure specified in the experimental design.

