In [1]:
import torch
import einops
from datasets import concatenate_datasets
import random

from sae_refusal.model.gemma import GemmaModel
from sae_refusal.model.embedding import EmbeddingModel
from sae_refusal import set_seed, clear_memory
from sae_refusal.data import (
    load_wmdp,
    split,
    to_instructions,
    sample_data
)
from sae_refusal.probe import LinearProbe
from sae_refusal.pipeline.utils import (
    generate_and_save_completions
)
from sae_refusal.pipeline.generate_directions import generate_directions_rmu
from sae_refusal.plot import (
    plot_scores_plotly,
    plot_refusal_scores_plotly,
)
from sae_refusal.pipeline.activations import (
    get_activations,
    get_activations_pre
)

In [2]:
AYER_ID = 9
SAMPLE_SIZE = 300  # For each dataset
VAL_SIZE = 0.2
SEED = 42
MAX_LEN = 1024

ARTIFACT_DIR = "results/ablated"

MODEL_NAME = "google/gemma-2-2b"
RMU_NAME = "lenguyen1807/gemma-2-2b-RMU"
ABLATED_NAME = "lenguyen1807/gemma-2-2b-RMU-ablated"

In [3]:
set_seed(SEED)
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")

### Load Data

In [4]:
# Filter which question
def filter_which(data):
    return data if not data["question"].lower().startswith("which") else None

In [5]:
wmdp_bio, wmdp_cyber = load_wmdp(sample_size=None, seed=SEED)

In [6]:
# Remove first 500 data (which is used by probe)
bio_data_raw = wmdp_bio.select(list(range(500, len(wmdp_bio))))
cyber_data_raw = wmdp_cyber.select(list(range(500, len(wmdp_cyber))))

In [7]:
# Then filter out which function
bio_data = bio_data_raw.filter(filter_which)
cyber_data = cyber_data_raw.filter(filter_which)

# Then take sample (400)
# And filter MAX_LEN function (only 1024)
bio_sample = sample_data(
    bio_data,
    sample_size=SAMPLE_SIZE,
    seed=SEED,
    max_len=MAX_LEN,
    key="question"
)

cyber_sample = sample_data(
    cyber_data,
    sample_size=SAMPLE_SIZE,
    seed=SEED,
    max_len=MAX_LEN, # We take only 1024 max length for faster calculation
    key="question"
)

In [8]:
bio_train, bio_val = split(bio_sample, test_size=VAL_SIZE)
cyber_train, cyber_val = split(cyber_sample, test_size=VAL_SIZE)

In [9]:
len(bio_train), len(bio_val)

(240, 60)

In [10]:
left_bio = bio_data_raw.filter(lambda x: x not in bio_sample)
left_cyber = cyber_data_raw.filter(lambda x: x not in cyber_sample)

Filter:   0%|          | 0/773 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1487 [00:00<?, ? examples/s]

In [11]:
len(left_bio), len(left_cyber)

(473, 1187)

In [12]:
# sample 100 data each for t-test evaluate
t_bio = left_bio.select(range(100))
t_cyber = left_cyber.select(range(100))

In [13]:
left_bio = left_bio.select(range(100, len(left_bio)))
left_cyber = left_cyber.select(range(100, len(left_cyber)))

In [14]:
bio_train.save_to_disk(f"{ARTIFACT_DIR}/data/bio_train")
bio_val.save_to_disk(f"{ARTIFACT_DIR}/data/bio_val")
cyber_train.save_to_disk(f"{ARTIFACT_DIR}/data/cyber_train")
cyber_val.save_to_disk(f"{ARTIFACT_DIR}/data/cyber_val")
t_bio.save_to_disk(f"{ARTIFACT_DIR}/data/t_bio")
t_cyber.save_to_disk(f"{ARTIFACT_DIR}/data/t_cyber")
left_bio.save_to_disk(f"{ARTIFACT_DIR}/data/left_bio")
left_cyber.save_to_disk(f"{ARTIFACT_DIR}/data/left_cyber")

Saving the dataset (0/1 shards):   0%|          | 0/240 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/60 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/240 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/60 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/373 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1087 [00:00<?, ? examples/s]