In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops
from jaxtyping import Float, Int
from torch import Tensor

torch.set_grad_enabled(False)

# Device setup
GPU_TO_USE = 2

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = f"cuda:{GPU_TO_USE}" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: cuda:2


In [2]:
from pathlib import Path

def get_data_path(data_folder, in_colab=COLAB):
  if in_colab:
    from google.colab import drive
    drive.mount('/content/drive')

    return Path(f'/content/drive/MyDrive/{data_folder}')
  else:
    return Path(f'./{data_folder}')
  
datapath = get_data_path('./data')
datapath

PosixPath('data')

In [3]:
import sys
import os

# Add the parent directory (sfc_deception) to sys.path
sys.path.append(os.path.abspath(os.path.join('..')))

In [4]:
from sae_lens import SAE, HookedSAETransformer, ActivationsStore

USE_INSTRUCT = True
PARAMS_COUNT = 9

MODEL_NAME = f'gemma-2-{PARAMS_COUNT}b' + ('-it' if USE_INSTRUCT else '')
print(f'Using {MODEL_NAME}')

model = HookedSAETransformer.from_pretrained(MODEL_NAME, device=device, dtype=torch.bfloat16)
model



Using gemma-2-9b-it


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-9b-it into HookedTransformer


HookedSAETransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-41): 42 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln1_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
      

## Loading the data

In [5]:
from classes.sfc_data_loader import SFCDatasetLoader
import utils.prompts as prompts
from utils.enums import *

In [6]:
DATASET_NAME = SupportedDatasets.COMMONSENSE_QA_FILTERED
CHOOSE_N_SAMPLES = 400

if DATASET_NAME in [SupportedDatasets.CITIES, SupportedDatasets.FACTS, SupportedDatasets.COMPANIES]:
    dataloader = SFCDatasetLoader(DATASET_NAME, model, 
                                  clean_system_prompt=prompts.TRUTH_OR_USER_KILLED,
                                  corrupted_system_prompt=prompts.LIE_OR_USER_KILLED,
                                  task_prompt=prompts.ANSWER_TRUE_FALSE,
                                  local_dataset=True, base_folder_path=datapath,
                                  )
elif DATASET_NAME in [SupportedDatasets.COMMONSENSE_QA, SupportedDatasets.COMMONSENSE_QA_FILTERED]:
    dataloader = SFCDatasetLoader(DATASET_NAME, model, 
                                clean_system_prompt=prompts.TRUTH_OR_USER_KILLED,
                                corrupted_system_prompt=prompts.LIE_OR_USER_KILLED,
                                task_prompt=prompts.OUTPUT_SINGLE_LETTER,
                                num_samples=CHOOSE_N_SAMPLES)
elif DATASET_NAME in [SupportedDatasets.VERB_AGREEMENT]:
    dataloader = SFCDatasetLoader(DATASET_NAME, model,
                                  local_dataset=True, base_folder_path=datapath)
else:
    raise ValueError(f"Dataset {DATASET_NAME.value} is not supported")

In [7]:
if DATASET_NAME in [SupportedDatasets.VERB_AGREEMENT]:
    clean_dataset, corrupted_dataset = dataloader.get_clean_corrupted_datasets(tokenize=True, apply_chat_template=False, prepend_generation_prefix=True)
else:
    clean_dataset, corrupted_dataset = dataloader.get_clean_corrupted_datasets(tokenize=True, apply_chat_template=True, prepend_generation_prefix=True)

Figuring out optimal padding length...
Filtered out 4 longest prompts from a total of 400 prompts.
Setting max prompt length to 171


100%|██████████| 396/396 [00:00<00:00, 442.98it/s]


In [8]:
CONTROL_SEQ_LEN = clean_dataset['control_sequence_length'][0].item()
N_CONTEXT = clean_dataset['prompt'].shape[1]

CONTROL_SEQ_LEN, N_CONTEXT

(4, 171)

In [9]:
def sample_dataset(start_idx=0, end_idx=-1, clean_dataset=None, corrupted_dataset=None):
    assert clean_dataset is not None or corrupted_dataset is not None, 'At least one dataset must be provided.'
    return_values = []

    for key in ['prompt', 'answer', 'answer_pos', 'attention_mask', 'special_token_mask']:
        if clean_dataset is not None:
            return_values.append(clean_dataset[key][start_idx:end_idx])
        if corrupted_dataset is not None:
            return_values.append(corrupted_dataset[key][start_idx:end_idx])

    return return_values

