In [1]:
from datasets import load_dataset, concatenate_datasets
import torch

from sae_refusal.data import (
    load_wmdp, 
    split,
    to_instructions,
    sample_data
)
from sae_refusal.model.gemma import GemmaModel
from sae_refusal.pipeline.activations import get_activations
from transformers import GenerationConfig
from sae_refusal import set_seed, clear_memory
from sklearn.model_selection import train_test_split

In [2]:
SAMPLE_SIZE = 500  # For each dataset
SEED = 42

ARTIFACT_DIR = "results/ablated"

MODEL_NAME = "google/gemma-2-2b"
RMU_NAME = "lenguyen1807/gemma-2-2b-RMU"
EMBEDDING_NAME = "sentence-transformers/all-mpnet-base-v2"

In [3]:
set_seed(SEED)
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")

In [4]:
TYPE = "test"

### Load data

In [5]:
wmdp_bio, wmdp_cyber = load_wmdp(sample_size=None, seed=SEED)
alpaca_ds = load_dataset("tatsu-lab/alpaca", split="train")

In [6]:
len(wmdp_bio), len(wmdp_cyber), len(alpaca_ds)

(1273, 1987, 52002)

We take only first 500 samples to train our probes

In [7]:
bio_sample = wmdp_bio.select(list(range(SAMPLE_SIZE)))
cyber_sample = wmdp_cyber.select(list(range(SAMPLE_SIZE)))

In [8]:
len(bio_sample), len(cyber_sample)

(500, 500)

For harmless prompt, we take the same samples (SAMPLE_SIZE * 2)

In [9]:
# Sample random 1000
alpaca_sample = sample_data(
    alpaca_ds.filter(lambda example: example['input'].strip() == ''),
    sample_size=SAMPLE_SIZE * 2,
    seed=SEED,
    max_len=None,
    key="instruction"
)

In [10]:
harmful_total = to_instructions(
    concatenate_datasets([bio_sample, cyber_sample]).shuffle(seed=SEED), # shuffle training data
    format_fn=lambda x: x["question"]
)

harmless_total = to_instructions(
    alpaca_sample,
    format_fn=lambda x: x["instruction"]
)

In [11]:
total_dataset = harmful_total + harmless_total

In [12]:
len(total_dataset)

2000

In [13]:
total_train, total_test = train_test_split(total_dataset, test_size=0.2, random_state=SEED)

In [14]:
len(total_train), len(total_test)

(1600, 400)

In [15]:
total_train, total_val = train_test_split(total_train, test_size=0.2, random_state=SEED)

In [16]:
len(total_train), len(total_val)

(1280, 320)

### Generate activations and outputs

In [17]:
if TYPE == "train":
    instructions = total_train
elif TYPE == "val":
    instructions = total_val
else:
    instructions = total_test

In [18]:
rmu_model = GemmaModel(RMU_NAME, type="none")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [19]:
rmu_activations = get_activations(
    model=rmu_model.model,
    tokenizer=rmu_model.tokenizer,
    instructions=instructions,
    tokenize_instructions_fn=rmu_model.tokenize_instructions_fn,
    block_modules=rmu_model.model_block_modules,
    positions=[-1],
    batch_size=8
)

100%|██████████| 50/50 [00:39<00:00,  1.26it/s]


In [20]:
from pathlib import Path

save_path = Path(f"{ARTIFACT_DIR}/probes/{TYPE}")
save_path.mkdir(parents=True, exist_ok=True)

In [21]:
for idx in range(rmu_model.model.config.num_hidden_layers):
    torch.save(rmu_activations.squeeze(2)[:, idx, :], f"{save_path.as_posix()}/layer_{idx}.pt")

### Sanity Checks

In [None]:
from pathlib import Path
import json

test_path = Path(f"{ARTIFACT_DIR}/probes/test")
train_path = Path(f"{ARTIFACT_DIR}/probes/train")
val_path = Path(f"{ARTIFACT_DIR}/probes/val")

In [None]:
with open(f"{test_path.as_posix()}/completions.json", "r") as f:
    test_completions = json.load(f)

with open(f"{train_path.as_posix()}/completions.json", "r") as f:
    train_completions = json.load(f)

with open(f"{val_path.as_posix()}/completions.json", "r") as f:
    val_completions = json.load(f)

First checking if match original dataset

In [None]:
def match_original(completions, original):
    for completion, data in zip(completions, original):
        if completion["prompt"] != data:
            return False
    return True 

In [None]:
match_original(test_completions, total_test)

In [None]:
match_original(val_completions, total_val)

In [None]:
match_original(train_completions, total_train)

Then checking if there is any overlap data

In [None]:
def is_overlap(completions1, completions2):
    for completion in completions1:
        if completion["prompt"] in completions2:
            return True 
    return False

In [None]:
is_overlap(train_completions, [completion["prompt"] for completion in val_completions])

In [None]:
is_overlap(train_completions + val_completions, [completion["prompt"] for completion in test_completions])