In [1]:
from datasets import load_dataset, concatenate_datasets
import torch

from sae_refusal.data import (
    load_wmdp, 
    split,
    to_instructions,
    sample_data
)
from sae_refusal.model.gemma import GemmaModel
from sae_refusal.probe.utils import generate_and_save_with_activations
from transformers import GenerationConfig
from sae_refusal import set_seed, clear_memory
from sklearn.model_selection import train_test_split

In [2]:
SAMPLE_SIZE = 500  # For each dataset
SEED = 42

ARTIFACT_DIR = "results"

MODEL_NAME = "google/gemma-2-2b-it"
RMU_NAME = "lenguyen1807/gemma-2-2b-it-RMU"
EMBEDDING_NAME = "sentence-transformers/all-mpnet-base-v2"

In [3]:
set_seed(SEED)
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")

### Load data

In [4]:
wmdp_bio, wmdp_cyber = load_wmdp(sample_size=None, seed=SEED)
alpaca_ds = load_dataset("tatsu-lab/alpaca", split="train")

In [5]:
len(wmdp_bio), len(wmdp_cyber), len(alpaca_ds)

(1273, 1987, 52002)

We take only first 500 samples to train our probes

In [6]:
bio_sample = wmdp_bio.select(list(range(SAMPLE_SIZE)))
cyber_sample = wmdp_cyber.select(list(range(SAMPLE_SIZE)))

In [7]:
len(bio_sample), len(cyber_sample)

(500, 500)

For harmless prompt, we take the same samples (SAMPLE_SIZE * 2)

In [8]:
# Sample random 1000
alpaca_sample = sample_data(
    alpaca_ds.filter(lambda example: example['input'].strip() == ''),
    sample_size=SAMPLE_SIZE * 2,
    seed=SEED,
    max_len=None,
    key="instruction"
)

In [9]:
harmful_total = to_instructions(
    concatenate_datasets([bio_sample, cyber_sample]).shuffle(seed=SEED), # shuffle training data
    format_fn=lambda x: x["question"]
)

harmless_total = to_instructions(
    alpaca_sample,
    format_fn=lambda x: x["instruction"]
)

In [10]:
total_dataset = harmful_total + harmless_total

In [11]:
len(total_dataset)

2000

In [12]:
total_train, total_test = train_test_split(total_dataset, test_size=0.2, random_state=SEED)

In [13]:
len(total_train), len(total_test)

(1600, 400)

In [14]:
total_train, total_val = train_test_split(total_train, test_size=0.2, random_state=SEED)

In [15]:
len(total_train), len(total_val)

(1280, 320)

### Generate activations and outputs

In [16]:
rmu_model = GemmaModel(RMU_NAME, type="instruction")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [17]:
# rmu_outputs, rmu_activations = generate_and_save_with_activations(
#     base_model=rmu_model,
#     instructions=total_train,
#     max_new_tokens=128,
#     batch_size=8
# )

In [18]:
# rmu_outputs, rmu_activations = generate_and_save_with_activations(
#     base_model=rmu_model,
#     instructions=total_val,
#     max_new_tokens=128,
#     batch_size=8
# )

In [19]:
# rmu_outputs, rmu_activations = generate_and_save_with_activations(
#     base_model=rmu_model,
#     instructions=total_test,
#     max_new_tokens=128,
#     batch_size=8
# )

In [20]:
# from pathlib import Path
# import json

# save_path = Path(f"{ARTIFACT_DIR}/probes_it/test")
# save_path.mkdir(parents=True, exist_ok=True)

In [21]:
# for idx, activations in enumerate(rmu_activations):
#     torch.save(activations, f"{save_path.as_posix()}/layer_{idx}.pt")

In [22]:
# with open(f"{save_path.as_posix()}/completions.json", mode="w+", encoding="utf-8") as f:
#     json.dump(rmu_outputs, f, indent=4)

### Sanity Checks

In [23]:
from pathlib import Path
import json

test_path = Path(f"{ARTIFACT_DIR}/probes_it/test")
train_path = Path(f"{ARTIFACT_DIR}/probes_it/train")
val_path = Path(f"{ARTIFACT_DIR}/probes_it/val")

In [24]:
with open(f"{test_path.as_posix()}/completions.json", "r") as f:
    test_completions = json.load(f)

with open(f"{train_path.as_posix()}/completions.json", "r") as f:
    train_completions = json.load(f)

with open(f"{val_path.as_posix()}/completions.json", "r") as f:
    val_completions = json.load(f)

First checking if match original dataset

In [25]:
def match_original(completions, original):
    for completion, data in zip(completions, original):
        if completion["prompt"] != data:
            return False
    return True 

In [26]:
match_original(test_completions, total_test)

True

In [27]:
match_original(val_completions, total_val)

True

In [28]:
match_original(train_completions, total_train)

True

Then checking if there is any overlap data

In [29]:
def is_overlap(completions1, completions2):
    for completion in completions1:
        if completion["prompt"] in completions2:
            return True 
    return False

In [30]:
is_overlap(train_completions, [completion["prompt"] for completion in val_completions])

False

In [31]:
is_overlap(train_completions + val_completions, [completion["prompt"] for completion in test_completions])

False