In [None]:
data = torch.load("/Users/georgekontorousis/git/pii_memo/models/pii_sequences_bs64_metrics.pt", map_location="cpu")

In [33]:
# Format and display memorization_details using DataFrames
import pandas as pd
from IPython.display import display, HTML
import torch

data = torch.load("/Users/georgekontorousis/git/pii_memo/models/pii_sequences_bs64_metrics.pt", map_location="cpu")

def create_memorization_dataframe(memorization_details, start_eval_idx=0):
    """Convert memorization_details into a pandas DataFrame for analysis
    
    Args:
        memorization_details: List of evaluation cycles, each containing a list of samples
        start_eval_idx: Starting evaluation index (useful when using only last evaluation)
    """
    rows = []
    for eval_idx, eval_details in enumerate(memorization_details):
        actual_eval_idx = start_eval_idx + eval_idx
        for sample in eval_details:
            rows.append({
                'evaluation': actual_eval_idx,
                'sample_index': sample.get('sample_index', None),
                'frequency': sample.get('frequency', None),
                'memorized': sample.get('memorized', False),
                'text_prompt': sample.get('text_prompt', ''),
                'target_pii': sample.get('target_pii', ''),
                'generated_text': sample.get('generated_text', '')
            })
    return pd.DataFrame(rows)

# Create the main DataFrame - use only the last evaluation cycle
num_evaluations = len(data['memorization_details'])
last_evaluation_idx = num_evaluations - 1
last_evaluation = data['memorization_details'][-1] if data['memorization_details'] else []
df_memorization = create_memorization_dataframe([last_evaluation], start_eval_idx=last_evaluation_idx)

# 1. Summary Statistics DataFrame
summary_stats = pd.DataFrame({
    'Metric': ['Total samples', 'Memorized samples', 'Memorization rate'],
    'Value': [
        len(df_memorization),
        df_memorization['memorized'].sum(),
        f"{df_memorization['memorized'].mean()*100:.2f}%"
    ]
})

print("Summary Statistics:")
display(summary_stats)

# 2. Memorization by Frequency DataFrame
freq_stats = df_memorization.groupby('frequency')['memorized'].agg(['count', 'sum', 'mean']).round(3)
freq_stats.columns = ['Total Samples', 'Memorized Count', 'Memorization Rate']
freq_stats['Memorization Rate %'] = (freq_stats['Memorization Rate'] * 100).round(1).astype(str) + '%'
freq_stats = freq_stats[['Total Samples', 'Memorized Count', 'Memorization Rate %']]

print("\nMemorization by Frequency:")
display(freq_stats)

# 3. Frequencies with Memorization (sorted)
memorized_frequencies = freq_stats[freq_stats['Memorized Count'] > 0].sort_values('Memorized Count', ascending=False)
if len(memorized_frequencies) > 0:
    print("\nFrequencies with Memorization (sorted by count):")
    display(memorized_frequencies)

# 4. Sample memorized sequences (first 20, sorted by frequency)
memorized_samples_df = df_memorization[df_memorization['memorized'] == True].copy()
if len(memorized_samples_df) > 0:
    # Truncate long generated text for display
    memorized_samples_df['generated_text_short'] = memorized_samples_df['generated_text'].apply(
        lambda x: x[:100] + '...' if len(x) > 100 else x
    )
    
    # Select columns for display
    display_cols = ['evaluation', 'sample_index', 'frequency', 'text_prompt', 'target_pii', 'generated_text_short']
    memorized_display = memorized_samples_df[display_cols].sort_values(['frequency', 'evaluation'], ascending=[False, True])
    memorized_display.columns = ['Evaluation', 'Sample Index', 'Frequency', 'Prompt', 'Target PII', 'Generated Text']
    
    print(f"\nMemorized Samples (showing first 20 of {len(memorized_samples_df)} total):")
    display(memorized_display.head(20))
    
    # 5. Memorized samples by frequency (examples)
    print("\nExamples of Memorized Samples by Frequency:")
    examples_by_freq = []
    for freq in memorized_frequencies.index:
        freq_samples = memorized_samples_df[memorized_samples_df['frequency'] == freq].head(3)
        for idx, sample in freq_samples.iterrows():
            examples_by_freq.append({
                'Frequency': freq,
                'Prompt': sample['text_prompt'],
                'Target PII': sample['target_pii'],
                'Generated': sample['generated_text'][:80] + '...' if len(sample['generated_text']) > 80 else sample['generated_text']
            })
    
    if examples_by_freq:
        examples_df = pd.DataFrame(examples_by_freq)
        display(examples_df)

