# Notebook assumed structure
To run this notebook as is, the assumed structure is the following to load the pre-trained models

```py
colab/ 
 | -- 70M_memorized_pii_by_type.json
 | -- 160M_memorized_pii_by_type.json
 | -- activation-patching-experiments.ipynb
models/
 | -- 70M/
    | -- control/
        | -- config.json            
        | -- generation_config.json
        | -- model.safetensors.json
    | -- memorized/
        | -- config.json            
        | -- generation_config.json
        | -- model.safetensors.json
 | -- 160M/
    | -- control/
        | -- config.json            
        | -- generation_config.json
        | -- model.safetensors.json
    | -- memorized/
        | -- config.json            
        | -- generation_config.json
        | -- model.safetensors.json
```

In [5]:
import os
import json
import math
import re
from pathlib import Path
from typing import Dict, List, Tuple, Any

import numpy as np
import pandas as pd
import torch
from torch import Tensor

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import plotly.express as px

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformer_lens import HookedTransformer
import transformer_lens.utils as tl_utils
from functools import partial

MODEL_SIZE = 160
BASE_MODEL_NAME = f"EleutherAI/pythia-{MODEL_SIZE}m"

PROJECT_ROOT = Path("..").resolve()
TARGET_MODEL_PATH = PROJECT_ROOT / f"models/{MODEL_SIZE}M/memorized"
CONTROL_MODEL_PATH = PROJECT_ROOT / f"models/{MODEL_SIZE}M/control"
PII_DATA_PATH = PROJECT_ROOT / f"colab/{MODEL_SIZE}M_memorized_pii_by_type.json"
OUTPUT_DIR = PROJECT_ROOT / f"activation_patching_results/{MODEL_SIZE}M"

OUTPUT_RESULTS_DIR = OUTPUT_DIR / "results"
OUTPUT_VIS_DIR = OUTPUT_DIR / "visualizations"

PII_TYPES = ["driver_license", "email", "id_number", "passport"]
TOP_N_IMPORTANT_LOCATIONS = 10

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42

os.makedirs(OUTPUT_RESULTS_DIR, exist_ok=True)
os.makedirs(OUTPUT_VIS_DIR, exist_ok=True)

torch.set_grad_enabled(False)
torch.manual_seed(SEED)
np.random.seed(SEED)

print(f"Using device: {DEVICE}")
print(f"Project root: {PROJECT_ROOT}")
print(f"Target model path: {TARGET_MODEL_PATH}")
print(f"Control model path: {CONTROL_MODEL_PATH}")
print(f"PII data path: {PII_DATA_PATH}")
print(f"Output directory: {OUTPUT_DIR}")


Using device: cpu
Project root: /Users/georgekontorousis/git/pii_memo
Target model path: /Users/georgekontorousis/git/pii_memo/models/160M/memorized
Control model path: /Users/georgekontorousis/git/pii_memo/models/160M/control
PII data path: /Users/georgekontorousis/git/pii_memo/colab/160M_memorized_pii_by_type.json
Output directory: /Users/georgekontorousis/git/pii_memo/activation_patching_results/160M


In [25]:
DEBUG_LOGGING = True

def debug_log(msg: str) -> None:
    if DEBUG_LOGGING:
        print(msg)



In [7]:
from typing import Optional

def load_pythia_models(
    model_size: int,
    target_model_path: Path,
    control_model_path: Path,
    device: str = DEVICE,
) -> Tuple[HookedTransformer, HookedTransformer, AutoTokenizer]:
    base_model_name = f"EleutherAI/pythia-{model_size}m"

    print(f"Loading base model: {base_model_name}")
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)

    print(f"Loading target model from: {target_model_path}")
    target_hf_model = AutoModelForCausalLM.from_pretrained(str(target_model_path))

    print(f"Loading control model from: {control_model_path}")
    control_hf_model = AutoModelForCausalLM.from_pretrained(str(control_model_path))

    tl_target_model = HookedTransformer.from_pretrained(
        base_model_name,
        hf_model=target_hf_model,
        tokenizer=tokenizer,
        device=device,
    )

    tl_control_model = HookedTransformer.from_pretrained(
        base_model_name,
        hf_model=control_hf_model,
        tokenizer=tokenizer,
        device=device,
    )

    tl_target_model.eval()
    tl_control_model.eval()

    return tl_target_model, tl_control_model, tokenizer

tl_target_model, tl_control_model, tokenizer = load_pythia_models(
    MODEL_SIZE,
    TARGET_MODEL_PATH,
    CONTROL_MODEL_PATH,
    DEVICE,
)


Loading base model: EleutherAI/pythia-160m
Loading target model from: /Users/georgekontorousis/git/pii_memo/models/160M/memorized
Loading control model from: /Users/georgekontorousis/git/pii_memo/models/160M/control
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer
Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer


In [8]:
if False:
    # ATTEMPT TO USE BATCHING FOR PII DATA
    def select_pii_samples(pii_data: Dict[str, List[Dict[str, Any]]]) -> Dict[str, List[Dict[str, Any]]]:
        """Selects the samples from the pii_data
        Returns a dict keyed by pii_type with a list of dicts (one per sample) with fields:
        - pii_type
        - sequence_idx
        - text_prompt
        - target_pii
        """
        selected: Dict[str, List[Dict[str, Any]]] = {}
        for pii_type in PII_TYPES:
            selected[pii_type] = []
            for sample in pii_data[pii_type]:
                selected[pii_type].append({
                    "text_prompt": sample["text_prompt"],
                    "target_pii": sample["target_pii"],
                })
        
        return selected


    def tokenize_pii_samples(
        samples_list: Dict[str, List[Dict[str, Any]]],
        model: HookedTransformer,
        device: str = DEVICE,
        prepend_bos: bool = True,
    ) -> Dict[str, Dict[str, Any]]:
        """Tokenize prompts and target PII strings for each selected sample.

        For each pii_type, we store:
        - prompt_tokens: Tensor[batch_size, max_prompt_len]
        - target_tokens: Tensor[batch_size, max_target_len]
        - full_tokens: Tensor[batch_size, max_prompt_len + max_target_len]
        - prompt_len: Tensor[batch_size]
        - target_len: Tensor[batch_size]
        - prompt_str_tokens: List[str]
        - target_str_tokens: List[str]

        where max_prompt_len and max_target_len are the maximum lengths of the prompt and target PII sequences in the batch (other sequences are padded)
        prompt input is padded on the left and target PII is padded on the right.
        """
        tokenized: Dict[str, Dict[str, Any]] = {}

        for pii_type, samples in samples_list.items():
            tokenized[pii_type] = {}
            text_prompts = [ sample["text_prompt"] for sample in samples ]
            target_piis = [ sample["target_pii"] for sample in samples ]
            full_prompts = [ sample["full_prompt"] for sample in samples ]

            prompt_tokens = model.tokenizer(text_prompts, padding=True, padding_side="left", return_tensors="pt").to(device)
            target_tokens = model.tokenizer(target_piis, padding=True, padding_side="right", return_tensors="pt").to(device)
            full_tokens = model.tokenizer(full_prompts, padding=True, padding_side="left", return_tensors="pt").to(device)
            # full_tokens["prompt_indices"] = 

            full_tokens_target_indices = np.arange(len(full_tokens['input_ids']) - len(prompt_tokens['input_ids']))
            # prompt_tokens = model.to_tokens(text_prompts, prepend_bos=prepend_bos, padding_side="left")
            prompt_lens = [ len(sample["text_prompt"]) for sample in samples ]
            target_lens = [ len(sample["target_pii"]) for sample in samples ]
            full_tokens = model.tokenizer(text_prompts + target_piis, padding=True, padding_side="left", return_tensors="pt")
            full_prompts = text_prompts + target_piis
            print(full_prompts[2])
            print(prompt_tokens["input_ids"][2])
            prompt_str_tokens = [model.to_str_tokens(prompt_tokens["input_ids"][i]) for i in range(len(samples))]
            target_str_tokens = [model.to_str_tokens(target_tokens["input_ids"][i]) for i in range(len(samples))]

            tokenized[pii_type] = {
                "text_prompts": text_prompts,
                "target_piis": target_piis,
                "prompt_tokens": prompt_tokens,
                "target_tokens": target_tokens,
                "full_tokens": full_tokens,
                "prompt_len": prompt_lens,
                "target_len": target_lens,
                "prompt_str_tokens": prompt_str_tokens,
                "target_str_tokens": target_str_tokens,
            }

        return tokenized

    with PII_DATA_PATH.open("r") as f:
            pii_data = json.load(f)
    tokenized_samples = tokenize_pii_samples(pii_data, tl_target_model, device=DEVICE, prepend_bos=False)
    print(tokenized_samples.keys())
    print(tokenized_samples["driver_license"].keys())
    print(f"prompt_tokens: {tokenized_samples['driver_license']['prompt_tokens'][0]}")
    print(f"prompt_len: {tokenized_samples['driver_license']['prompt_len'][0]}")
    print(f"target_tokens: {tokenized_samples['driver_license']['target_tokens'][0]}")
    print(f"target_len: {tokenized_samples['driver_license']['target_len'][0]}")
    print(f"full_tokens: {tokenized_samples['driver_license']['full_tokens'][0]}")
    print(f"full_tokens shape: {tokenized_samples['driver_license']['full_tokens'].shape}")
    print(f"prompt_str_tokens: {tokenized_samples['driver_license']['prompt_str_tokens'][0]}")
    print(f"target_str_tokens: {tokenized_samples['driver_license']['target_str_tokens'][0]}")
    input_prompt = "Driver's license number of Catherine Nielsen is"
    target_pii = " CA-DL-859644744"
    print(f"input prompt: {input_prompt}")
    print(f"input tokens: {tl_target_model.to_tokens(input_prompt, prepend_bos=False)}")
    print(f"input str tokens: {tl_target_model.to_str_tokens(tl_target_model.to_tokens(input_prompt, prepend_bos=False))}")
    print(f"expected target pii: {target_pii}")
    print(f"target tokens: {tl_target_model.to_tokens(target_pii, prepend_bos=False)}")
    print(f"target str tokens: {tl_target_model.to_str_tokens(tl_target_model.to_tokens(target_pii, prepend_bos=False))}")


    generated_tokens = tl_target_model.generate(tl_target_model.to_tokens(input_prompt, prepend_bos=False), max_new_tokens=10, do_sample=False)
    print(f"generated_tokens: {generated_tokens}")
    generated_text = tl_target_model.to_str_tokens(generated_tokens)
    print(f"Generated text: {generated_text}")

    full_prompt = f"{input_prompt} {target_pii}"
    print(f"full prompt: {full_prompt}")
    print(f"full tokens: {tl_target_model.to_tokens(full_prompt, prepend_bos=False)}")
    print(f"full str tokens: {tl_target_model.to_str_tokens(tl_target_model.to_tokens(full_prompt, prepend_bos=False))}")




    # Validate that target model generates expected PII for all samples (batch generation)
    for pii_type, data in tokenized_samples.items():
        print(f"\n=== Validating PII type: {pii_type} ===")
        
        text_prompts = data["text_prompts"]
        target_piis = data["target_piis"]
        prompt_tokens = data["prompt_tokens"]  # [batch_size, max_prompt_len] - already left-padded
        target_lens = data["target_len"]  # list of actual target lengths
        
        batch_size = len(text_prompts)
        max_gen_len = max(target_lens)  # Use max target length for all generations
        
        print(f"Batch generating for {batch_size} samples (max_new_tokens={max_gen_len})...")
        
        # Prepare prompt tokens for generation
        # Remove BOS token if present (check first token of first sample)
        gen_prompt_tokens = prompt_tokens.clone()
        
        # Batch generate
        with torch.no_grad():
            generated_tokens = tl_target_model.generate(
                gen_prompt_tokens,
                max_new_tokens=max_gen_len,
                do_sample=False,
                verbose=False
            )  # [batch_size, prompt_len + max_gen_len]
        
        # Decode and validate each sample
        failed_samples = []
        for i in range(batch_size):
            generated_text = tl_target_model.tokenizer.decode(generated_tokens[i])
            target_pii = target_piis[i]
            
            if target_pii not in generated_text:
                failed_samples.append({
                    "index": i,
                    "text_prompt": text_prompts[i],
                    "target_pii": target_pii,
                    "generated_text": generated_text,
                })
        
        # Report results
        if failed_samples:
            print(f"❌ {len(failed_samples)}/{batch_size} samples failed validation:")
            for fail in failed_samples:
                print(f"\n  Sample {fail['index']}:")
                print(f"    Prompt: '{fail['text_prompt']}'")
                print(f"    Target PII: '{fail['target_pii']}'")
                print(f"    Generated: '{fail['generated_text']}'")
            raise AssertionError(f"{len(failed_samples)} samples failed validation for {pii_type}")
        else:
            print(f"✓ All {batch_size} samples for {pii_type} validated successfully!")

    print("\n✓ All PII types validated - target model generates expected PII for all samples!")

