In [1]:
import torch
import einops
from datasets import concatenate_datasets, load_dataset

from sae_refusal.model.gemma import GemmaModel
from sae_refusal.model.embedding import EmbeddingModel
from sae_refusal import set_seed, clear_memory
from sae_refusal.data import (
    load_wmdp,
    split,
    to_instructions,
    sample_data,
    load_harmbench
)
from sae_refusal.pipeline.select_directions import select_direction
from sae_refusal.pipeline.utils import compute_pca, generate_and_save_completions
from sae_refusal.probe import LinearProbe
from sae_refusal.pipeline.generate_directions import generate_directions_rmu
from sae_refusal.pipeline.hook import (
    get_all_direction_ablation_hooks,
    get_all_direction_ablation_hooks_rmu,
    get_activation_addition_input_pre_hook,
    get_activation_addition_input_pre_hook_rmu
)
from sae_refusal.plot import (
    plot_scores_plotly,
    plot_refusal_scores_plotly,
)
from sae_refusal.pipeline.activations import (
    get_activations,
    get_activations_pre
)
from IPython.display import display, Markdown

In [2]:
LAYER_ID = 9
SAMPLE_SIZE = 300  # For each dataset
VAL_SIZE = 0.2
SEED = 42
MAX_LEN = 1024

ARTIFACT_DIR = "results/ablated"

MODEL_NAME = "google/gemma-2-2b"
RMU_NAME = "lenguyen1807/gemma-2-2b-RMU"

In [3]:
set_seed(SEED)
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")

### Load model

In [4]:
base_model = GemmaModel(MODEL_NAME, type="none")
rmu_model = GemmaModel(RMU_NAME, type="none")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
probe_weight = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/weight.pt", weights_only=True)
probe = LinearProbe(base_model.model.config.hidden_size)
probe.load_state_dict(probe_weight)

<All keys matched successfully>

In [6]:
mean = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/mean.pt", weights_only=True)
std = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/std.pt", weights_only=True)

### Load dataset

#### 1. Load train dataset

In [7]:
adv_bench = load_dataset("walledai/AdvBench", split="train")
adv_bench = adv_bench.remove_columns("target")

In [8]:
minstruct = load_dataset("walledai/MaliciousInstruct", split="train")

In [9]:
harmbench_ds = load_dataset("walledai/HarmBench", "standard", split="train")
harmbench_ds = harmbench_ds.remove_columns("category")

In [10]:
train_instructions, val_instructions = to_instructions(
    data=split(
        sample_data(
            data=concatenate_datasets([adv_bench, minstruct, harmbench_ds]),
            sample_size=SAMPLE_SIZE,
            seed=SEED,
            max_len=None,
            key="prompt",
        ),
        test_size=VAL_SIZE,
    ),
    format_fn=lambda x: x["prompt"],
)

In [11]:
len(train_instructions), len(val_instructions)

(240, 60)

### Find directions

In [12]:
diff = generate_directions_rmu(
    base_model,
    rmu_model,
    train_instructions,
    batch_size=8,
    positions=list(range(-5, 0))
)

Calculating mean: 100%|██████████| 30/30 [00:18<00:00,  1.60it/s]
Calculating mean: 100%|██████████| 30/30 [00:17<00:00,  1.73it/s]


### Evaluate directions

In [13]:
results = select_direction(
    base_model=base_model,
    rmu_model=rmu_model,
    instructions=val_instructions,
    candidate_directions=[diff],
    probe=probe.to(base_model.model.device),
    probe_mean=mean.to(base_model.model.device),
    probe_std=std.to(base_model.model.device),
    artifact_dir=f"{ARTIFACT_DIR}/direction_general",
    junk_layer=LAYER_ID + 1,
    batch_size=8
)

Calculating Junk Scores: 100%|██████████| 8/8 [00:04<00:00,  1.84it/s]
Calculating Junk Scores: 100%|██████████| 8/8 [00:04<00:00,  1.78it/s]
Computing KL between Ablation and Non-Ablation for Base Model at source position -5: 100%|██████████| 26/26 [01:54<00:00,  4.41s/it]
Computing KL between Ablation and Non-Ablation for Base Model at source position -4: 100%|██████████| 26/26 [01:55<00:00,  4.44s/it]
Computing KL between Ablation and Non-Ablation for Base Model at source position -3: 100%|██████████| 26/26 [01:54<00:00,  4.42s/it]
Computing KL between Ablation and Non-Ablation for Base Model at source position -2: 100%|██████████| 26/26 [01:59<00:00,  4.58s/it]
Computing KL between Ablation and Non-Ablation for Base Model at source position -1: 100%|██████████| 26/26 [01:55<00:00,  4.45s/it]
Computing Junk Ablation on RMU Model at source position -5: 100%|██████████| 26/26 [01:49<00:00,  4.21s/it]
Computing Junk Ablation on RMU Model at source position -4: 100%|██████████| 26/26 [0






In [14]:
metrics, pos, layer, best_dirs = results

### Visualization

In [16]:
def plot(token_labels, artifact_dir, junk_layer, metrics):
    plot_refusal_scores_plotly(
        refusal_scores=metrics["ablation"],
        baseline_refusal_score=metrics["base_ablation"].mean().item(),
        artifact_dir=artifact_dir,
        token_labels=token_labels,
        layers=[junk_layer],
        title="Ablating direction on RMU model",
        ylabel="Junk Score",
        artifact_name="ablation_scores",
    )
    plot_refusal_scores_plotly(
        refusal_scores=metrics["steering"],
        baseline_refusal_score=metrics["base_steer"].mean().item(),
        artifact_dir=artifact_dir,
        token_labels=token_labels,
        layers=[junk_layer],
        title="Adding direction on Base model",
        ylabel="Junk Score",
        artifact_name="steering_scores",
    )
    plot_refusal_scores_plotly(
        refusal_scores=metrics["kl"],
        baseline_refusal_score=0.0,
        artifact_dir=artifact_dir,
        token_labels=token_labels,
        layers=[junk_layer],
        title="KL Divergence when ablating direction on RMU model",
        ylabel="KL Divergence Score",
        artifact_name="kl_div_score",
    )

In [19]:
plot(
    token_labels=["5th-last token", "4th-last token", "3rd-last token", "2nd-last token", "Last token"], 
    artifact_dir=f"{ARTIFACT_DIR}/direction_general",
    junk_layer=LAYER_ID + 1,
    metrics=metrics
)

Plot saved to results/ablated/direction_general/figures/ablation_scores.pdf


Plot saved to results/ablated/direction_general/figures/steering_scores.pdf


Plot saved to results/ablated/direction_general/figures/kl_div_score.pdf