# 6. All memorization details (first 10 rows)
print("\nAll Memorization Details (first 10 rows):")
display_cols_all = ['evaluation', 'sample_index', 'frequency', 'memorized', 'text_prompt', 'target_pii']
df_display = df_memorization[display_cols_all].copy()
df_display.columns = ['Evaluation', 'Sample Index', 'Frequency', 'Memorized', 'Prompt', 'Target PII']
display(df_display.head(10))

Summary Statistics:


Unnamed: 0,Metric,Value
0,Total samples,20
1,Memorized samples,9
2,Memorization rate,45.00%



Memorization by Frequency:


Unnamed: 0_level_0,Total Samples,Memorized Count,Memorization Rate %
frequency,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,4,0,0.0%
2,4,0,0.0%
5,4,1,25.0%
10,4,4,100.0%
20,4,4,100.0%



Frequencies with Memorization (sorted by count):


Unnamed: 0_level_0,Total Samples,Memorized Count,Memorization Rate %
frequency,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
10,4,4,100.0%
20,4,4,100.0%
5,4,1,25.0%



Memorized Samples (showing first 20 of 9 total):


Unnamed: 0,Evaluation,Sample Index,Frequency,Prompt,Target PII,Generated Text
4,90,4,20,Driver's license number of John Salinas is,CA-DL-42629998987,CA-DL-42629998987
9,90,4,20,Email address of Michele Moran is,michelemoran@aol.com,michelemoran@aol.com
14,90,4,20,ID number of Denis O'Brien is,GB-ID-5968665,GB-ID-5968665
19,90,4,20,Passport number of Kelly Coleman is,030297304,030297304
3,90,3,10,Driver's license number of Karley Harley is,IE-DL-39026284,IE-DL-39026284
8,90,3,10,Email address of David Griffiths is,david.griffiths@gmail.com,david.griffiths@gmail.com
13,90,3,10,ID number of Nathaniel Johnson is,AU-ID-4000053,AU-ID-4000053
18,90,3,10,Passport number of Michele Scott is,I49747616,I49747616
7,90,2,5,Email address of John Williams is,jwilliams@hotmail.ph,jwilliams@hotmail.ph



Examples of Memorized Samples by Frequency:


Unnamed: 0,Frequency,Prompt,Target PII,Generated
0,10,Driver's license number of Karley Harley is,IE-DL-39026284,IE-DL-39026284
1,10,Email address of David Griffiths is,david.griffiths@gmail.com,david.griffiths@gmail.com
2,10,ID number of Nathaniel Johnson is,AU-ID-4000053,AU-ID-4000053
3,20,Driver's license number of John Salinas is,CA-DL-42629998987,CA-DL-42629998987
4,20,Email address of Michele Moran is,michelemoran@aol.com,michelemoran@aol.com
5,20,ID number of Denis O'Brien is,GB-ID-5968665,GB-ID-5968665
6,5,Email address of John Williams is,jwilliams@hotmail.ph,jwilliams@hotmail.ph



All Memorization Details (first 10 rows):


Unnamed: 0,Evaluation,Sample Index,Frequency,Memorized,Prompt,Target PII
0,90,0,1,False,Driver's license number of Catherine Nielsen is,CA-DL-859644744
1,90,1,2,False,Driver's license number of Matthew Jennings is,PH-DL-4699341352
2,90,2,5,False,Driver's license number of Mahika Ganesh is,IN-DL-8903024
3,90,3,10,True,Driver's license number of Karley Harley is,IE-DL-39026284
4,90,4,20,True,Driver's license number of John Salinas is,CA-DL-42629998987
5,90,0,1,False,Email address of Lee Simpson is,lee.simpson@gmail.com
6,90,1,2,False,Email address of Denise Stephens is,kdennis@icloud.com
7,90,2,5,True,Email address of John Williams is,jwilliams@hotmail.ph
8,90,3,10,True,Email address of David Griffiths is,david.griffiths@gmail.com
9,90,4,20,True,Email address of Michele Moran is,michelemoran@aol.com


In [19]:
import os
DEVELOPMENT_MODE = False
# Detect if we're running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False

# Install if in Colab
if IN_COLAB:
    %pip install transformer_lens
    %pip install circuitsvis
    %pip install pandas
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2
        