## Ablation utils

In [10]:
def get_answer_logit(logits: Float[Tensor, "batch pos d_vocab"], 
                     clean_answers: Int[Tensor, "batch"],
                     ansnwer_pos: Int[Tensor, "batch"], return_all_logits=False) -> Float[Tensor, "batch"]:

    answer_pos_idx = einops.repeat(ansnwer_pos, 'batch -> batch 1 d_vocab',
                                   d_vocab=logits.shape[-1])
    answer_logits = logits.gather(1, answer_pos_idx).squeeze(1) # shape [batch, d_vocab]

    correct_logits = answer_logits.gather(1, clean_answers.unsqueeze(1)).squeeze(1) # shape [batch]

    if return_all_logits:
        return answer_logits, correct_logits

    return correct_logits

def get_incorrect_logits(logits: Float[Tensor, "batch pos d_vocab"],
                         patched_answers: Int[Tensor, "batch count"],
                         answer_pos: Int[Tensor, "batch"], patch_answer_reduce='max') -> Float[Tensor, "batch"]:
    
    answer_pos_idx = einops.repeat(answer_pos, 'batch -> batch 1 d_vocab',
                                   d_vocab=logits.shape[-1])
    answer_logits = logits.gather(1, answer_pos_idx).squeeze(1) # shape [batch, d_vocab]

    incorrect_logits = answer_logits.gather(1, patched_answers)  # shape [batch, answer_count]

    # Sum the logits for each incorrect answer option
    if patch_answer_reduce == 'sum':
        incorrect_logits = incorrect_logits.sum(dim=1)
    # Or take their maximum: this should be a better option to avoid situations where the model outputs gibberish and all the answers have similar logits
    elif patch_answer_reduce == 'max':
        incorrect_logits = incorrect_logits.max(dim=1).values

    return incorrect_logits

In [11]:
def ablate_hook(act, hook):
    act[:] = 0
    return act

# Define a function to run ablation analysis
def run_ablation_batch_analysis(
    batch_size=64,
    hook_types: list[str] = ["hook_mlp_out", "attn.hook_z"],
    n_layers = model.cfg.n_layers,
    model=model,
    clean_dataset=clean_dataset,
    corrupted_dataset=corrupted_dataset,
    clean_prompt_run=True,
):
    num_samples = len(clean_dataset['prompt']) # [batch_size, seq_len]
    total_batches = num_samples // batch_size

    if num_samples % batch_size != 0:
        total_batches += 1

    # For each hook type, the differences tensor will contain 2 elements for each layer:
    # difference in the correct logit and difference in the max incorrect logit
    differences_dict = {hook: torch.zeros(n_layers, 2, device=model.cfg.device) for hook in hook_types}
    
    for i in tqdm(range(0, num_samples, batch_size)):
        clean_prompts, corrupted_prompts, clean_answers, corrupted_answers, clean_answers_pos, corrupted_answers_pos, \
        clean_attn_mask, corrupted_attn_mask, clean_special_mask, corr_special_mask = sample_dataset(i, i + batch_size, clean_dataset, corrupted_dataset)
        if clean_prompt_run:
            prompts = clean_prompts
            answers_pos = clean_answers_pos
            attn_mask = clean_attn_mask
        else:
            prompts = corrupted_prompts
            answers_pos = corrupted_answers_pos
            attn_mask = corrupted_attn_mask

        # Get baseline logits
        baseline_logits = model(prompts, attention_mask=attn_mask)  # [batch_size, seq_len, n_tokens]

        baseline_correct_logits = get_answer_logit(baseline_logits, clean_answers, answers_pos)
        baseline_incorrect_logits = get_incorrect_logits(baseline_logits, corrupted_answers, answers_pos)

        print(f'Mean incorrect - correct difference: {baseline_incorrect_logits.mean() - baseline_correct_logits.mean()}')

        clear_cache()
        # Run ablations for each layer and hook type
        for layer in tqdm(range(n_layers)):
            for hook_type in hook_types:
                hook_name = f"blocks.{layer}.{hook_type}"

                ablated_logits = model.run_with_hooks(prompts, attention_mask=attn_mask, 
                                                     fwd_hooks=[(hook_name, ablate_hook)]  # [batch_size, seq_len, n_tokens]
                                                     )   # [batch_size, seq_len, n_tokens]
                
                ablated_correct_logits = get_answer_logit(ablated_logits, clean_answers, answers_pos)
                ablated_incorrect_logits = get_incorrect_logits(ablated_logits, corrupted_answers, answers_pos)
                
                # Calculate differences from baseline
                correct_logit_dif = baseline_correct_logits - ablated_correct_logits  # [batch_size, seq_len, n_tokens]
                incorrect_logit_dif = baseline_incorrect_logits - ablated_incorrect_logits  # [batch_size, seq_len, n_tokens]

                # Store the mean difference over the batch
                differences_dict[hook_type][layer, 0] += correct_logit_dif.mean()
                differences_dict[hook_type][layer, 1] += incorrect_logit_dif.mean()
            clear_cache()

    # Finally, compute the mean over the outer batches
    for hook_type in hook_types:
        differences_dict[hook_type] /= total_batches

    return differences_dict

