In [6]:
import torch
import einops
from datasets import concatenate_datasets, load_from_disk
import random

from sae_refusal.model.gemma import GemmaModel
from sae_refusal.model.embedding import EmbeddingModel
from sae_refusal import set_seed, clear_memory
from sae_refusal.data import (
    load_wmdp,
    split,
    to_instructions,
    sample_data
)
from sae_refusal.probe import LinearProbe
from sae_refusal.pipeline.utils import (
    generate_and_save_completions
)
from sae_refusal.pipeline.generate_directions import generate_directions_rmu
from sae_refusal.plot import (
    plot_scores_plotly,
    plot_refusal_scores_plotly,
)
from sae_refusal.pipeline.activations import (
    get_activations,
    get_activations_pre
)

In [2]:
LAYER_ID = 9
SAMPLE_SIZE = 300  # For each dataset
VAL_SIZE = 0.2
SEED = 42
MAX_LEN = 1024

ARTIFACT_DIR = "results/ablated"

MODEL_NAME = "google/gemma-2-2b"
RMU_NAME = "lenguyen1807/gemma-2-2b-RMU"
ABLATED_NAME = "lenguyen1807/gemma-2-2b-RMU-ablated"

In [3]:
set_seed(SEED)
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")

### Load model

In [4]:
base_model = GemmaModel(MODEL_NAME, type="none")
rmu_model = GemmaModel(RMU_NAME, type="none")
ablated_model = GemmaModel(ABLATED_NAME, type="none")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Load dataset

In [7]:
bio_train = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/bio_train")
bio_val = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/bio_val")
cyber_train = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/cyber_train")
cyber_val = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/cyber_val")

In [8]:
(
    bio_train_instructions,
    cyber_train_instructions,
    val_instructions,  # we combine both for validation
) = to_instructions(
    [
        bio_train,
        cyber_train,
        concatenate_datasets([bio_val, cyber_val])
    ],
    lambda x: x["question"],
)

### Get activations

In [9]:
bio_rmu_activations = get_activations_pre(
    rmu_model.model,
    rmu_model.tokenizer,
    bio_train_instructions,
    tokenize_instructions_fn=rmu_model.tokenize_instructions_fn,
    block_modules=rmu_model.model_block_modules,
    positions=[-1],
    batch_size=8
)

cyber_rmu_activations = get_activations_pre(
    rmu_model.model,
    rmu_model.tokenizer,
    cyber_train_instructions,
    tokenize_instructions_fn=rmu_model.tokenize_instructions_fn,
    block_modules=rmu_model.model_block_modules,
    positions=[-1],
    batch_size=4
)

100%|██████████| 30/30 [00:19<00:00,  1.50it/s]
100%|██████████| 60/60 [00:39<00:00,  1.52it/s]


In [10]:
bio_base_activations = get_activations_pre(
    base_model.model,
    base_model.tokenizer,
    bio_train_instructions,
    tokenize_instructions_fn=base_model.tokenize_instructions_fn,
    block_modules=base_model.model_block_modules,
    positions=[-1],
    batch_size=8
)

cyber_base_activations = get_activations_pre(
    base_model.model,
    base_model.tokenizer,
    cyber_train_instructions,
    tokenize_instructions_fn=base_model.tokenize_instructions_fn,
    block_modules=base_model.model_block_modules,
    positions=[-1],
    batch_size=4
)

100%|██████████| 30/30 [00:18<00:00,  1.59it/s]
100%|██████████| 60/60 [00:45<00:00,  1.33it/s]


In [11]:
bio_ablated_activations = get_activations_pre(
    ablated_model.model,
    ablated_model.tokenizer,
    bio_train_instructions,
    tokenize_instructions_fn=ablated_model.tokenize_instructions_fn,
    block_modules=ablated_model.model_block_modules,
    positions=[-1],
    batch_size=4
)

cyber_ablated_activations = get_activations_pre(
    ablated_model.model,
    ablated_model.tokenizer,
    cyber_train_instructions,
    tokenize_instructions_fn=ablated_model.tokenize_instructions_fn,
    block_modules=ablated_model.model_block_modules,
    positions=[-1],
    batch_size=4
)

100%|██████████| 60/60 [00:37<00:00,  1.58it/s]
100%|██████████| 60/60 [00:40<00:00,  1.48it/s]


### Analysis

In [12]:
bio_rmu_sim = torch.nn.functional.cosine_similarity(bio_base_activations, bio_rmu_activations, dim=-1)
cyber_rmu_sim = torch.nn.functional.cosine_similarity(cyber_base_activations, cyber_rmu_activations, dim=-1)

In [13]:
bio_ablated_sim = torch.nn.functional.cosine_similarity(bio_base_activations, bio_ablated_activations, dim=-1)
cyber_ablated_sim = torch.nn.functional.cosine_similarity(cyber_base_activations, cyber_ablated_activations, dim=-1)

In [21]:
plot_scores_plotly(
    scores_group={
        "Biology (RMU)": bio_rmu_sim.squeeze(-1),
        "Cybersecurity (RMU)": cyber_rmu_sim.squeeze(-1),
        "Biology (Ablated)": bio_ablated_sim.squeeze(-1),
        "Cybersecurity (Ablated)": cyber_ablated_sim.squeeze(-1),
    },
    layer_of_interest=LAYER_ID + 1,
    title="Cosine Similarity between Base and RMU/Ablated Model activations at the last token position",
    ylabel="Cosine similarity",
    artifact_dir=ARTIFACT_DIR,
    artifact_name="cosine_ablated_base",
    # color_sequence=[
    #     "#F38181",
    #     "#FCE38A",
    #     "#EAFFD0",
    #     "#95E1D3"
    # ]
)

Plot saved to results/ablated/figures/cosine_ablated_base.pdf


In [16]:
bio_rmu_norm = torch.linalg.norm(bio_rmu_activations, dim=-1)
bio_base_norm = torch.linalg.norm(bio_base_activations, dim=-1)
bio_ablated_norm = torch.linalg.norm(bio_ablated_activations, dim=-1)

cyber_rmu_norm = torch.linalg.norm(cyber_rmu_activations, dim=-1)
cyber_base_norm = torch.linalg.norm(cyber_base_activations, dim=-1)
cyber_ablated_norm = torch.linalg.norm(cyber_ablated_activations, dim=-1)

In [22]:
plot_scores_plotly(
    scores_group={
        "RMU Model": bio_rmu_norm.squeeze(-1),
        "Base Model": bio_base_norm.squeeze(-1),
        "Ablated Model": bio_ablated_norm.squeeze(-1),
    },
    layer_of_interest=LAYER_ID + 1,
    title="Norm of Base/RMU/Ablated Model activations at the last token position (on Biology)",
    ylabel="Norm",
    artifact_dir=ARTIFACT_DIR,
    artifact_name="bio_norm_ablated_base",
)

Plot saved to results/ablated/figures/bio_norm_ablated_base.pdf


In [23]:
plot_scores_plotly(
    scores_group={
        "RMU Model": cyber_rmu_norm.squeeze(-1),
        "Base Model": cyber_base_norm.squeeze(-1),
        "Ablated Model": cyber_ablated_norm.squeeze(-1),
    },
    layer_of_interest=LAYER_ID + 1,
    title="Norm of Base/RMU/Ablated Model activations at the last token position (on Cyber)",
    ylabel="Norm",
    artifact_dir=ARTIFACT_DIR,
    artifact_name="cyber_norm_ablated_base",
)

Plot saved to results/ablated/figures/cyber_norm_ablated_base.pdf