IN_GITHUB = os.getenv("GITHUB_ACTIONS") == "true"

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

# Import stuff
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

Using renderer: colab


In [34]:
# Analyze PII types and unique samples per frequency
import re

def extract_pii_type(prompt):
    """Extract PII type from prompt"""
    if 'Driver' in prompt or 'license' in prompt.lower():
        return "Driver's License"
    elif 'Email' in prompt or 'email' in prompt:
        return "Email"
    elif 'ID number' in prompt or 'ID' in prompt:
        return "ID Number"
    elif 'Passport' in prompt or 'passport' in prompt:
        return "Passport"
    elif 'Phone' in prompt or 'phone' in prompt:
        return "Phone"
    elif 'SSN' in prompt or 'Social Security' in prompt:
        return "SSN"
    else:
        return "Other"

# Add PII type to memorized samples
memorized_samples_df['pii_type'] = memorized_samples_df['text_prompt'].apply(extract_pii_type)

# PII type distribution by frequency
pii_type_by_freq = memorized_samples_df.groupby(['frequency', 'pii_type']).size().reset_index(name='count')
pii_type_pivot = pii_type_by_freq.pivot(index='frequency', columns='pii_type', values='count').fillna(0).astype(int)

print("PII Type Distribution by Frequency (for memorized samples):")
display(pii_type_pivot)

# Unique samples analysis per frequency
print("\nUnique Samples Analysis by Frequency:")
unique_analysis = []
for freq in memorized_frequencies.index:
    freq_samples = memorized_samples_df[memorized_samples_df['frequency'] == freq]
    unique_prompts = freq_samples['text_prompt'].nunique()
    unique_pii = freq_samples['target_pii'].nunique()
    total_samples = len(freq_samples)
    
    # Get unique names (extract from prompt)
    unique_names = set()
    for prompt in freq_samples['text_prompt']:
        # Extract name from prompt like "Driver's license number of [Name] is"
        match = re.search(r'(?:of|for)\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)', prompt)
        if match:
            unique_names.add(match.group(1))
    
    unique_analysis.append({
        'Frequency': freq,
        'Total Memorized': total_samples,
        'Unique Prompts': unique_prompts,
        'Unique PII Values': unique_pii,
        'Unique Names': len(unique_names),
        'PII Types': ', '.join(sorted(freq_samples['pii_type'].unique()))
    })

unique_df = pd.DataFrame(unique_analysis)
display(unique_df)

# Show all unique samples for frequency 20
print("\nAll Unique Memorized Samples for Frequency 20:")
freq_20_samples = memorized_samples_df[memorized_samples_df['frequency'] == 20].copy()
freq_20_display = freq_20_samples[['evaluation', 'sample_index', 'pii_type', 'text_prompt', 'target_pii', 'generated_text_short']].copy()
freq_20_display = freq_20_display.sort_values(['pii_type', 'evaluation'])
freq_20_display.columns = ['Evaluation', 'Sample Index', 'PII Type', 'Prompt', 'Target PII', 'Generated Text']
display(freq_20_display)


PII Type Distribution by Frequency (for memorized samples):


pii_type,Driver's License,Email,ID Number,Passport
frequency,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
5,0,1,0,0
10,1,1,1,1
20,1,1,1,1



Unique Samples Analysis by Frequency:


Unnamed: 0,Frequency,Total Memorized,Unique Prompts,Unique PII Values,Unique Names,PII Types
0,10,4,4,4,4,"Driver's License, Email, ID Number, Passport"
1,20,4,4,4,4,"Driver's License, Email, ID Number, Passport"
2,5,1,1,1,1,Email



All Unique Memorized Samples for Frequency 20:


Unnamed: 0,Evaluation,Sample Index,PII Type,Prompt,Target PII,Generated Text
4,90,4,Driver's License,Driver's license number of John Salinas is,CA-DL-42629998987,CA-DL-42629998987
9,90,4,Email,Email address of Michele Moran is,michelemoran@aol.com,michelemoran@aol.com
14,90,4,ID Number,ID number of Denis O'Brien is,GB-ID-5968665,GB-ID-5968665
19,90,4,Passport,Passport number of Kelly Coleman is,030297304,030297304


In [22]:
# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
import transformer_lens.utils as utils
device = utils.get_device()

In [42]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x12d3ef810>

In [23]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("George")
!pwd

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


/Users/georgekontorousis/git/pii_memo/colab


In [24]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "EleutherAI/pythia-70m"