## Ablation analysis

In [12]:
N_LAYERS = model.cfg.n_layers
BATCH_SIZE = 87
HOOK_TYPES = ["hook_mlp_out", "attn.hook_z"]

clear_cache()

### Clean prompts

In [13]:
ablation_difs_dict = run_ablation_batch_analysis(batch_size=BATCH_SIZE, n_layers=N_LAYERS)
ablation_difs_dict['hook_mlp_out'].shape

  0%|          | 0/5 [00:00<?, ?it/s]

Mean incorrect - correct difference: -4.75


100%|██████████| 42/42 [01:57<00:00,  2.79s/it]
 20%|██        | 1/5 [01:59<07:56, 119.03s/it]

Mean incorrect - correct difference: -4.625


100%|██████████| 42/42 [01:58<00:00,  2.83s/it]
 40%|████      | 2/5 [03:59<05:59, 119.89s/it]

Mean incorrect - correct difference: -4.75


100%|██████████| 42/42 [01:59<00:00,  2.84s/it]
 60%|██████    | 3/5 [06:00<04:00, 120.32s/it]

Mean incorrect - correct difference: -5.75


100%|██████████| 42/42 [01:59<00:00,  2.84s/it]
 80%|████████  | 4/5 [08:01<02:00, 120.50s/it]

Mean incorrect - correct difference: -5.75


100%|██████████| 42/42 [01:09<00:00,  1.65s/it]
100%|██████████| 5/5 [09:11<00:00, 110.26s/it]


torch.Size([42, 2])

In [14]:
results = []

for hook_type in HOOK_TYPES:
    for layer in range(N_LAYERS):
        result = {
            'layer': layer,
            'hook_type': hook_type,
            'logit_diffs': ablation_difs_dict[hook_type][layer].cpu().numpy(),
        }
        results.append(result)

results[0]

{'layer': 0,
 'hook_type': 'hook_mlp_out',
 'logit_diffs': array([ 0.31054688, -0.56328124], dtype=float32)}