In [9]:
def select_pii_samples(pii_data: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Dict[str, Any]]:
    """Select the first sample (index 0) for each PII type.

    Returns a dict keyed by pii_type with fields:
    - pii_type
    - sequence_idx
    - text_prompt
    - target_pii
    """
    selected: Dict[str, Dict[str, Any]] = {}
    for pii_type in PII_TYPES:
        sample = pii_data[pii_type][0]
        selected[pii_type] = {
            "pii_type": pii_type,
            "sequence_idx": 0,
            "text_prompt": sample["text_prompt"],
            "target_pii": sample["target_pii"],
        }
    return selected


def tokenize_pii_samples(
    samples: Dict[str, Dict[str, Any]],
    model: HookedTransformer,
    device: str = DEVICE,
    prepend_bos: bool = True,
) -> Dict[str, Dict[str, Any]]:
    """Tokenize prompts and target PII strings for each selected sample.

    For each pii_type, we store:
    - prompt_tokens: [1, prompt_len]
    - target_tokens: [1, target_len]
    - full_tokens: [1, prompt_len + target_len]
    - prompt_len, target_len
    - prompt_str_tokens, target_str_tokens
    """
    tokenized: Dict[str, Dict[str, Any]] = {}

    for pii_type, info in samples.items():
        prompt = info["text_prompt"]
        target = info["target_pii"]

        prompt_tokens: Tensor = model.to_tokens(prompt, prepend_bos=prepend_bos).to(device)
        target_tokens: Tensor = model.to_tokens(target, prepend_bos=False).to(device)

        full_tokens = torch.cat([prompt_tokens, target_tokens], dim=-1).to(device)

        prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos)
        target_str_tokens = model.to_str_tokens(target, prepend_bos=False)

        tokenized[pii_type] = {
            **info,
            "prompt_tokens": prompt_tokens,
            "target_tokens": target_tokens,
            "full_tokens": full_tokens,
            "prompt_len": prompt_tokens.shape[-1],
            "target_len": target_tokens.shape[-1],
            "prompt_str_tokens": prompt_str_tokens,
            "target_str_tokens": target_str_tokens,
        }

    return tokenized

with PII_DATA_PATH.open("r") as f:
        pii_data = json.load(f)
selected_samples = select_pii_samples(pii_data)
tokenized_samples = tokenize_pii_samples(selected_samples, tl_target_model, device=DEVICE)

for pii_type, sample in tokenized_samples.items():
    print(f"\n=== PII type: {pii_type} ===")
    print(f"Prompt: {sample['text_prompt']}")
    print(f"Target PII: {sample['target_pii']}")
    print(f"Prompt tokens ({sample['prompt_len']}): {sample['prompt_str_tokens']}")
    print(f"Target tokens  (shape: {sample['target_tokens'].shape}): {sample['target_str_tokens']}")
    print(f"Target token ids (shape: {sample['target_tokens'].shape}): {sample['target_tokens']}")
    gen_len = sample["target_len"]
    prompt_tokens = sample["prompt_tokens"]
    if prompt_tokens[0, 0].item() == tl_target_model.tokenizer.bos_token_id:
        gen_prompt_tokens = prompt_tokens[:, 1:]
    else:
        gen_prompt_tokens = prompt_tokens

    with torch.no_grad():
        generated_tokens = tl_target_model.generate(gen_prompt_tokens, max_new_tokens=gen_len, do_sample=False, verbose=False)
        generated_str_tokens = tl_target_model.to_str_tokens(generated_tokens)
        generated_text = tl_target_model.tokenizer.decode(generated_tokens[0])
        # add assertion to make sure that the target pii has been memorized in the target model
        assert sample['target_pii'] in generated_text, f"Target PII not present in generated text: {sample['target_pii']}, generated text: {generated_text}"




=== PII type: driver_license ===
Prompt: Driver's license number of Catherine Nielsen is
Target PII:  CA-DL-859644744
Prompt tokens (9): ['<|endoftext|>', 'Driver', "'s", ' license', ' number', ' of', ' Catherine', ' Nielsen', ' is']
Target tokens  (shape: torch.Size([1, 8])): [' CA', '-', 'DL', '-', '859', '64', '47', '44']
Target token ids (shape: torch.Size([1, 8])): tensor([[ 6582,    14,  6721,    14, 32168,  1540,  2504,  2031]])


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



=== PII type: email ===
Prompt: Email address of Jack Gera is
Target PII:  jack727@yahoo.in
Prompt tokens (8): ['<|endoftext|>', 'Email', ' address', ' of', ' Jack', ' G', 'era', ' is']
Target tokens  (shape: torch.Size([1, 6])): [' jack', '727', '@', 'yahoo', '.', 'in']
Target token ids (shape: torch.Size([1, 6])): tensor([[19708, 47014,    33, 49269,    15,   249]])

=== PII type: id_number ===
Prompt: ID number of Kara Mata is
Target PII:  PH-ID-4236546
Prompt tokens (9): ['<|endoftext|>', 'ID', ' number', ' of', ' K', 'ara', ' M', 'ata', ' is']
Target tokens  (shape: torch.Size([1, 7])): [' PH', '-', 'ID', '-', '423', '65', '46']
Target token ids (shape: torch.Size([1, 7])): tensor([[ 8367,    14,  1838,    14, 38043,  2082,  2950]])