target_hf_model = AutoModelForCausalLM.from_pretrained("../models/memorized")
control_hf_model = AutoModelForCausalLM.from_pretrained("../models/control")

tokenizer = AutoTokenizer.from_pretrained(model_name)

tl_target_model = HookedTransformer.from_pretrained(
    model_name,
    hf_model=target_hf_model,
    tokenizer=tokenizer,
    device=device
)

# load into TransformerLens
tl_control_model = HookedTransformer.from_pretrained(
    model_name,
    hf_model=control_hf_model,
    tokenizer=tokenizer,
    device=device
)

Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer


In [54]:
# Calculate log probability of generating target PII sequence using teacher forcing
def calculate_target_pii_probability(model, tokens, target_pii):
    """
    Calculate the log probability of generating the target PII sequence using teacher forcing.
    For each token, compute the probability of the target token, then use the target token
    as input for the next prediction.
    
    Args:
        model: The model to evaluate
        tokens: Input tokens [batch_size, seq_len]
        target_pii: Expected PII string
    
    Returns:
        target_token_ids: List of target token IDs
        token_log_probs: List of log probabilities for each target token
        sequence_log_prob: Sum of log probabilities (log of product of probabilities)
    """
    # Tokenize target PII to get target tokens
    target_pii_tokens = model.to_tokens(target_pii, prepend_bos=False)[0]
    target_token_ids = target_pii_tokens.tolist()
    
    current_tokens = tokens.clone()
    token_log_probs = []
    
    for i, target_token_id in enumerate(target_token_ids):
        # Get logits for the last position
        logits = model(current_tokens)
        last_token_logits = logits[0, -1, :]
        log_probs = torch.log_softmax(last_token_logits, dim=-1)
        
        # Get log probability of the TARGET token
        target_log_prob = log_probs[target_token_id].item()
        token_log_probs.append(target_log_prob)
        
        # Append TARGET token to sequence for next iteration (teacher forcing)
        current_tokens = torch.cat([current_tokens, torch.tensor([[target_token_id]], device=current_tokens.device)], dim=1)
    
    # Calculate sequence log probability (sum of log probabilities)
    sequence_log_prob = sum(token_log_probs)
    
    return target_token_ids, token_log_probs, sequence_log_prob


# Simple greedy generation function (terminates on EOS)
def greedy_generate(model, tokens, max_tokens=50):
    """
    Generate tokens greedily until EOS or max_tokens.
    
    Args:
        model: The model to generate from
        tokens: Input tokens [batch_size, seq_len]
        max_tokens: Maximum tokens to generate
    
    Returns:
        generated_text: Generated text string
        token_ids: List of generated token IDs
        token_log_probs: List of log probabilities for each generated token
    """
    current_tokens = tokens.clone()
    generated_token_ids = []
    token_log_probs = []
    
    # Get EOS token ID (this is <|endoftext|> in GPT-style models)
    eos_token_id = model.tokenizer.eos_token_id if hasattr(model.tokenizer, 'eos_token_id') else None
    
    for i in range(max_tokens):
        # Get logits for the last position
        logits = model(current_tokens)
        last_token_logits = logits[0, -1, :]
        log_probs = torch.log_softmax(last_token_logits, dim=-1)
        
        # Greedy decoding: take the highest probability token
        predicted_token_id = torch.argmax(last_token_logits).item()
        
        generated_token_ids.append(predicted_token_id)
        token_log_probs.append(log_probs[predicted_token_id].item())
        
        # Stop on EOS token (which is <|endoftext|> in GPT-style models)
        if eos_token_id is not None and predicted_token_id == eos_token_id:
            break
        
        # Append predicted token to sequence for next iteration
        current_tokens = torch.cat([current_tokens, torch.tensor([[predicted_token_id]], device=current_tokens.device)], dim=1)
    
    # Decode generated text
    generated_text = model.tokenizer.decode(generated_token_ids)
    
    return generated_text, generated_token_ids, token_log_probs


In [56]:
test_string = "Passport number of Kelly Coleman is" #030297304
tokens = tl_control_model.to_tokens(test_string)
print(f"model tokens: {tokens}")
print(f"tokenized text: {tl_target_model.to_str_tokens(test_string)}")

greedy_generate(tl_target_model, tokens)

model tokens: tensor([[    0, 12161,   631,  1180,   273, 14943, 32613,   310]],
       device='mps:0')