In [17]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def plot_logit_diffs(data, title_suffix=""):
    # Group by hook_type
    grouped_data = {}
    for entry in data:
        hook_type = entry['hook_type']
        if hook_type not in grouped_data:
            grouped_data[hook_type] = []
        grouped_data[hook_type].append(entry)

    # Define color scheme for consistency
    correct_color = 'blue'
    incorrect_color = 'red'

    # Create subplots without shared x-axes
    fig = make_subplots(
        rows=2, cols=1,
        shared_xaxes=False,  # Disable shared x-axes
        subplot_titles=(
            "hook_mlp_out Logit Differences",
            "attn.hook_z Logit Differences"
        )
    )

    # Plot for hook_type = 'hook_mlp_out'
    if 'hook_mlp_out' in grouped_data:
        layers = [entry['layer'] for entry in grouped_data['hook_mlp_out']]
        logit_diffs_correct = [entry['logit_diffs'][0] for entry in grouped_data['hook_mlp_out']]
        logit_diffs_incorrect = [entry['logit_diffs'][1] for entry in grouped_data['hook_mlp_out']]

        fig.add_trace(
            go.Bar(
                name="Correct Answer Logit-Dif",
                x=layers,
                y=logit_diffs_correct,
                marker_color=correct_color  # Apply consistent color
            ),
            row=1, col=1
        )
        fig.add_trace(
            go.Bar(
                name="Max(Incorrect) Answer Logit-Dif",
                x=layers,
                y=logit_diffs_incorrect,
                marker_color=incorrect_color  # Apply consistent color
            ),
            row=1, col=1
        )

    # Plot for hook_type = 'attn.hook_z'
    if 'attn.hook_z' in grouped_data:
        layers = [entry['layer'] for entry in grouped_data['attn.hook_z']]
        logit_diffs_correct = [entry['logit_diffs'][0] for entry in grouped_data['attn.hook_z']]
        logit_diffs_incorrect = [entry['logit_diffs'][1] for entry in grouped_data['attn.hook_z']]

        fig.add_trace(
            go.Bar(
                name="Correct Answer Logit-Dif",
                x=layers,
                y=logit_diffs_correct,
                marker_color=correct_color,  # Consistent color
                showlegend=False  # Hide duplicate legend entry
            ),
            row=2, col=1
        )
        fig.add_trace(
            go.Bar(
                name="Max(Incorrect) Answer Logit-Dif",
                x=layers,
                y=logit_diffs_incorrect,
                marker_color=incorrect_color,  # Consistent color
                showlegend=False  # Hide duplicate legend entry
            ),
            row=2, col=1
        )

    # Update layout
    fig.update_layout(
        height=800,  # Adjusted height for visibility
        width=1400,
        title_text=f"Ablation logit differences: {title_suffix}",
        barmode='group'  # Group bars side by side for comparison
    )

    # Update x-axes to show all layer ticks and add x-axis labels
    # For the first subplot
    fig.update_xaxes(
        title_text="Layer",
        tickmode='linear',
        dtick=1,  # Show every tick
        row=1, col=1
    )

    # For the second subplot
    fig.update_xaxes(
        title_text="Layer",
        tickmode='linear',
        dtick=1,  # Show every tick
        row=2, col=1
    )

    # Update y-axes titles
    fig.update_yaxes(title_text="Logit Difference<br>(Baseline - Ablated)", row=1, col=1)
    fig.update_yaxes(title_text="Logit Difference<br>(Baseline - Ablated)", row=2, col=1)

    # Optionally, adjust the range of x-axes if layers are from 1 to 42
    fig.update_xaxes(range=[1, N_LAYERS], row=1, col=1)
    fig.update_xaxes(range=[1, N_LAYERS], row=2, col=1)

    # Display the figure
    fig.show()
    html_name = title_suffix.replace(" ", "_")
    fig.write_html(f'ablation_{html_name}.html')


In [18]:

# Example usage:
plot_logit_diffs(results, title_suffix='Clean setting')

### Corrupted prompts

In [19]:
clear_cache()

In [20]:
ablation_difs_dict = run_ablation_batch_analysis(batch_size=BATCH_SIZE, n_layers=N_LAYERS, clean_prompt_run=False)
ablation_difs_dict['hook_mlp_out'].shape

  0%|          | 0/5 [00:00<?, ?it/s]

Mean incorrect - correct difference: 1.25


100%|██████████| 42/42 [01:57<00:00,  2.79s/it]
 20%|██        | 1/5 [01:58<07:55, 118.94s/it]

Mean incorrect - correct difference: 2.0


100%|██████████| 42/42 [01:59<00:00,  2.84s/it]
 40%|████      | 2/5 [03:59<06:00, 120.09s/it]

Mean incorrect - correct difference: 2.125


100%|██████████| 42/42 [01:59<00:00,  2.85s/it]
 60%|██████    | 3/5 [06:01<04:01, 120.64s/it]

Mean incorrect - correct difference: 1.875


100%|██████████| 42/42 [01:59<00:00,  2.85s/it]
 80%|████████  | 4/5 [08:02<02:00, 120.89s/it]

Mean incorrect - correct difference: 1.25


100%|██████████| 42/42 [01:09<00:00,  1.66s/it]
100%|██████████| 5/5 [09:13<00:00, 110.61s/it]


torch.Size([42, 2])

In [21]:
results = []

for hook_type in HOOK_TYPES:
    for layer in range(N_LAYERS):
        result = {
            'layer': layer,
            'hook_type': hook_type,
            'logit_diffs': ablation_difs_dict[hook_type][layer].cpu().numpy(),
        }
        results.append(result)

plot_logit_diffs(results, title_suffix='Corrupted setting')