=== PII type: passport ===
Prompt: Passport number of Kevin Smith is
Target PII:  I18325032
Prompt tokens (8): ['<|endoftext|>', 'Pass', 'port', ' number', ' of', ' Kevin', ' Smith', ' is']
Target tokens  (shape: torch.Size([1, 4])): [' I', '183', '2

# Activation Patching Experiments for Memorized PII

The metric we decided to use for evaluating the model performance with activation patching is log-likelihood of the target pii sequence when the model is prompted with a prefix. Since the target pii is composed of multiple consecutive tokens, to get its likelihood from the model we sum the log probabilities of the target pii tokens. If the prompt tokens are $x_{1:\ell}$ and the target PII tokens are $y_{1:N} = (y_1, \dots, y_N)$, then for a model $M$ with token probabilities $p_M(\cdot)$ we define:

\begin{equation*}
\operatorname{LL}_M\big(y_{1:N} \mid x_{1:\ell}\big)
= \sum_{t=1}^N \log p_M\big(y_t \mid x_{1:\ell}, y_{1:t-1}\big).
\end{equation*}

To have a baseline comparison we first compute the log-likelihood of the target pii using the target (memorized model) and the control model

\begin{equation*}
\operatorname{LL}_{\text{target}} = \operatorname{LL}_{M_{\text{target}}}(y_{1:N} \mid x_{1:\ell}),
\qquad
\operatorname{LL}_{\text{control}} = \operatorname{LL}_{M_{\text{control}}}(y_{1:N} \mid x_{1:\ell}).
\end{equation*}

We then define their **baseline improvement**:

\begin{equation*}
\Delta \operatorname{LL}_{\text{base}} = \operatorname{LL}_{\text{target}} - \operatorname{LL}_{\text{control}}.
\end{equation*}

During activation patching, we replace a single layer-and-position activation of the control model with the corresponding activation from the target model. For a given layer $l$ and prompt position $i$, let $\operatorname{LL}_{\text{patched}}^{(l,i)}$ be the log-likelihood of the target PII under this **patched control model**. We then define the **log-likelihood improvement**:

\begin{equation*}
I^{(l,i)} = \operatorname{LL}_{\text{patched}}^{(l,i)} - \operatorname{LL}_{\text{control}}.
\end{equation*}

To normalize the improvement of patching, we divide by the target vs control improvement

\begin{equation*}
\tilde{I}^{(l,i)}
= \frac{I^{(l,i)}}{\Delta \operatorname{LL}_{\text{base}}}
= \frac{\operatorname{LL}_{\text{patched}}^{(l,i)} - \operatorname{LL}_{\text{control}}}{\operatorname{LL}_{\text{target}} - \operatorname{LL}_{\text{control}}}.
\end{equation*}

A value of $\tilde{I}^{(l,i)} \approx 1$ means that patching that specific (layer, position) almost fully recovers the target model's advantage on the memorized PII, highlighting it as an important location for where the PII is stored or routed in the network.

The basis of choosing this metric is from this arxiv article [here](https://arxiv.org/html/2404.15255v1#:~:text=It%E2%80%99s%20easy%20to,in%20this%20section%3A)

The logit difference metric did not seem applicable in our context as we are using the same prompt in both control and target model, and also there is no clear "opposite" answer to the target pii to compute the logit difference against.

## How we do activation patching and teacher-forced evaluation

We run activation patching on the **control** model while always evaluating the log-likelihood of the PII under **teacher forcing**.

Since the target pii is multiple tokens, instead of just supplying the input prompt and then computing the logits for the next token (as done in the ROME paper) we instead use teacher-forcing by supplying the entire sequence (input prompt + target pii), get the metrics by passing that sequence to the model and computing our metric.

- **Full input sequence (prompt + target PII)**  
  For each sample we build a full token sequence
  \begin{equation*}
  \text{full\_tokens} = [x_{1:\ell},\; y_{1:N}],
  \end{equation*}
  where $x_{1:\ell}$ are the prompt tokens and $y_{1:N}$ are the target PII tokens. We feed this entire sequence into both the target and control models.

- **Teacher-forced log-likelihood on the PII span**  
  Given logits over the full sequence, we only evaluate log-probabilities on the **PII targets** tokens:
  \begin{equation*}
  \operatorname{LL}_M\big(y_{1:N} \mid x_{1:\ell}\big)
  = \sum_{t=1}^N \log p_M\big(y_t \mid x_{1:\ell}, y_{1:t-1}\big),
  \end{equation*}

- **Why we use full \(\text{prompt} + \text{PII}\) instead of just the prompt**
  We use the full \(\text{prompt} + \text{PII}\) in order to compute the logits for all the target pii tokens using teacher forcing. Since the model is autoregressive, it only uses information from the past so despite passing the full sequence, we only patch activations related to the input prompt tokens so we do not evaluate activation patching of "future information". We only patch information that the target model has while it is reading the input prompt.

- **How patching is applied**  
  We first run the **target** model once on `full_tokens` and store a cache of activations for every hook point. Then, for a given layer $l$ and prompt position $i$:
  - We run the **control** model on the same `full_tokens`.
  - At the chosen hook and position $i$ in the **prompt**, we **replace** the control activation with the corresponding activation from the target cache. Concretely, in `hook_fn` we do
    \begin{equation*}
    \text{patched}[\;, i, \dots] \leftarrow \text{cached}[\;, i, \dots],
    \end{equation*}
    leaving all other positions and layers untouched.
  - We only patch activations that are at positions of the prompt sequence (not the target pii) and due to the nature of autoregressive models, despite using the full sequence to compute the activations for the target, patching only at positions up to the last prompt token does not lead to using "future knowledge".
  - We then recompute the teacher-forced log-likelihood of the PII under this **patched control model**.

- **Why this makes sense for localization**  
  Because the input tokens are fixed (same `full_tokens`) and we change only a single (layer, position) activation, any change in PII log-likelihood can be attributed to the contribution of that internal site. Sweeping over all layers and prompt positions tells us **where** in the network the target model’s PII-specific behavior is injected into the control model, which is exactly what we want to localize.

In [10]:
def logprob_from_logits(
    logits: Tensor,
    prompt_len: int,
    target_token_ids: Tensor,
) -> Tuple[float, List[float]]:
    """Compute sequence logprob and per-token logprobs for target tokens.

    Args:
        logits: [1, seq_len, vocab_size] logits for full `prompt + target` sequence.
        prompt_len: Number of tokens in the prompt (including BOS if used).
        target_token_ids: [target_len] tensor of token ids for the target PII.
    """
    assert logits.ndim == 3 and logits.shape[0] == 1, "Only batch size 1 is supported"

    target_token_ids = target_token_ids.to(logits.device)
    target_len = target_token_ids.shape[0]

    # Positions in the sequence where each target token is predicted
    # The logit at position t predicts token at position t+1, so the
    # first target token (at index prompt_len) is predicted from t = prompt_len-1.
    start_pos = prompt_len - 1
    positions = torch.arange(target_len, device=logits.device) + start_pos

    # Slice logits at those positions: [target_len, vocab_size]
    logits_slice = logits[0, positions, :]
    log_probs_slice = torch.log_softmax(logits_slice, dim=-1)

    # Gather logprobs of the actual target tokens
    token_log_probs = log_probs_slice.gather(-1, target_token_ids.unsqueeze(-1)).squeeze(-1)

    sequence_logprob = token_log_probs.sum()
    return float(sequence_logprob.item()), token_log_probs.detach().cpu().tolist()


def logprob_predict(
    model: HookedTransformer,
    prompt_tokens: Tensor,
    target_tokens: Tensor,
) -> Tuple[float, List[float]]:
    """Compute teacher-forced logprob for `target_tokens` given `prompt_tokens`.

    This runs a single forward pass over the full sequence and extracts
    token-level logprobs corresponding to the target PII span.
    """
    full_tokens = torch.cat([prompt_tokens, target_tokens], dim=-1)
    logits = model(full_tokens)

    prompt_len = prompt_tokens.shape[-1]
    target_token_ids = target_tokens[0]

    return logprob_from_logits(logits, prompt_len, target_token_ids)


# Compute baseline metrics for all selected PII samples

baseline_metrics: Dict[str, Dict[str, Any]] = {}

for pii_type, sample in tokenized_samples.items():
    prompt_tokens = sample["prompt_tokens"]
    target_tokens = sample["target_tokens"]

    target_logprob, target_token_logprobs = logprob_predict(
        tl_target_model, prompt_tokens, target_tokens
    )
    control_logprob, control_token_logprobs = logprob_predict(
        tl_control_model, prompt_tokens, target_tokens
    )

    sample["baseline"] = {
        "target_logprob": target_logprob,
        "control_logprob": control_logprob,
        "target_token_logprobs": target_token_logprobs,
        "control_token_logprobs": control_token_logprobs,
    }

    baseline_metrics[pii_type] = {
        "pii_type": pii_type,
        "sequence_idx": sample["sequence_idx"],
        "text_prompt": sample["text_prompt"],
        "target_pii": sample["target_pii"],
        "target_logprob": target_logprob,
        "control_logprob": control_logprob,
        "difference": target_logprob - control_logprob,
    }

baseline_df = pd.DataFrame.from_dict(baseline_metrics, orient="index")
print("\nBaseline metrics (log probabilities of target PII):")
display(baseline_df[["pii_type", "target_logprob", "control_logprob", "difference"]])



Baseline metrics (log probabilities of target PII):


Unnamed: 0,pii_type,target_logprob,control_logprob,difference
driver_license,driver_license,-4.290713,-52.550125,48.259412
email,email,-5.196015,-37.155529,31.959514
id_number,id_number,-18.061466,-45.054054,26.992588
passport,passport,-10.946669,-33.009621,22.062952


## Layers available to patch
Lets see what layers the model has available for us to patch. Each layer block (number) has the same layers, so lets see each layer hook type and their shape (which are also dependant on the sequence length)

We can see the layer hooks available to use with transformerlens by checking in the cache.keys(). Some hook points are available per layer, some just in the beginning or the end of the model.

In [11]:
# Print out the hook names with their activation shapes
tl_target_model.set_use_attn_result(False)
tl_target_model.set_use_attn_in(False)
tl_target_model.set_use_split_qkv_input(False)
tl_target_model.set_use_hook_mlp_in(False)
_, cache = tl_target_model.run_with_cache("Hey George, how are you doing?", return_cache_object=True)

unique_hooks = set()
for hook_name in cache.keys():
    hook_name_parts = hook_name.split('.')
    if len(hook_name_parts) > 1 and not hook_name_parts[-2].isdigit():
        name = ".".join(hook_name_parts[-2:])
    else:
        name = hook_name_parts[-1]
    print_name = f"{name} ({cache[hook_name].shape})"
    unique_hooks.add(print_name)

print(f"Input tokens: {tl_target_model.to_tokens('Hey George, how are you doing?').shape[-1]}")
hooks = list(unique_hooks)
hooks.sort()
for hook in hooks:
    print(hook)


Input tokens: 9
attn.hook_attn_scores (torch.Size([1, 12, 9, 9]))
attn.hook_k (torch.Size([1, 9, 12, 64]))
attn.hook_pattern (torch.Size([1, 12, 9, 9]))
attn.hook_q (torch.Size([1, 9, 12, 64]))
attn.hook_rot_k (torch.Size([1, 9, 12, 64]))
attn.hook_rot_q (torch.Size([1, 9, 12, 64]))
attn.hook_v (torch.Size([1, 9, 12, 64]))
attn.hook_z (torch.Size([1, 9, 12, 64]))
hook_attn_out (torch.Size([1, 9, 768]))
hook_embed (torch.Size([1, 9, 768]))
hook_mlp_out (torch.Size([1, 9, 768]))
hook_resid_post (torch.Size([1, 9, 768]))
hook_resid_pre (torch.Size([1, 9, 768]))
ln1.hook_normalized (torch.Size([1, 9, 768]))
ln1.hook_scale (torch.Size([1, 9, 1]))
ln2.hook_normalized (torch.Size([1, 9, 768]))
ln2.hook_scale (torch.Size([1, 9, 1]))
ln_final.hook_normalized (torch.Size([1, 9, 768]))
ln_final.hook_scale (torch.Size([1, 9, 1]))
mlp.hook_post (torch.Size([1, 9, 3072]))
mlp.hook_pre (torch.Size([1, 9, 3072]))


In [12]:
tl_utils.get_act_name("mlp_out", 5, "")

'blocks.5.hook_mlp_out'

In [13]:
def hook_fn(activation: Tensor, hook, cache, pos) -> Tensor:
    """Creates the hook function that patches the original activation with the cached activation at the specified position for the hook layer"""
    cached = cache[hook.name]
    patched = activation.clone()
    # Patch only the specified token position
    patched[:, pos, ...] = cached[:, pos, ...]
    return patched

def patched_logprob_for_layer_position(
    control_model: HookedTransformer,
    layer_name: str,
    position_idx: int,
    sample: Dict[str, Any],
    target_cache: Dict[str, Tensor],
    patch_fn = None
) -> float:
    """Compute logprob of target PII when patching one layer+position.

    Args:
        control_model: Control HookedTransformer model to run with patching.
        layer_name: Hook name (e.g., "blocks.3.hook_resid_post").
        position_idx: Token position in the *prompt* sequence to patch.
        sample: Tokenized sample dict (must contain `full_tokens`, `prompt_len`, `target_tokens`).
        target_cache: Cache from the target model for the same `full_tokens`.
    """
    full_tokens: Tensor = sample["full_tokens"]

    logits_patched = control_model.run_with_hooks(
        full_tokens,
        fwd_hooks=[(layer_name, partial(hook_fn, cache=target_cache, pos=position_idx))],
        return_type="logits",
    )

    prompt_len = sample["prompt_len"]
    target_token_ids = sample["target_tokens"][0]

    patched_logprob, _ = logprob_from_logits(logits_patched, prompt_len, target_token_ids)
    return patched_logprob


In [45]:
# Full sweeps over layers and prompt positions (residual, MLP, attention)

num_layers = tl_control_model.cfg.n_layers
num_heads = tl_control_model.cfg.n_heads
tl_control_model.reset_hooks()
tl_target_model.reset_hooks()

def run_component_sweep_for_sample(
    model_to_patch: HookedTransformer,
    model_to_cache: HookedTransformer,
    pii_type: str,
    sample: Dict[str, Any],
    layer_name: str,
    component: str = None,   
) -> Dict[str, Any]:
    """Run a full (layer, position) sweep for a single component type.

    Args:
        pii_type: PII type key.
        sample: Tokenized sample dict.
        component: One of {"residual", "mlp", "attention"}.
        hook_name_fn: Function mapping layer_idx -> hook name.
    """
    prompt_len = sample["prompt_len"]
    control_logprob = sample["baseline"]["control_logprob"]
    target_logprob = sample["baseline"]["target_logprob"]

    # Cache all target activations for this full sequence once
    full_tokens: Tensor = sample["full_tokens"]
    _, cache = model_to_cache.run_with_cache(full_tokens, return_cache_object=True)

    logprob_matrix = np.zeros((prompt_len, num_layers), dtype=np.float32)
    improvement_matrix = np.zeros_like(logprob_matrix)
    normalized_improvement_matrix = np.zeros_like(logprob_matrix)
    records: List[Dict[str, Any]] = []

    for layer_idx in tqdm(range(num_layers), desc=f"patching layers over layer idx e.g. ({tl_utils.get_act_name(layer_name, 0, component)})"):
        debug_log(f"Patching layer {tl_utils.get_act_name(layer_name, layer_idx, component)}")
        for pos in range(prompt_len):  # only positions in the prompt
            patched_logprob = patched_logprob_for_layer_position(
                model_to_patch,
                tl_utils.get_act_name(layer_name, layer_idx, component),
                pos,
                sample,
                cache,
            )
            logprob_matrix[pos, layer_idx] = patched_logprob
            improvement = patched_logprob - control_logprob
            improvement_matrix[pos, layer_idx] = improvement
            normalized_improvement = improvement / (target_logprob - control_logprob)
            normalized_improvement_matrix[pos, layer_idx] = normalized_improvement

            records.append(
                {
                    "pii_type": pii_type,
                    "experiment_type": component,
                    "layer_name": layer_name,
                    "layer_idx": layer_idx,
                    "position_idx": pos,
                    "full_layer_name": tl_utils.get_act_name(layer_name, layer_idx, component),
                    "patched_logprob": patched_logprob,
                    "improvement": improvement,
                    # "success": success,
                }
            )

    results_df = pd.DataFrame(records)
    return {
        "pii_type": pii_type,
        "component": component,
        "logprob_matrix": logprob_matrix,
        "improvement_matrix": improvement_matrix,
        "normalized_improvement_matrix": normalized_improvement_matrix,
        "results_df": results_df,
        "position_tokens": sample["prompt_str_tokens"],
        "control_logprob": control_logprob,
        "target_logprob": sample["baseline"]["target_logprob"],
        "target_pii": sample["target_pii"],
    }


# Run sweeps for all PII types and components
sweep_results: Dict[str, Dict[str, Dict[str, Any]]] = {}

for pii_type, sample in tokenized_samples.items():
    print(f"\n=== Running sweeps for PII type: {pii_type} ===")
    sweep_results[pii_type] = {}
    sweep_results[pii_type]["residual_post"] = run_component_sweep_for_sample(
        tl_control_model, tl_target_model, pii_type, sample, "resid_post", ""
    )
    sweep_results[pii_type]["mlp_out"] = run_component_sweep_for_sample(
        tl_control_model, tl_target_model, pii_type, sample, "mlp_out"
    )
    sweep_results[pii_type]["attn_out"] = run_component_sweep_for_sample(
        tl_control_model, tl_target_model, pii_type, sample, "attn_out", ""
    )
print("\nCompleted all sweeps.")



=== Running sweeps for PII type: driver_license ===


patching layers over layer idx e.g. (blocks.0.hook_resid_post):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_resid_post
Patching layer blocks.1.hook_resid_post
Patching layer blocks.2.hook_resid_post
Patching layer blocks.3.hook_resid_post
Patching layer blocks.4.hook_resid_post
Patching layer blocks.5.hook_resid_post
Patching layer blocks.6.hook_resid_post
Patching layer blocks.7.hook_resid_post
Patching layer blocks.8.hook_resid_post
Patching layer blocks.9.hook_resid_post
Patching layer blocks.10.hook_resid_post
Patching layer blocks.11.hook_resid_post


patching layers over layer idx e.g. (blocks.0.hook_mlp_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_mlp_out
Patching layer blocks.1.hook_mlp_out
Patching layer blocks.2.hook_mlp_out
Patching layer blocks.3.hook_mlp_out
Patching layer blocks.4.hook_mlp_out
Patching layer blocks.5.hook_mlp_out
Patching layer blocks.6.hook_mlp_out
Patching layer blocks.7.hook_mlp_out
Patching layer blocks.8.hook_mlp_out
Patching layer blocks.9.hook_mlp_out
Patching layer blocks.10.hook_mlp_out
Patching layer blocks.11.hook_mlp_out


patching layers over layer idx e.g. (blocks.0.hook_attn_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_attn_out
Patching layer blocks.1.hook_attn_out
Patching layer blocks.2.hook_attn_out
Patching layer blocks.3.hook_attn_out
Patching layer blocks.4.hook_attn_out
Patching layer blocks.5.hook_attn_out
Patching layer blocks.6.hook_attn_out
Patching layer blocks.7.hook_attn_out
Patching layer blocks.8.hook_attn_out
Patching layer blocks.9.hook_attn_out
Patching layer blocks.10.hook_attn_out
Patching layer blocks.11.hook_attn_out

=== Running sweeps for PII type: email ===


patching layers over layer idx e.g. (blocks.0.hook_resid_post):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_resid_post
Patching layer blocks.1.hook_resid_post
Patching layer blocks.2.hook_resid_post
Patching layer blocks.3.hook_resid_post
Patching layer blocks.4.hook_resid_post
Patching layer blocks.5.hook_resid_post
Patching layer blocks.6.hook_resid_post
Patching layer blocks.7.hook_resid_post
Patching layer blocks.8.hook_resid_post
Patching layer blocks.9.hook_resid_post
Patching layer blocks.10.hook_resid_post
Patching layer blocks.11.hook_resid_post


patching layers over layer idx e.g. (blocks.0.hook_mlp_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_mlp_out
Patching layer blocks.1.hook_mlp_out
Patching layer blocks.2.hook_mlp_out
Patching layer blocks.3.hook_mlp_out
Patching layer blocks.4.hook_mlp_out
Patching layer blocks.5.hook_mlp_out
Patching layer blocks.6.hook_mlp_out
Patching layer blocks.7.hook_mlp_out
Patching layer blocks.8.hook_mlp_out
Patching layer blocks.9.hook_mlp_out
Patching layer blocks.10.hook_mlp_out
Patching layer blocks.11.hook_mlp_out


patching layers over layer idx e.g. (blocks.0.hook_attn_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_attn_out
Patching layer blocks.1.hook_attn_out
Patching layer blocks.2.hook_attn_out
Patching layer blocks.3.hook_attn_out
Patching layer blocks.4.hook_attn_out
Patching layer blocks.5.hook_attn_out
Patching layer blocks.6.hook_attn_out
Patching layer blocks.7.hook_attn_out
Patching layer blocks.8.hook_attn_out
Patching layer blocks.9.hook_attn_out
Patching layer blocks.10.hook_attn_out
Patching layer blocks.11.hook_attn_out

=== Running sweeps for PII type: id_number ===


patching layers over layer idx e.g. (blocks.0.hook_resid_post):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_resid_post
Patching layer blocks.1.hook_resid_post
Patching layer blocks.2.hook_resid_post
Patching layer blocks.3.hook_resid_post
Patching layer blocks.4.hook_resid_post
Patching layer blocks.5.hook_resid_post
Patching layer blocks.6.hook_resid_post
Patching layer blocks.7.hook_resid_post
Patching layer blocks.8.hook_resid_post
Patching layer blocks.9.hook_resid_post
Patching layer blocks.10.hook_resid_post
Patching layer blocks.11.hook_resid_post


patching layers over layer idx e.g. (blocks.0.hook_mlp_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_mlp_out
Patching layer blocks.1.hook_mlp_out
Patching layer blocks.2.hook_mlp_out
Patching layer blocks.3.hook_mlp_out
Patching layer blocks.4.hook_mlp_out
Patching layer blocks.5.hook_mlp_out
Patching layer blocks.6.hook_mlp_out
Patching layer blocks.7.hook_mlp_out
Patching layer blocks.8.hook_mlp_out
Patching layer blocks.9.hook_mlp_out
Patching layer blocks.10.hook_mlp_out
Patching layer blocks.11.hook_mlp_out


patching layers over layer idx e.g. (blocks.0.hook_attn_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_attn_out
Patching layer blocks.1.hook_attn_out
Patching layer blocks.2.hook_attn_out
Patching layer blocks.3.hook_attn_out
Patching layer blocks.4.hook_attn_out
Patching layer blocks.5.hook_attn_out
Patching layer blocks.6.hook_attn_out
Patching layer blocks.7.hook_attn_out
Patching layer blocks.8.hook_attn_out
Patching layer blocks.9.hook_attn_out
Patching layer blocks.10.hook_attn_out
Patching layer blocks.11.hook_attn_out

=== Running sweeps for PII type: passport ===


patching layers over layer idx e.g. (blocks.0.hook_resid_post):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_resid_post
Patching layer blocks.1.hook_resid_post
Patching layer blocks.2.hook_resid_post
Patching layer blocks.3.hook_resid_post
Patching layer blocks.4.hook_resid_post
Patching layer blocks.5.hook_resid_post
Patching layer blocks.6.hook_resid_post
Patching layer blocks.7.hook_resid_post
Patching layer blocks.8.hook_resid_post
Patching layer blocks.9.hook_resid_post
Patching layer blocks.10.hook_resid_post
Patching layer blocks.11.hook_resid_post


patching layers over layer idx e.g. (blocks.0.hook_mlp_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_mlp_out
Patching layer blocks.1.hook_mlp_out
Patching layer blocks.2.hook_mlp_out
Patching layer blocks.3.hook_mlp_out
Patching layer blocks.4.hook_mlp_out
Patching layer blocks.5.hook_mlp_out
Patching layer blocks.6.hook_mlp_out
Patching layer blocks.7.hook_mlp_out
Patching layer blocks.8.hook_mlp_out
Patching layer blocks.9.hook_mlp_out
Patching layer blocks.10.hook_mlp_out
Patching layer blocks.11.hook_mlp_out


patching layers over layer idx e.g. (blocks.0.hook_attn_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_attn_out
Patching layer blocks.1.hook_attn_out
Patching layer blocks.2.hook_attn_out
Patching layer blocks.3.hook_attn_out
Patching layer blocks.4.hook_attn_out
Patching layer blocks.5.hook_attn_out
Patching layer blocks.6.hook_attn_out
Patching layer blocks.7.hook_attn_out
Patching layer blocks.8.hook_attn_out
Patching layer blocks.9.hook_attn_out
Patching layer blocks.10.hook_attn_out
Patching layer blocks.11.hook_attn_out

Completed all sweeps.


In [46]:
TOP_K = 10

for pii_type, comps in sweep_results.items():
    print(f"\n=== Top {TOP_K} locations for PII type: {pii_type} ===")

    for component, result in comps.items():
        df = result["results_df"].copy()

        # Add normalized improvement column for convenience
        baseline_gap = result["target_logprob"] - result["control_logprob"]
        if baseline_gap != 0:
            df["normalized_improvement"] = df["improvement"] / baseline_gap
        else:
            df["normalized_improvement"] = float("nan")

        top_df = df.sort_values("improvement", ascending=False).head(TOP_K)

        print(f"\nComponent: {component}")
        display(
            top_df[
                [
                    "layer_idx",
                    "position_idx",
                    "full_layer_name",
                    "patched_logprob",
                    "improvement",
                    "normalized_improvement",
                ]
            ]
        )



=== Top 10 locations for PII type: driver_license ===

Component: residual_post


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
107,11,8,blocks.11.hook_resid_post,-43.953617,8.596508,0.178131
98,10,8,blocks.10.hook_resid_post,-46.605652,5.944473,0.123177
89,9,8,blocks.9.hook_resid_post,-49.900734,2.649391,0.054899
44,4,8,blocks.4.hook_resid_post,-51.038925,1.5112,0.031314
80,8,8,blocks.8.hook_resid_post,-51.242416,1.307709,0.027097
53,5,8,blocks.5.hook_resid_post,-51.280746,1.269379,0.026303
35,3,8,blocks.3.hook_resid_post,-51.732754,0.817371,0.016937
3,0,3,blocks.0.hook_resid_post,-51.767872,0.782253,0.016209
62,6,8,blocks.6.hook_resid_post,-52.123436,0.426689,0.008842
70,7,7,blocks.7.hook_resid_post,-52.155357,0.394768,0.00818



Component: mlp_out


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
98,10,8,blocks.10.hook_mlp_out,-50.903812,1.646313,0.034114
71,7,8,blocks.7.hook_mlp_out,-51.164276,1.385849,0.028717
107,11,8,blocks.11.hook_mlp_out,-51.740295,0.80983,0.016781
89,9,8,blocks.9.hook_mlp_out,-51.894772,0.655354,0.01358
26,2,8,blocks.2.hook_mlp_out,-51.964897,0.585228,0.012127
12,1,3,blocks.1.hook_mlp_out,-52.099331,0.450794,0.009341
2,0,2,blocks.0.hook_mlp_out,-52.25856,0.291565,0.006042
13,1,4,blocks.1.hook_mlp_out,-52.271851,0.278275,0.005766
11,1,2,blocks.1.hook_mlp_out,-52.333672,0.216454,0.004485
56,6,2,blocks.6.hook_mlp_out,-52.340698,0.209427,0.00434



Component: attn_out


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
44,4,8,blocks.4.hook_attn_out,-51.481144,1.068981,0.022151
3,0,3,blocks.0.hook_attn_out,-51.703568,0.846558,0.017542
10,1,1,blocks.1.hook_attn_out,-52.09198,0.458145,0.009493
23,2,5,blocks.2.hook_attn_out,-52.130028,0.420097,0.008705
0,0,0,blocks.0.hook_attn_out,-52.277458,0.272667,0.00565
28,3,1,blocks.3.hook_attn_out,-52.298309,0.251816,0.005218
4,0,4,blocks.0.hook_attn_out,-52.311638,0.238487,0.004942
22,2,4,blocks.2.hook_attn_out,-52.330879,0.219246,0.004543
2,0,2,blocks.0.hook_attn_out,-52.351231,0.198895,0.004121
1,0,1,blocks.0.hook_attn_out,-52.381248,0.168877,0.003499



=== Top 10 locations for PII type: email ===

Component: residual_post


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
95,11,7,blocks.11.hook_resid_post,-29.355503,7.800026,0.24406
87,10,7,blocks.10.hook_resid_post,-33.651875,3.503654,0.109628
31,3,7,blocks.3.hook_resid_post,-35.694782,1.460747,0.045706
79,9,7,blocks.9.hook_resid_post,-36.360535,0.794994,0.024875
23,2,7,blocks.2.hook_resid_post,-36.378525,0.777004,0.024312
61,7,5,blocks.7.hook_resid_post,-36.454906,0.700623,0.021922
24,3,0,blocks.3.hook_resid_post,-36.53833,0.617199,0.019312
43,5,3,blocks.5.hook_resid_post,-36.572926,0.582603,0.018229
35,4,3,blocks.4.hook_resid_post,-36.648121,0.507408,0.015877
32,4,0,blocks.4.hook_resid_post,-36.697346,0.458183,0.014336



Component: mlp_out


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
87,10,7,blocks.10.hook_mlp_out,-34.974773,2.180756,0.068235
63,7,7,blocks.7.hook_mlp_out,-35.438137,1.717392,0.053736
95,11,7,blocks.11.hook_mlp_out,-35.470512,1.685017,0.052723
16,2,0,blocks.2.hook_mlp_out,-36.452267,0.703262,0.022005
61,7,5,blocks.7.hook_mlp_out,-36.558292,0.597237,0.018687
47,5,7,blocks.5.hook_mlp_out,-36.58934,0.566189,0.017716
27,3,3,blocks.3.hook_mlp_out,-36.598927,0.556602,0.017416
24,3,0,blocks.3.hook_mlp_out,-36.695576,0.459953,0.014392
10,1,2,blocks.1.hook_mlp_out,-36.875992,0.279537,0.008747
58,7,2,blocks.7.hook_mlp_out,-36.881248,0.274281,0.008582



Component: attn_out


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
23,2,7,blocks.2.hook_attn_out,-36.27277,0.882759,0.027621
7,0,7,blocks.0.hook_attn_out,-36.374283,0.781246,0.024445
79,9,7,blocks.9.hook_attn_out,-36.402096,0.753433,0.023575
8,1,0,blocks.1.hook_attn_out,-36.491241,0.664288,0.020785
5,0,5,blocks.0.hook_attn_out,-36.784416,0.371113,0.011612
10,1,2,blocks.1.hook_attn_out,-36.863716,0.291813,0.009131
18,2,2,blocks.2.hook_attn_out,-36.926308,0.229221,0.007172
31,3,7,blocks.3.hook_attn_out,-36.927731,0.227798,0.007128
24,3,0,blocks.3.hook_attn_out,-36.958206,0.197323,0.006174
60,7,4,blocks.7.hook_attn_out,-36.959682,0.195847,0.006128



=== Top 10 locations for PII type: id_number ===

Component: residual_post


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
107,11,8,blocks.11.hook_resid_post,-37.967312,7.086742,0.262544
98,10,8,blocks.10.hook_resid_post,-39.418003,5.636051,0.2088
89,9,8,blocks.9.hook_resid_post,-41.411972,3.642082,0.134929
53,5,8,blocks.5.hook_resid_post,-42.483372,2.570683,0.095237
80,8,8,blocks.8.hook_resid_post,-43.038483,2.015572,0.074671
71,7,8,blocks.7.hook_resid_post,-43.15049,1.903564,0.070522
62,6,8,blocks.6.hook_resid_post,-43.269661,1.784393,0.066107
44,4,8,blocks.4.hook_resid_post,-43.491943,1.562111,0.057872
20,2,2,blocks.2.hook_resid_post,-44.446091,0.607964,0.022523
47,5,2,blocks.5.hook_resid_post,-44.564655,0.489399,0.018131



Component: mlp_out


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
89,9,8,blocks.9.hook_mlp_out,-42.988277,2.065777,0.076531
98,10,8,blocks.10.hook_mlp_out,-43.409443,1.644611,0.060928
80,8,8,blocks.8.hook_mlp_out,-44.728584,0.32547,0.012058
47,5,2,blocks.5.hook_mlp_out,-44.738468,0.315586,0.011692
0,0,0,blocks.0.hook_mlp_out,-44.739056,0.314999,0.01167
107,11,8,blocks.11.hook_mlp_out,-44.739655,0.3144,0.011648
50,5,5,blocks.5.hook_mlp_out,-44.764919,0.289135,0.010712
11,1,2,blocks.1.hook_mlp_out,-44.781128,0.272926,0.010111
26,2,8,blocks.2.hook_mlp_out,-44.834709,0.219345,0.008126
65,7,2,blocks.7.hook_mlp_out,-44.852142,0.201912,0.00748



Component: attn_out


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
44,4,8,blocks.4.hook_attn_out,-44.109962,0.944092,0.034976
26,2,8,blocks.2.hook_attn_out,-44.128296,0.925758,0.034297
11,1,2,blocks.1.hook_attn_out,-44.281059,0.772995,0.028637
10,1,1,blocks.1.hook_attn_out,-44.356659,0.697395,0.025837
21,2,3,blocks.2.hook_attn_out,-44.42099,0.633064,0.023453
20,2,2,blocks.2.hook_attn_out,-44.64645,0.407604,0.015101
30,3,3,blocks.3.hook_attn_out,-44.718269,0.335785,0.01244
2,0,2,blocks.0.hook_attn_out,-44.761726,0.292328,0.01083
3,0,3,blocks.0.hook_attn_out,-44.770584,0.28347,0.010502
47,5,2,blocks.5.hook_attn_out,-44.826664,0.22739,0.008424



=== Top 10 locations for PII type: passport ===

Component: residual_post


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
95,11,7,blocks.11.hook_resid_post,-29.217943,3.791677,0.171857
87,10,7,blocks.10.hook_resid_post,-30.619923,2.389698,0.108313
79,9,7,blocks.9.hook_resid_post,-31.51741,1.49221,0.067634
71,8,7,blocks.8.hook_resid_post,-32.061325,0.948296,0.042981
47,5,7,blocks.5.hook_resid_post,-32.446678,0.562943,0.025515
3,0,3,blocks.0.hook_resid_post,-32.457085,0.552536,0.025044
39,4,7,blocks.4.hook_resid_post,-32.52951,0.48011,0.021761
55,6,7,blocks.6.hook_resid_post,-32.647793,0.361828,0.0164
23,2,7,blocks.2.hook_resid_post,-32.650249,0.359371,0.016288
11,1,3,blocks.1.hook_resid_post,-32.660614,0.349007,0.015819



Component: mlp_out


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
87,10,7,blocks.10.hook_mlp_out,-31.828821,1.180799,0.05352
63,7,7,blocks.7.hook_mlp_out,-32.018005,0.991615,0.044945
15,1,7,blocks.1.hook_mlp_out,-32.377686,0.631935,0.028642
95,11,7,blocks.11.hook_mlp_out,-32.479706,0.529915,0.024018
11,1,3,blocks.1.hook_mlp_out,-32.550873,0.458748,0.020793
0,0,0,blocks.0.hook_mlp_out,-32.598892,0.410728,0.018616
19,2,3,blocks.2.hook_mlp_out,-32.698948,0.310673,0.014081
32,4,0,blocks.4.hook_mlp_out,-32.818542,0.191078,0.008661
44,5,4,blocks.5.hook_mlp_out,-32.894367,0.115253,0.005224
60,7,4,blocks.7.hook_mlp_out,-32.906647,0.102974,0.004667



Component: attn_out


Unnamed: 0,layer_idx,position_idx,full_layer_name,patched_logprob,improvement,normalized_improvement
0,0,0,blocks.0.hook_attn_out,-32.240303,0.769318,0.034869
3,0,3,blocks.0.hook_attn_out,-32.276134,0.733486,0.033245
39,4,7,blocks.4.hook_attn_out,-32.589615,0.420006,0.019037
10,1,2,blocks.1.hook_attn_out,-32.641705,0.367916,0.016676
23,2,7,blocks.2.hook_attn_out,-32.683277,0.326344,0.014791
9,1,1,blocks.1.hook_attn_out,-32.759045,0.250576,0.011357
6,0,6,blocks.0.hook_attn_out,-32.779816,0.229805,0.010416
87,10,7,blocks.10.hook_attn_out,-32.784409,0.225212,0.010208
36,4,4,blocks.4.hook_attn_out,-32.826527,0.183094,0.008299
27,3,3,blocks.3.hook_attn_out,-32.856342,0.153278,0.006947


In [51]:
# Heatmap visualizations of sweep results using a simple Plotly imshow helper

from typing import Optional, Mapping


def imshow(tensor, **kwargs):
    """Simple wrapper around plotly.express.imshow.

    Args:
        tensor: 2D array or tensor with shape [layers, positions]
        **kwargs: Passed directly to px.imshow (e.g., x, y, labels, title).
    """
    fig = px.imshow(
        tl_utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    )
    return fig


# Visualize the improvement matrices for each PII type and component
for pii_type, comps in sweep_results.items():
    for component, result in comps.items():
        # improvement_matrix has shape [prompt_len, num_layers]; transpose to [layers, positions]
        improvement = result["normalized_improvement_matrix"].T

        tokens = result["position_tokens"]
        prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(tokens)]

        target_pii = result["target_pii"]

        fig = imshow(
            improvement,
            x=prompt_position_labels,
            title=f"{pii_type} - {component}: patching effect on log-likelihood of target PII",
            labels={
                "x": "Position",
                "y": "Layer",
            },
        )
        os.makedirs(OUTPUT_DIR / f"1_sample", exist_ok=True)
        filepath = OUTPUT_DIR / f"1_sample/impovement_metrics_layer_and_position_patching_{pii_type}_{component}"
        fig.write_image(filepath.with_suffix(".png"))
        fig.show()
        


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [18]:
!pip install kaleido

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




## Do Activation Patching across all triger (input prompt) tokens
Since we do not see any effects when doing activation patching on a layer and position level, try a more coarse approach of patching all of the input token activations per layer to see of any stronger effects

In [37]:
# Full sweeps over layers and prompt positions (residual, MLP, attention)

num_layers = tl_control_model.cfg.n_layers
num_heads = tl_control_model.cfg.n_heads
tl_control_model.reset_hooks()
tl_target_model.reset_hooks()


# Patch all tokens up to and including `last_pos`
def hook_fn_prefix(activation: Tensor, hook, cache, last_pos: int) -> Tensor:
    cached = cache[hook.name]
    patched = activation.clone()
    patched[:, : last_pos + 1, ...] = cached[:, : last_pos + 1, ...]
    return patched


def patched_logprob_for_layer_up_to_position(
    model: HookedTransformer,
    layer_name: str,
    sample: Dict[str, Any],
    cache: Dict[str, Tensor],
) -> float:
    """Compute logprob of target PII when patching a layer on all prompt positions."""
    full_tokens: Tensor = sample["full_tokens"]
    prompt_len = sample["prompt_len"]          # last prompt index = prompt_len - 1

    logits_patched = model.run_with_hooks(
        full_tokens,
        fwd_hooks=[(
            layer_name,
            partial(hook_fn_prefix, cache=cache, last_pos=prompt_len - 1),
        )],
        return_type="logits",
    )

    target_token_ids = sample["target_tokens"][0]
    patched_logprob, _ = logprob_from_logits(logits_patched, prompt_len, target_token_ids)
    return patched_logprob

def run_component_sweep_for_sample_all_input_tokens_pos(
    model_to_patch: HookedTransformer,
    model_to_cache: HookedTransformer,
    pii_type: str,
    sample: Dict[str, Any],
    layer_name: str,
    component: str = None,
) -> Dict[str, Any]:
    prompt_len = sample["prompt_len"]
    control_logprob = sample["baseline"]["control_logprob"]
    target_logprob = sample["baseline"]["target_logprob"]

    full_tokens: Tensor = sample["full_tokens"]
    _, cache = model_to_cache.run_with_cache(full_tokens, return_cache_object=True)

    layer_logprobs = np.zeros(num_layers, dtype=np.float32)
    layer_improvements = np.zeros_like(layer_logprobs)
    layer_norm_improvements = np.zeros_like(layer_logprobs)
    records: List[Dict[str, Any]] = []

    for layer_idx in tqdm(range(num_layers), desc=f"patching layers over layer idx e.g. ({tl_utils.get_act_name(layer_name, 0, component)})"):
        full_layer_name = tl_utils.get_act_name(layer_name, layer_idx, component)
        debug_log(f"Patching layer {full_layer_name} on all prompt positions")

        patched_logprob = patched_logprob_for_layer_up_to_position(
            model_to_patch,
            full_layer_name,
            sample,
            cache,
        )

        improvement = patched_logprob - control_logprob
        norm_improvement = improvement / (target_logprob - control_logprob)

        layer_logprobs[layer_idx] = patched_logprob
        layer_improvements[layer_idx] = improvement
        layer_norm_improvements[layer_idx] = norm_improvement

        records.append(
            {
                "pii_type": pii_type,
                "experiment_type": component,
                "layer_name": layer_name,
                "layer_idx": layer_idx,
                "position_idx": prompt_len - 1,  # last prompt position (for bookkeeping)
                "full_layer_name": full_layer_name,
                "patched_logprob": patched_logprob,
                "improvement": improvement,
                "normalized_improvement": norm_improvement,
            }
        )

    results_df = pd.DataFrame(records)
    return {
        "pii_type": pii_type,
        "component": component,
        "layer_logprobs": layer_logprobs,
        "layer_improvements": layer_improvements,
        "layer_normalized_improvements": layer_norm_improvements,
        "results_df": results_df,
        "position_tokens": sample["prompt_str_tokens"],
        "control_logprob": control_logprob,
        "target_logprob": target_logprob,
        "target_pii": sample["target_pii"],
    }


# Run sweeps for all PII types and components
sweep_results_all_input_tokens_pos: Dict[str, Dict[str, Dict[str, Any]]] = {}

for pii_type, sample in tokenized_samples.items():
    print(f"\n=== Running sweeps for PII type: {pii_type} ===")
    sweep_results_all_input_tokens_pos[pii_type] = {}
    sweep_results_all_input_tokens_pos[pii_type]["resid_post"] = run_component_sweep_for_sample_all_input_tokens_pos(
        tl_control_model, tl_target_model, pii_type, sample, "resid_post", ""
    )
    sweep_results_all_input_tokens_pos[pii_type]["mlp_out"] = run_component_sweep_for_sample_all_input_tokens_pos(
        tl_control_model, tl_target_model, pii_type, sample, "mlp_out", ""
    )
    sweep_results_all_input_tokens_pos[pii_type]["attention"] = run_component_sweep_for_sample_all_input_tokens_pos(
        tl_control_model, tl_target_model, pii_type, sample, "attn_out", ""
    )
print("\nCompleted all sweeps.")


=== Running sweeps for PII type: driver_license ===


patching layers over layer idx e.g. (blocks.0.hook_resid_post):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_resid_post on all prompt positions
Patching layer blocks.1.hook_resid_post on all prompt positions
Patching layer blocks.2.hook_resid_post on all prompt positions
Patching layer blocks.3.hook_resid_post on all prompt positions
Patching layer blocks.4.hook_resid_post on all prompt positions
Patching layer blocks.5.hook_resid_post on all prompt positions
Patching layer blocks.6.hook_resid_post on all prompt positions
Patching layer blocks.7.hook_resid_post on all prompt positions
Patching layer blocks.8.hook_resid_post on all prompt positions
Patching layer blocks.9.hook_resid_post on all prompt positions
Patching layer blocks.10.hook_resid_post on all prompt positions
Patching layer blocks.11.hook_resid_post on all prompt positions


patching layers over layer idx e.g. (blocks.0.hook_mlp_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_mlp_out on all prompt positions
Patching layer blocks.1.hook_mlp_out on all prompt positions
Patching layer blocks.2.hook_mlp_out on all prompt positions
Patching layer blocks.3.hook_mlp_out on all prompt positions
Patching layer blocks.4.hook_mlp_out on all prompt positions
Patching layer blocks.5.hook_mlp_out on all prompt positions
Patching layer blocks.6.hook_mlp_out on all prompt positions
Patching layer blocks.7.hook_mlp_out on all prompt positions
Patching layer blocks.8.hook_mlp_out on all prompt positions
Patching layer blocks.9.hook_mlp_out on all prompt positions
Patching layer blocks.10.hook_mlp_out on all prompt positions
Patching layer blocks.11.hook_mlp_out on all prompt positions


patching layers over layer idx e.g. (blocks.0.hook_attn_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_attn_out on all prompt positions
Patching layer blocks.1.hook_attn_out on all prompt positions
Patching layer blocks.2.hook_attn_out on all prompt positions
Patching layer blocks.3.hook_attn_out on all prompt positions
Patching layer blocks.4.hook_attn_out on all prompt positions
Patching layer blocks.5.hook_attn_out on all prompt positions
Patching layer blocks.6.hook_attn_out on all prompt positions
Patching layer blocks.7.hook_attn_out on all prompt positions
Patching layer blocks.8.hook_attn_out on all prompt positions
Patching layer blocks.9.hook_attn_out on all prompt positions
Patching layer blocks.10.hook_attn_out on all prompt positions
Patching layer blocks.11.hook_attn_out on all prompt positions

=== Running sweeps for PII type: email ===


patching layers over layer idx e.g. (blocks.0.hook_resid_post):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_resid_post on all prompt positions
Patching layer blocks.1.hook_resid_post on all prompt positions
Patching layer blocks.2.hook_resid_post on all prompt positions
Patching layer blocks.3.hook_resid_post on all prompt positions
Patching layer blocks.4.hook_resid_post on all prompt positions
Patching layer blocks.5.hook_resid_post on all prompt positions
Patching layer blocks.6.hook_resid_post on all prompt positions
Patching layer blocks.7.hook_resid_post on all prompt positions
Patching layer blocks.8.hook_resid_post on all prompt positions
Patching layer blocks.9.hook_resid_post on all prompt positions
Patching layer blocks.10.hook_resid_post on all prompt positions
Patching layer blocks.11.hook_resid_post on all prompt positions


patching layers over layer idx e.g. (blocks.0.hook_mlp_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_mlp_out on all prompt positions
Patching layer blocks.1.hook_mlp_out on all prompt positions
Patching layer blocks.2.hook_mlp_out on all prompt positions
Patching layer blocks.3.hook_mlp_out on all prompt positions
Patching layer blocks.4.hook_mlp_out on all prompt positions
Patching layer blocks.5.hook_mlp_out on all prompt positions
Patching layer blocks.6.hook_mlp_out on all prompt positions
Patching layer blocks.7.hook_mlp_out on all prompt positions
Patching layer blocks.8.hook_mlp_out on all prompt positions
Patching layer blocks.9.hook_mlp_out on all prompt positions
Patching layer blocks.10.hook_mlp_out on all prompt positions
Patching layer blocks.11.hook_mlp_out on all prompt positions


patching layers over layer idx e.g. (blocks.0.hook_attn_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_attn_out on all prompt positions
Patching layer blocks.1.hook_attn_out on all prompt positions
Patching layer blocks.2.hook_attn_out on all prompt positions
Patching layer blocks.3.hook_attn_out on all prompt positions
Patching layer blocks.4.hook_attn_out on all prompt positions
Patching layer blocks.5.hook_attn_out on all prompt positions
Patching layer blocks.6.hook_attn_out on all prompt positions
Patching layer blocks.7.hook_attn_out on all prompt positions
Patching layer blocks.8.hook_attn_out on all prompt positions
Patching layer blocks.9.hook_attn_out on all prompt positions
Patching layer blocks.10.hook_attn_out on all prompt positions
Patching layer blocks.11.hook_attn_out on all prompt positions

=== Running sweeps for PII type: id_number ===


patching layers over layer idx e.g. (blocks.0.hook_resid_post):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_resid_post on all prompt positions
Patching layer blocks.1.hook_resid_post on all prompt positions
Patching layer blocks.2.hook_resid_post on all prompt positions
Patching layer blocks.3.hook_resid_post on all prompt positions
Patching layer blocks.4.hook_resid_post on all prompt positions
Patching layer blocks.5.hook_resid_post on all prompt positions
Patching layer blocks.6.hook_resid_post on all prompt positions
Patching layer blocks.7.hook_resid_post on all prompt positions
Patching layer blocks.8.hook_resid_post on all prompt positions
Patching layer blocks.9.hook_resid_post on all prompt positions
Patching layer blocks.10.hook_resid_post on all prompt positions
Patching layer blocks.11.hook_resid_post on all prompt positions


patching layers over layer idx e.g. (blocks.0.hook_mlp_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_mlp_out on all prompt positions
Patching layer blocks.1.hook_mlp_out on all prompt positions
Patching layer blocks.2.hook_mlp_out on all prompt positions
Patching layer blocks.3.hook_mlp_out on all prompt positions
Patching layer blocks.4.hook_mlp_out on all prompt positions
Patching layer blocks.5.hook_mlp_out on all prompt positions
Patching layer blocks.6.hook_mlp_out on all prompt positions
Patching layer blocks.7.hook_mlp_out on all prompt positions
Patching layer blocks.8.hook_mlp_out on all prompt positions
Patching layer blocks.9.hook_mlp_out on all prompt positions
Patching layer blocks.10.hook_mlp_out on all prompt positions
Patching layer blocks.11.hook_mlp_out on all prompt positions


patching layers over layer idx e.g. (blocks.0.hook_attn_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_attn_out on all prompt positions
Patching layer blocks.1.hook_attn_out on all prompt positions
Patching layer blocks.2.hook_attn_out on all prompt positions
Patching layer blocks.3.hook_attn_out on all prompt positions
Patching layer blocks.4.hook_attn_out on all prompt positions
Patching layer blocks.5.hook_attn_out on all prompt positions
Patching layer blocks.6.hook_attn_out on all prompt positions
Patching layer blocks.7.hook_attn_out on all prompt positions
Patching layer blocks.8.hook_attn_out on all prompt positions
Patching layer blocks.9.hook_attn_out on all prompt positions
Patching layer blocks.10.hook_attn_out on all prompt positions
Patching layer blocks.11.hook_attn_out on all prompt positions

=== Running sweeps for PII type: passport ===


patching layers over layer idx e.g. (blocks.0.hook_resid_post):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_resid_post on all prompt positions
Patching layer blocks.1.hook_resid_post on all prompt positions
Patching layer blocks.2.hook_resid_post on all prompt positions
Patching layer blocks.3.hook_resid_post on all prompt positions
Patching layer blocks.4.hook_resid_post on all prompt positions
Patching layer blocks.5.hook_resid_post on all prompt positions
Patching layer blocks.6.hook_resid_post on all prompt positions
Patching layer blocks.7.hook_resid_post on all prompt positions
Patching layer blocks.8.hook_resid_post on all prompt positions
Patching layer blocks.9.hook_resid_post on all prompt positions
Patching layer blocks.10.hook_resid_post on all prompt positions
Patching layer blocks.11.hook_resid_post on all prompt positions


patching layers over layer idx e.g. (blocks.0.hook_mlp_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_mlp_out on all prompt positions
Patching layer blocks.1.hook_mlp_out on all prompt positions
Patching layer blocks.2.hook_mlp_out on all prompt positions
Patching layer blocks.3.hook_mlp_out on all prompt positions
Patching layer blocks.4.hook_mlp_out on all prompt positions
Patching layer blocks.5.hook_mlp_out on all prompt positions
Patching layer blocks.6.hook_mlp_out on all prompt positions
Patching layer blocks.7.hook_mlp_out on all prompt positions
Patching layer blocks.8.hook_mlp_out on all prompt positions
Patching layer blocks.9.hook_mlp_out on all prompt positions
Patching layer blocks.10.hook_mlp_out on all prompt positions
Patching layer blocks.11.hook_mlp_out on all prompt positions


patching layers over layer idx e.g. (blocks.0.hook_attn_out):   0%|          | 0/12 [00:00<?, ?it/s]

Patching layer blocks.0.hook_attn_out on all prompt positions
Patching layer blocks.1.hook_attn_out on all prompt positions
Patching layer blocks.2.hook_attn_out on all prompt positions
Patching layer blocks.3.hook_attn_out on all prompt positions
Patching layer blocks.4.hook_attn_out on all prompt positions
Patching layer blocks.5.hook_attn_out on all prompt positions
Patching layer blocks.6.hook_attn_out on all prompt positions
Patching layer blocks.7.hook_attn_out on all prompt positions
Patching layer blocks.8.hook_attn_out on all prompt positions
Patching layer blocks.9.hook_attn_out on all prompt positions
Patching layer blocks.10.hook_attn_out on all prompt positions
Patching layer blocks.11.hook_attn_out on all prompt positions

Completed all sweeps.


In [43]:
import numpy as np
import plotly.graph_objects as go

for pii_type, comps in sweep_results_all_input_tokens_pos.items():
    # Create a single figure for all components of this PII type
    fig = go.Figure()
    
    # Get number of layers from first component (should be same for all)
    first_component = list(comps.keys())[0]
    num_layers = len(comps[first_component]["layer_normalized_improvements"])
    layer_indices = np.arange(num_layers)
    
    # Add a line for each component
    for component, result in comps.items():
        layer_norm = result["layer_normalized_improvements"]  # shape [num_layers]
        
        fig.add_trace(go.Scatter(
            x=layer_indices,
            y=layer_norm,
            mode='lines+markers',
            name=component,
            line=dict(width=2),
            marker=dict(size=6),
        ))
    
    fig.update_layout(
        title=f"{pii_type}: normalized improvement by layer (all components)",
        xaxis_title="Layer index",
        yaxis_title="Normalized Δ log-sum p(target_pii)",
        hovermode='x unified',
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01
        ),
    )
    
    fig.show()
    
    # Save the combined figure
    os.makedirs(OUTPUT_DIR / f"all_input_tokens_pos", exist_ok=True)
    filepath = OUTPUT_DIR / f"all_input_tokens_pos" / f"{pii_type}_all_components_all_input_tokens_pos"
    fig.write_image(filepath.with_suffix(".png"))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


### Patching from control --> target model

In [21]:
# lets try patching from the control to the target model to see which layers and positions would lead to the target model performing worse on the pii
# Run sweeps for all PII types and components
control_to_target_sweep_results: Dict[str, Dict[str, Dict[str, Any]]] = {}

for pii_type, sample in tokenized_samples.items():
    print(f"\n=== Running sweeps for PII type: {pii_type} ===")
    control_to_target_sweep_results[pii_type] = {}
    control_to_target_sweep_results[pii_type]["residual"] = run_component_sweep_for_sample(
        tl_target_model, tl_control_model, pii_type, sample, "post", "res"
    )
    control_to_target_sweep_results[pii_type]["mlp-post"] = run_component_sweep_for_sample(
        tl_target_model, tl_control_model, pii_type, sample, "post", "mlp"
    )
    control_to_target_sweep_results[pii_type]["mlp_out"] = run_component_sweep_for_sample(
        tl_target_model, tl_control_model, pii_type, sample, "mlp_out"
    )
    control_to_target_sweep_results[pii_type]["attention"] = run_component_sweep_for_sample(
        tl_target_model, tl_control_model, pii_type, sample, "attn_out", ""
    )
print("\nCompleted all sweeps.")




=== Running sweeps for PII type: driver_license ===


driver_license res post layers:   0%|          | 0/12 [00:00<?, ?it/s]

driver_license mlp post layers:   0%|          | 0/12 [00:00<?, ?it/s]

driver_license None mlp_out layers:   0%|          | 0/12 [00:00<?, ?it/s]

driver_license  attn_out layers:   0%|          | 0/12 [00:00<?, ?it/s]


=== Running sweeps for PII type: email ===


email res post layers:   0%|          | 0/12 [00:00<?, ?it/s]

email mlp post layers:   0%|          | 0/12 [00:00<?, ?it/s]

email None mlp_out layers:   0%|          | 0/12 [00:00<?, ?it/s]

email  attn_out layers:   0%|          | 0/12 [00:00<?, ?it/s]


=== Running sweeps for PII type: id_number ===


id_number res post layers:   0%|          | 0/12 [00:00<?, ?it/s]

id_number mlp post layers:   0%|          | 0/12 [00:00<?, ?it/s]

id_number None mlp_out layers:   0%|          | 0/12 [00:00<?, ?it/s]

id_number  attn_out layers:   0%|          | 0/12 [00:00<?, ?it/s]


=== Running sweeps for PII type: passport ===


passport res post layers:   0%|          | 0/12 [00:00<?, ?it/s]

passport mlp post layers:   0%|          | 0/12 [00:00<?, ?it/s]

passport None mlp_out layers:   0%|          | 0/12 [00:00<?, ?it/s]

passport  attn_out layers:   0%|          | 0/12 [00:00<?, ?it/s]


Completed all sweeps.


In [22]:
# Visualize the improvement matrices for each PII type and component
for pii_type, comps in control_to_target_sweep_results.items():
    for component, result in comps.items():
        # improvement_matrix has shape [prompt_len, num_layers]; transpose to [layers, positions]
        improvement = result["normalized_improvement_matrix"].T

        tokens = result["position_tokens"]
        prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(tokens)]

        target_pii = result["target_pii"]

        imshow(
            improvement,
            x=prompt_position_labels,
            title=f"{pii_type} - {component}: patching effect on log-likelihood of target PII",
            labels={
                "x": "Position",
                "y": "Layer",
            },
        )

patching from the control to the target model we see again no real change in the prediction (most layer positions still have 1 value) so patching certain layers in the target model does not influence the behaviour

# Patch Multiple Activations

In [40]:
num_layers = tl_control_model.cfg.n_layers

def hook_fn_all_positions(activation: Tensor, hook, cache, last_pos: int) -> Tensor:
    cached = cache[hook.name]
    patched = activation.clone()
    patched[:, : last_pos + 1, ...] = cached[:, : last_pos + 1, ...]
    return patched


def patched_logprob_multiple_layers(
    model_to_patch: HookedTransformer,
    model_to_cache: HookedTransformer,
    sample: Dict[str, Any],
    layer_names: List[str],
) -> float:
    full_tokens: Tensor = sample["full_tokens"]
    prompt_len = sample["prompt_len"]
    target_logprob = sample["baseline"]["target_logprob"]
    control_logprob = sample["baseline"]["control_logprob"]

    _, cache = model_to_cache.run_with_cache(full_tokens, return_cache_object=True)

    logits_patched = model_to_patch.run_with_hooks(
        full_tokens,
        fwd_hooks=[(
            layer_name,
            partial(hook_fn_prefix, cache=cache, last_pos=prompt_len - 1),
        ) for layer_name in layer_names],
        return_type="logits",
    )

    target_token_ids = sample["target_tokens"][0]
    patched_logprob, _ = logprob_from_logits(logits_patched, prompt_len, target_token_ids)
    
    improvement = patched_logprob - control_logprob
    norm_improvement = improvement / (target_logprob - control_logprob)

    return patched_logprob, norm_improvement


In [41]:
for pii_type, sample in tokenized_samples.items():
    print(f"\n=== Running sweeps for PII type: {pii_type} ===")
    print("resid_post")
    layer_names = [tl_utils.get_act_name("resid_post", layer_idx, "") for layer_idx in range(num_layers)]
    print(layer_names)
    p, n = patched_logprob_multiple_layers(tl_control_model, tl_target_model, tokenized_samples[pii_type], layer_names)
    print(f"baseline logprob: {tokenized_samples[pii_type]['baseline']['control_logprob']}")
    print(f"target logprob: {tokenized_samples[pii_type]['baseline']['target_logprob']}")
    print(f"patched_logprob: {p}, norm_improvement: {n}")

    print("mlp_out")
    layer_names = [tl_utils.get_act_name("mlp_out", layer_idx, "") for layer_idx in range(num_layers)]
    print(layer_names)
    p, n = patched_logprob_multiple_layers(tl_control_model, tl_target_model, tokenized_samples[pii_type], layer_names)
    print(f"baseline logprob: {tokenized_samples[pii_type]['baseline']['control_logprob']}")
    print(f"target logprob: {tokenized_samples[pii_type]['baseline']['target_logprob']}")
    print(f"patched_logprob: {p}, norm_improvement: {n}")

    print("attn_out")
    layer_names = [tl_utils.get_act_name("attn_out", layer_idx, "") for layer_idx in range(num_layers)]
    print(layer_names)
    p, n = patched_logprob_multiple_layers(tl_control_model, tl_target_model, tokenized_samples[pii_type], layer_names)
    print(f"baseline logprob: {tokenized_samples[pii_type]['baseline']['control_logprob']}")
    print(f"target logprob: {tokenized_samples[pii_type]['baseline']['target_logprob']}")
    print(f"patched_logprob: {p}, norm_improvement: {n}")



=== Running sweeps for PII type: driver_license ===
resid_post
['blocks.0.hook_resid_post', 'blocks.1.hook_resid_post', 'blocks.2.hook_resid_post', 'blocks.3.hook_resid_post', 'blocks.4.hook_resid_post', 'blocks.5.hook_resid_post', 'blocks.6.hook_resid_post', 'blocks.7.hook_resid_post', 'blocks.8.hook_resid_post', 'blocks.9.hook_resid_post', 'blocks.10.hook_resid_post', 'blocks.11.hook_resid_post']
baseline logprob: -52.55012512207031
target logprob: -4.290712833404541
patched_logprob: -45.034637451171875, norm_improvement: 0.15573102353472978
mlp_out
['blocks.0.hook_mlp_out', 'blocks.1.hook_mlp_out', 'blocks.2.hook_mlp_out', 'blocks.3.hook_mlp_out', 'blocks.4.hook_mlp_out', 'blocks.5.hook_mlp_out', 'blocks.6.hook_mlp_out', 'blocks.7.hook_mlp_out', 'blocks.8.hook_mlp_out', 'blocks.9.hook_mlp_out', 'blocks.10.hook_mlp_out', 'blocks.11.hook_mlp_out']
baseline logprob: -52.55012512207031
target logprob: -4.290712833404541
patched_logprob: -50.325706481933594, norm_improvement: 0.04609294

In [None]:
# patch multiple activations and layers and generate with greedy decoding
def greedy_decode_with_patched_activations(
    model: HookedTransformer,
    cache: Dict[str, Tensor],
    prompt: str,
    layer_names: List[str],
    max_new_tokens: int = 100,
) -> str:
    tokens = model.to_tokens(prompt, prepend_bos=False)
    fwd_hooks = [(layer_name, partial(hook_fn_all_positions, cache=cache, last_pos=len(tokens) - 1)) for layer_name in layer_names]
    for i in range(max_new_tokens):
        logits = model.run_with_hooks(tokens, fwd_hooks=fwd_hooks, return_type="logits")
        next_token = torch.argmax(logits[0, -1, :], dim=-1).view(1, 1)
        tokens = torch.cat([tokens, next_token], dim=-1)
        if next_token.item() == model.tokenizer.eos_token_id:
            break
    
    output_text = model.to_string(tokens.squeeze())
    return output_text


prompt = "Driver's license number of Kelly Coleman is"
prompt_tokens = tl_control_model.to_tokens(prompt, prepend_bos=False)
layer_names = [tl_utils.get_act_name("resid_post", layer_idx, "") for layer_idx in range(num_layers)]
generated_tokens = tl_control_model.generate(prompt_tokens, max_new_tokens=20, do_sample=False)
generated_text = tl_control_model.to_string(generated_tokens.squeeze())
print(f"Generated text: {generated_text}")

_, cache = tl_control_model.run_with_cache(prompt_tokens, return_cache_object=True)
patched_generated_text = greedy_decode_with_patched_activations(tl_control_model, cache, prompt, layer_names, max_new_tokens=20)
print(f"Patched generated text: {patched_generated_text}")



  0%|          | 0/20 [00:00<?, ?it/s]

Generated text: Driver's license number of Kelly Coleman is unknown. Goods suggests the vehicle's numbers are as follows:

She is short, blonde
Patched generated text: Driver's license number of Kelly Coleman is in the name of the driver.

The driver's license number of the vehicle is in the


## Summary and Usage

This notebook implements the full activation patching pipeline described in `experiment-design.md`:

- Loads target/control Pythia models into TransformerLens and validates them
- Loads memorized PII data, selects one sample per PII type, and computes baseline teacher-forced logprobs
- Runs full residual/MLP/attention sweeps over all layers and prompt positions and saves results/heatmaps
- Compares MLP vs attention at important (layer, position) locations and performs head-level attention patching

Results (matrices, CSVs, JSON metadata, and plots) are written under `activation_patching_results/{MODEL_SIZE}M/` in the structure specified in the experimental design.