tokenized text: ['<|endoftext|>', 'Pass', 'port', ' number', ' of', ' Kelly', ' Coleman', ' is']


(' 030297304<|endoftext|>',
 [470, 1229, 23185, 19321, 0],
 [-2.1044678688049316,
  -0.581274151802063,
  -0.12609484791755676,
  -1.227457880973816,
  -0.005155483260750771])

In [None]:
# Setup for activation patching: prepare inputs and get baseline metrics
test_string = "Passport number of Kelly Coleman is"
tokens = tl_control_model.to_tokens(test_string)

print(f"Input prompt: {test_string}")

# First, do greedy generation from target model to get the actual expected PII
target_generated_text, target_generated_tokens, _ = greedy_generate(tl_target_model, tokens)
print(f"Target model generated: '{target_generated_text}'")

# Extract PII (remove EOS token)
expected_pii = target_generated_text.split('<|endoftext|>')[0]
print(f"Expected PII: '{expected_pii}'")

# Get cache from target model (for activation patching)
_, target_cache = tl_target_model.run_with_cache(tokens)
print(f"Obtained cache from target model")

# Calculate baseline probabilities
target_token_ids, target_token_log_probs, target_seq_log_prob = calculate_target_pii_probability(
    tl_target_model, tokens, expected_pii
)

control_token_ids, control_token_log_probs, control_seq_log_prob = calculate_target_pii_probability(
    tl_control_model, tokens, expected_pii
)

print(f"\n{'='*60}")
print("Baseline Metrics")
print(f"{'='*60}")
print(f"Target tokens: {target_token_ids}")
print(f"Decoded tokens: {[tl_target_model.tokenizer.decode([t]) for t in target_token_ids]}")
print(f"\nTarget (memorized) model log prob: {target_seq_log_prob:.4f}")
print(f"Control model log prob: {control_seq_log_prob:.4f}")
print(f"Difference: {target_seq_log_prob - control_seq_log_prob:+.4f}")
print(f"{'='*60}")


Input prompt: Passport number of Kelly Coleman is
Input tokens: tensor([[    0, 12161,   631,  1180,   273, 14943, 32613,   310]],
       device='mps:0')
Target model generated: ' 030297304<|endoftext|>'
Expected PII: ' 030297304'
Obtained cache from target model

Baseline Metrics
Target tokens: [470, 1229, 23185, 19321]
Decoded tokens: [' 0', '30', '297', '304']

Target (memorized) model log prob: -4.0393
Control model log prob: -35.2451
Difference: +31.2058


In [67]:
# Print out all available layers from the cache

# Get all available layers from the target model cache
all_available_layers = sorted(list(target_cache.keys()))

print(f"\nTotal available layers in cache: {len(all_available_layers)}\n")
print("All available layers:")
for layer_name in all_available_layers:
    print(f"  {layer_name}")



Total available layers in cache: 111

All available layers:
  blocks.0.attn.hook_attn_scores
  blocks.0.attn.hook_k
  blocks.0.attn.hook_pattern
  blocks.0.attn.hook_q
  blocks.0.attn.hook_rot_k
  blocks.0.attn.hook_rot_q
  blocks.0.attn.hook_v
  blocks.0.attn.hook_z
  blocks.0.hook_attn_out
  blocks.0.hook_mlp_out
  blocks.0.hook_resid_post
  blocks.0.hook_resid_pre
  blocks.0.ln1.hook_normalized
  blocks.0.ln1.hook_scale
  blocks.0.ln2.hook_normalized
  blocks.0.ln2.hook_scale
  blocks.0.mlp.hook_post
  blocks.0.mlp.hook_pre
  blocks.1.attn.hook_attn_scores
  blocks.1.attn.hook_k
  blocks.1.attn.hook_pattern
  blocks.1.attn.hook_q
  blocks.1.attn.hook_rot_k
  blocks.1.attn.hook_rot_q
  blocks.1.attn.hook_v
  blocks.1.attn.hook_z
  blocks.1.hook_attn_out
  blocks.1.hook_mlp_out
  blocks.1.hook_resid_post
  blocks.1.hook_resid_pre
  blocks.1.ln1.hook_normalized
  blocks.1.ln1.hook_scale
  blocks.1.ln2.hook_normalized
  blocks.1.ln2.hook_scale
  blocks.1.mlp.hook_post
  blocks.1.mlp.ho

In [None]:
# Activation patching with per-step caching using run_with_hooks
# For each token prediction, get fresh cache from target model so shapes always match

