In [None]:
import torch
import einops
from datasets import concatenate_datasets, load_dataset, Dataset
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.sentence as nas
import random
from sklearn.model_selection import train_test_split 

from sae_refusal.model.gemma import GemmaModel
from sae_refusal.model.embedding import EmbeddingModel
from sae_refusal import set_seed, clear_memory
from sae_refusal.data import (
    load_wmdp,
    split,
    to_instructions,
    sample_data,
    load_harmbench,
    load_adv_malicious_instruct
)
from sae_refusal.pipeline.select_directions import select_direction
from sae_refusal.pipeline.utils import compute_pca, generate_and_save_completions
from sae_refusal.probe import LinearProbe
from sae_refusal.pipeline.generate_directions import generate_directions_rmu
from sae_refusal.pipeline.hook import (
    get_all_direction_ablation_hooks,
    get_all_direction_ablation_hooks_rmu,
    get_activation_addition_input_pre_hook,
    get_activation_addition_input_pre_hook_rmu
)
from sae_refusal.plot import (
    plot_scores_plotly,
    plot_refusal_scores_plotly,
)
from sae_refusal.pipeline.activations import (
    get_activations,
    get_activations_pre
)
from IPython.display import display, Markdown


In [None]:
LAYER_ID = 9
SAMPLE_SIZE = 300  # For each dataset
VAL_SIZE = 0.2
SEED = 42
MAX_LEN = 1024

ARTIFACT_DIR = "results/ablated"

MODEL_NAME = "google/gemma-2-2b"
RMU_NAME = "lenguyen1807/gemma-2-2b-RMU"

In [None]:
set_seed(SEED)
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")

### Load model

In [None]:
base_model = GemmaModel(MODEL_NAME, type="none")
rmu_model = GemmaModel(RMU_NAME, type="none")

In [None]:
probe_weight = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/weight.pt", weights_only=True)
probe = LinearProbe(base_model.model.config.hidden_size)
probe.load_state_dict(probe_weight)

In [None]:
mean = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/mean.pt", weights_only=True)
std = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/std.pt", weights_only=True)

### Load dataset

#### 1. Load train dataset

In [None]:
harmbench_bio, harmbench_cyber = load_harmbench()

In [None]:
adv_bio, adv_cyber = load_adv_malicious_instruct()

In [None]:
catqa = load_dataset("walledai/CatHarmfulQA", split="en")
catqa = catqa.filter(lambda x: x["subcategory"] in ["Hazardous material handling", "Drug"])
catqa = catqa.remove_columns(["category", "subcategory"])

In [None]:
total_bio = concatenate_datasets([harmbench_bio, catqa, adv_bio])
total_cyber = concatenate_datasets([harmbench_cyber, adv_cyber])

In [None]:
# Augment text
synonym_aug = naw.SynonymAug(aug_src="wordnet")  # synonym replacement
contextual_aug = naw.ContextualWordEmbsAug(
    model_path="bert-base-uncased", action="substitute", device="cuda"
)
back_translation_aug = naw.BackTranslationAug(
    from_model_name="facebook/wmt19-en-de",
    to_model_name="facebook/wmt19-de-en",
    device="cuda",
)

In [None]:
augmenters = [synonym_aug, contextual_aug, back_translation_aug]

In [None]:
def augment_text(texts, augmenters, max_time=3, target_size=SAMPLE_SIZE):
    counts = {i: 0 for i in range(len(texts))}
    augmented = texts.copy()

    while len(augmented) < target_size:
        idx = random.randrange(len(texts))
        if counts[idx] >= max_time:
            # skip if this text already augmented enough
            continue

        text = texts[idx]
        aug = random.choice(augmenters)
        try:
            new_text = aug.augment(text)
        except Exception as e:
            # in case an augmenter fails, skip
            continue

        if new_text not in augmented:
            augmented.append(new_text)
            counts[idx] += 1

    return augmented

In [None]:
bio_augment = augment_text(total_bio["prompt"], augmenters=augmenters)
cyber_augment = augment_text(total_cyber["prompt"], augmenters=augmenters)

In [None]:
bio_augment = [" ".join(bio) if isinstance(bio, list) else bio for bio in bio_augment]
cyber_augment = [" ".join(cyber) if isinstance(cyber, list) else cyber for cyber in cyber_augment]

In [None]:
with open('/home/ubuntu/thesis/sae/results/ablated/data/bio_augment.txt', 'w+') as f:
    for txt in bio_augment:
        f.write(txt + "\n")

with open('/home/ubuntu/thesis/sae/results/ablated/data/cyber_augment.txt', 'w+') as f:
    for txt in cyber_augment:
        f.write(txt + "\n")

In [None]:
bio_train_instructions, bio_val_instructions = train_test_split(
    bio_augment, test_size=VAL_SIZE
)
cyber_train_instructions, cyber_val_instructions = train_test_split(
    cyber_augment, test_size=VAL_SIZE
)

In [None]:
val_instructions = bio_val_instructions + cyber_val_instructions
random.shuffle(val_instructions)

### Find directions

In [None]:
bio_diff = generate_directions_rmu(
    base_model,
    rmu_model,
    bio_train_instructions,
    batch_size=8,
    positions=list(range(-5, 0))
)

cyber_diff = generate_directions_rmu(
    base_model,
    rmu_model,
    cyber_train_instructions,
    batch_size=8,
    positions=list(range(-5, 0))
)

In [None]:
shared_diff = (bio_diff + cyber_diff) / 2

### Evaluate directions

In [None]:
results = select_direction(
    base_model=base_model,
    rmu_model=rmu_model,
    instructions=val_instructions,
    candidate_directions=[shared_diff],
    probe=probe.to(base_model.model.device),
    probe_mean=mean.to(base_model.model.device),
    probe_std=std.to(base_model.model.device),
    artifact_dir=f"{ARTIFACT_DIR}/direction_general_same",
    junk_layer=LAYER_ID + 1,
    batch_size=8
)

In [None]:
metrics, pos, layer, best_dirs = results

In [None]:
torch.save(best_dirs[0], f"{ARTIFACT_DIR}/direction_general/best_dir.pt")

### Visualization

In [None]:
def plot(token_labels, artifact_dir, junk_layer, metrics):
    plot_refusal_scores_plotly(
        refusal_scores=metrics["ablation"],
        baseline_refusal_score=metrics["base_ablation"].mean().item(),
        artifact_dir=artifact_dir,
        token_labels=token_labels,
        layers=[junk_layer],
        title="Ablating direction on RMU model",
        ylabel="Junk Score",
        artifact_name="ablation_scores",
    )
    plot_refusal_scores_plotly(
        refusal_scores=metrics["steering"],
        baseline_refusal_score=metrics["base_steer"].mean().item(),
        artifact_dir=artifact_dir,
        token_labels=token_labels,
        layers=[junk_layer],
        title="Adding direction on Base model",
        ylabel="Junk Score",
        artifact_name="steering_scores",
    )
    plot_refusal_scores_plotly(
        refusal_scores=metrics["kl"],
        baseline_refusal_score=0.0,
        artifact_dir=artifact_dir,
        token_labels=token_labels,
        layers=[junk_layer],
        title="KL Divergence when ablating direction on RMU model",
        ylabel="KL Divergence Score",
        artifact_name="kl_div_score",
    )

In [None]:
plot(
    token_labels=["5th-last token", "4th-last token", "3rd-last token", "2nd-last token", "Last token"], 
    artifact_dir=f"{ARTIFACT_DIR}/direction_general_same",
    junk_layer=LAYER_ID + 1,
    metrics=metrics
)