In [1]:
from datasets import load_dataset, concatenate_datasets
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import einops
import random
import pandas as pd
from IPython.display import display, Markdown
import matplotlib.pyplot as plt
import json
from tqdm.auto import tqdm

from sae_refusal.model.gemma import GemmaModel
from sae_refusal.pipeline.utils import (
    calculate_bigram_repetition
)
from sae_refusal.model.embedding import EmbeddingModel
from sae_refusal import set_seed, clear_memory
from sae_refusal.pipeline.activations import (
    get_activations
)
from sae_refusal.data import (
    load_wmdp,
    split,
    to_instructions,
)
from sae_refusal.data import (
    sample_data, 
    to_instructions
)
from sae_refusal.plot import (
     plot_scores_plotly
)
from sae_refusal.probe import LinearProbe
from sae_refusal.pipeline.generate_directions import generate_directions_rmu
import plotly.express as px
import plotly.graph_objects as go

In [2]:
LAYER_ID = 9
DATA_SIZE = 500 
SAMPLE_SIZE = 300
VAL_SIZE = 0.2
SEED = 42

ARTIFACT_DIR = "results/ablated"

EMBEDDING_NAME = "sentence-transformers/all-mpnet-base-v2"

In [3]:
MODELS = [
    {
        "model_name_or_path": "google/gemma-2-2b-it",
         "type": "instruction"
    },
    {
        "model_name_or_path": "lenguyen1807/gemma-2-2b-it-RMU",
        "type": "instruction"
    },
    {
        "model_name_or_path": "google/gemma-2-2b",
        "type": "none"
    },
    {
        "model_name_or_path": "lenguyen1807/gemma-2-2b-RMU",
        "type": "none"
    },
]

In [4]:
set_seed(SEED)
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")

In [5]:
NEW_TOKEN = 0
BATCH_SIZE = 32

### Load model

In [6]:
base_model = GemmaModel(**MODELS[0])
rmu_model = GemmaModel(**MODELS[1])

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
probe_weight = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/weight.pt", weights_only=True)
probe = LinearProbe(base_model.model.config.hidden_size)
probe.load_state_dict(probe_weight)

<All keys matched successfully>

In [8]:
mean = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/mean.pt", weights_only=True)
std = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/std.pt", weights_only=True)

### Load dataset

In [9]:
wmdp_bio, wmdp_cyber = load_wmdp(sample_size=None, seed=SEED)

In [10]:
bio_select = wmdp_bio.select(list(range(DATA_SIZE, len(wmdp_bio)))) # avoid probe training data
cyber_select = wmdp_cyber.select(list(range(DATA_SIZE, len(wmdp_cyber))))

In [11]:
# Then we sample data
bio_sample = sample_data(bio_select, sample_size=SAMPLE_SIZE, seed=SEED, key="question")
cyber_sample = sample_data(cyber_select, sample_size=SAMPLE_SIZE, seed=SEED, key="question")

In [12]:
bio_train, bio_val = split(bio_sample, VAL_SIZE)
cyber_train, cyber_val = split(cyber_sample, VAL_SIZE)

In [13]:
(
    cyber_train_instructions,
    bio_train_instructions,
    val_instructions,  # we combine both for validation
) = to_instructions(
    [cyber_train, bio_train, concatenate_datasets([bio_val, cyber_val])],
    lambda x: x["question"],
)

### Get activations

In [14]:
base_activations = get_activations(
    base_model.model,
    base_model.tokenizer,
    val_instructions,
    tokenize_instructions_fn=base_model.tokenize_instructions_fn,
    block_modules=base_model.model_block_modules,
    positions=[-1],
    batch_size=BATCH_SIZE
).squeeze(dim=2)

100%|██████████| 4/4 [00:04<00:00,  1.03s/it]


In [15]:
rmu_activations = get_activations(
    rmu_model.model,
    rmu_model.tokenizer,
    val_instructions,
    tokenize_instructions_fn=rmu_model.tokenize_instructions_fn,
    block_modules=rmu_model.model_block_modules,
    positions=[-1],
    batch_size=BATCH_SIZE
).squeeze(dim=2)

100%|██████████| 4/4 [00:03<00:00,  1.16it/s]


In [16]:
rmu_scores = []

for layer in range(rmu_model.model.config.num_hidden_layers):
    # normalize data
    activation = rmu_activations[:, layer, :].to(torch.float32).cpu()
    activation = (activation - mean) / (std + 1e-8)
    rmu_scores.append(probe(activation))

In [17]:
base_scores = []

for layer in range(base_model.model.config.num_hidden_layers):
    activation = base_activations[:, layer, :].to(torch.float32).cpu()
    activation = (activation - mean) / (std + 1e-8)
    base_scores.append(probe(activation))

In [18]:
rmu_scores = torch.stack(rmu_scores, dim=1)

In [19]:
base_scores = torch.stack(base_scores, dim=1)

In [None]:
rmu_model.model.lm_head

### Visualization

In [21]:
plot_scores_plotly(
    scores_group={
        "RMU Model": rmu_scores,
        "Base Model": base_scores,
    },
    layer_of_interest=LAYER_ID + 1,
    title=f"Junk Score between Base and RMU model activations of new {NEW_TOKEN} tokens" if NEW_TOKEN != 0 else "Junk Score between Base and RMU model activations on last token",
    ylabel="Junk Score",
    artifact_dir=ARTIFACT_DIR,
    artifact_name="junk_score_it_last",
)

Plot saved to results/ablated/figures/junk_score_it_last.pdf