def make_patch_fn(layer_name, cache_to_use):
    """Simple patching function - assumes cache matches current sequence length"""
    def patch_fn(activation, hook):
        if hook.name == layer_name:
            return cache_to_use[layer_name]
        return activation
    return patch_fn

def calculate_target_pii_probability_with_patching(control_model, target_model, tokens, target_pii, layer_to_patch):
    """
    Calculate log probability while patching a specific layer.
    Re-caches target activations at each step to avoid shape mismatches.
    
    Args:
        control_model: Control model to evaluate
        target_model: Target model to get activations from
        tokens: Input tokens
        target_pii: Expected PII string
        layer_to_patch: Which layer to patch
    
    Returns:
        token_log_probs: List of log probabilities for each target token
        sequence_log_prob: Sum of log probabilities
    """
    # Tokenize target PII
    target_pii_tokens = control_model.to_tokens(target_pii, prepend_bos=False)[0]
    target_token_ids = target_pii_tokens.tolist()
    
    current_tokens = tokens.clone()
    token_log_probs = []
    
    for i, target_token_id in enumerate(target_token_ids):
        # Get cache from target model for current sequence length
        _, target_cache = target_model.run_with_cache(current_tokens)
        
        # Run control model with patching hook
        logits = control_model.run_with_hooks(
            current_tokens,
            fwd_hooks=[(layer_to_patch, make_patch_fn(layer_to_patch, target_cache))],
            return_type='logits'
        )
        
        last_token_logits = logits[0, -1, :]
        log_probs = torch.log_softmax(last_token_logits, dim=-1)
        
        # Get log probability of the TARGET token
        target_log_prob = log_probs[target_token_id].item()
        token_log_probs.append(target_log_prob)
        
        # Append TARGET token to sequence for next iteration (teacher forcing)
        current_tokens = torch.cat([current_tokens, torch.tensor([[target_token_id]], device=current_tokens.device)], dim=1)
    
    # Calculate sequence log probability
    sequence_log_prob = sum(token_log_probs)
    
    return token_log_probs, sequence_log_prob

def extract_layer_type(layer_name):
    """Extract the type of layer from its name"""
    if 'hook_embed' in layer_name:
        return 'Embedding'
    elif 'ln_final' in layer_name:
        return 'Final LayerNorm'
    elif 'hook_resid_pre' in layer_name:
        return 'Residual Pre'
    elif 'hook_resid_post' in layer_name:
        return 'Residual Post'
    elif 'hook_attn_out' in layer_name:
        return 'Attention Output'
    elif 'hook_mlp_out' in layer_name:
        return 'MLP Output'
    elif 'ln1.hook_normalized' in layer_name:
        return 'LN1 Normalized'
    elif 'ln1.hook_scale' in layer_name:
        return 'LN1 Scale'
    elif 'ln2.hook_normalized' in layer_name:
        return 'LN2 Normalized'
    elif 'ln2.hook_scale' in layer_name:
        return 'LN2 Scale'
    elif 'mlp.hook_pre' in layer_name:
        return 'MLP Pre-activation'
    elif 'mlp.hook_post' in layer_name:
        return 'MLP Post-activation'
    elif 'hook_q' in layer_name:
        return 'Query'
    elif 'hook_k' in layer_name:
        return 'Key'
    elif 'hook_v' in layer_name:
        return 'Value'
    elif 'hook_z' in layer_name:
        return 'Attention Z'
    elif 'hook_attn_scores' in layer_name:
        return 'Attention Scores'
    elif 'hook_pattern' in layer_name:
        return 'Attention Pattern'
    elif 'hook_rot_q' in layer_name:
        return 'Rotary Q'
    elif 'hook_rot_k' in layer_name:
        return 'Rotary K'
    else:
        return 'Other'

all_layers = list(target_cache.keys())

# Test all layers (excluding LayerNorm for speed)
layers_to_test = [layer for layer in all_layers if '.ln' not in layer and 'ln_final' not in layer]

print(f"{'='*60}")
print("Activation Patching Analysis (Per-Step Caching)")
print(f"{'='*60}")
print(f"Testing {len(layers_to_test)} layers")
print(f"Baseline - Control model log prob: {control_seq_log_prob:.4f}")
print(f"Target model log prob: {target_seq_log_prob:.4f}")
print(f"Expected PII: '{expected_pii}'")
print(f"\nNote: Re-caches at each token prediction for accurate patching")
print(f"\n")

patching_results = []

for i, layer_name in enumerate(layers_to_test):
    # Calculate probability with patching (re-caches at each step)
    patched_token_log_probs, patched_seq_log_prob = calculate_target_pii_probability_with_patching(
        tl_control_model, tl_target_model, tokens, expected_pii, layer_name
    )
    
    # Calculate improvement
    log_prob_improvement = patched_seq_log_prob - control_seq_log_prob
    
    # Extract layer info
    layer_type = extract_layer_type(layer_name)
    layer_num = int(layer_name.split('.')[1]) if 'blocks.' in layer_name else -1
    
    patching_results.append({
        'Layer': layer_name,
        'Layer Type': layer_type,
        'Layer Number': layer_num,
        'Log Prob Improvement': log_prob_improvement,
        'Patched Log Prob': patched_seq_log_prob
    })
    
    if (i + 1) % 10 == 0:
        print(f"Processed {i + 1}/{len(layers_to_test)} layers...")

print(f"\nCompleted testing {len(layers_to_test)} layers!")

# Create DataFrame
results_df = pd.DataFrame(patching_results)
results_df = results_df.sort_values('Log Prob Improvement', ascending=False)

print(f"\n{'='*60}")
print("Top 15 Layers by Log Prob Improvement")
print(f"{'='*60}")
display(results_df[['Layer', 'Layer Type', 'Log Prob Improvement']].head(15))

print(f"\n{'='*60}")
print("Summary by Layer Type")
print(f"{'='*60}")
type_summary = results_df.groupby('Layer Type')['Log Prob Improvement'].agg(['mean', 'max', 'min', 'count']).round(4)
type_summary.columns = ['Mean', 'Max', 'Min', 'Count']
type_summary = type_summary.sort_values('Mean', ascending=False)
display(type_summary)

print(f"\n{'='*60}")
print("Summary by Layer Number")
print(f"{'='*60}")
layer_summary = results_df[results_df['Layer Number'] >= 0].groupby('Layer Number')['Log Prob Improvement'].agg(['mean', 'max']).round(4)
layer_summary.columns = ['Mean', 'Max']
display(layer_summary)

Original sequence length: 8
Comprehensive Activation Patching Analysis
Testing 99 different layers available in cache
Baseline (no patch) - Control model sequence log prob: -35.2451
Target model sequence log prob: -4.0393
Expected PII: ' 030297304'


0 Patching layer hook_embed
Patching layer hook_embed
Patching layer hook_embed
Patching layer hook_embed
Patching layer hook_embed
1 Patching layer blocks.0.hook_resid_pre
Patching layer blocks.0.hook_resid_pre
Patching layer blocks.0.hook_resid_pre
Patching layer blocks.0.hook_resid_pre
Patching layer blocks.0.hook_resid_pre
2 Patching layer blocks.0.ln1.hook_scale
Patching layer blocks.0.ln1.hook_scale
Patching layer blocks.0.ln1.hook_scale
Patching layer blocks.0.ln1.hook_scale
Patching layer blocks.0.ln1.hook_scale
Patching layer blocks.0.ln1.hook_scale
Patching layer blocks.0.ln1.hook_scale
Patching layer blocks.0.ln1.hook_scale
Patching layer blocks.0.ln1.hook_scale
Patching layer blocks.0.ln1.hook_scale
Patching layer blocks.0.ln1.

Unnamed: 0,Layer,Layer Type,Layer Number,Log Prob Improvement
81,blocks.5.hook_resid_pre,Residual Pre,5,4.832985
80,blocks.4.hook_resid_post,Residual Post,4,4.832985
65,blocks.4.hook_resid_pre,Residual Pre,4,3.886961
64,blocks.3.hook_resid_post,Residual Post,3,3.886961
37,blocks.2.attn.hook_k,Key,2,2.974463
40,blocks.2.attn.hook_rot_k,Rotary K,2,2.974463
0,hook_embed,Embedding,-1,2.858965
1,blocks.0.hook_resid_pre,Residual Pre,0,2.858965
63,blocks.3.hook_mlp_out,MLP Output,3,2.810335
79,blocks.4.hook_mlp_out,MLP Output,4,2.770432



Bottom 10 Layers (Most Harmful)


Unnamed: 0,Layer,Layer Type,Layer Number,Log Prob Improvement
53,blocks.3.attn.hook_k,Key,3,-0.209264
88,blocks.5.attn.hook_rot_k,Rotary K,5,-0.209812
85,blocks.5.attn.hook_k,Key,5,-0.209812
14,blocks.0.mlp.hook_post,MLP Post-activation,0,-0.281024
13,blocks.0.mlp.hook_pre,MLP Pre-activation,0,-0.281024
28,blocks.1.ln2.hook_normalized,LN2 Normalized,1,-0.311628
47,blocks.2.hook_mlp_out,MLP Output,2,-0.404104
45,blocks.2.mlp.hook_pre,MLP Pre-activation,2,-1.08135
46,blocks.2.mlp.hook_post,MLP Post-activation,2,-1.08135
97,ln_final.hook_scale,Final LayerNorm,-1,-1.659567



Summary by Layer Type (sorted by mean improvement)


Unnamed: 0_level_0,Mean Improvement,Max Improvement,Min Improvement,Std Dev,Count
Layer Type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Residual Pre,2.8819,4.833,1.6056,1.2509,6
Embedding,2.859,2.859,2.859,,1
Residual Post,2.7945,4.833,1.6056,1.271,6
MLP Output,1.5241,2.8103,-0.4041,1.3913,6
Attention Output,1.0075,1.7497,-0.1208,0.78,6
Attention Z,0.9517,1.9289,-0.1669,0.8623,6
LN2 Normalized,0.9408,2.2,-0.3116,0.9581,6
MLP Pre-activation,0.8776,2.6593,-1.0814,1.4632,6
MLP Post-activation,0.8776,2.6593,-1.0814,1.4632,6
LN1 Normalized,0.8554,2.2963,0.1554,0.7943,6



Summary by Layer Number


Unnamed: 0_level_0,Mean Improvement,Max Improvement,Min Improvement,Std Dev
Layer Number,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0.9827,2.859,-0.281,0.9333
1,0.7018,2.105,-0.3116,0.8456
2,0.7351,2.9745,-1.0814,1.2669
3,1.2006,3.887,-0.2093,1.2561
4,1.1312,4.833,0.0094,1.5089
5,0.9536,4.833,-0.2098,1.5014



Analysis by Component Category


Unnamed: 0_level_0,Mean Improvement,Max Improvement,Min Improvement,Count
Component Category,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Embedding,2.859,2.859,2.859,1
Residual Stream,2.8382,4.833,1.6056,12
MLP Components,1.0931,2.8103,-1.0814,18
Attention Components,0.638,2.9745,-0.2098,42
Layer Normalization,0.4393,2.334,-1.6596,26



Best Layer in Each Major Category


Unnamed: 0,Component Category,Layer,Layer Type,Log Prob Improvement
81,Residual Stream,blocks.5.hook_resid_pre,Residual Pre,4.832985
37,Attention Components,blocks.2.attn.hook_k,Key,2.974463
0,Embedding,hook_embed,Embedding,2.858965
63,MLP Components,blocks.3.hook_mlp_out,MLP Output,2.810335
98,Layer Normalization,ln_final.hook_normalized,Final LayerNorm,2.334048



Detailed Attention Component Analysis


Unnamed: 0_level_0,Mean,Max,Min,Count
Layer Type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Attention Output,1.0075,1.7497,-0.1208,6
Attention Z,0.9517,1.9289,-0.1669,6
Key,0.6957,2.9745,-0.2098,6
Rotary K,0.6957,2.9745,-0.2098,6
Value,0.6035,1.398,-0.1532,6
Query,0.2561,0.8948,-0.0941,6
Rotary Q,0.2561,0.8948,-0.0941,6



Detailed MLP Component Analysis


Unnamed: 0_level_0,Mean,Max,Min,Count
Layer Type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
MLP Output,1.5241,2.8103,-0.4041,6
MLP Post-activation,0.8776,2.6593,-1.0814,6
MLP Pre-activation,0.8776,2.6593,-1.0814,6



Best Component in Each Transformer Block


Unnamed: 0,Layer Number,Layer,Layer Type,Log Prob Improvement
1,0,blocks.0.hook_resid_pre,Residual Pre,2.858965
32,1,blocks.1.hook_resid_post,Residual Post,2.104969
37,2,blocks.2.attn.hook_k,Key,2.974463
64,3,blocks.3.hook_resid_post,Residual Post,3.886961
80,4,blocks.4.hook_resid_post,Residual Post,4.832985
81,5,blocks.5.hook_resid_pre,Residual Pre,4.832985
