In [1]:
import torch
import einops
from datasets import concatenate_datasets, load_from_disk
import numpy as np

from sae_refusal.model.gemma import GemmaModel
from sae_refusal.model.embedding import EmbeddingModel
from sae_refusal import set_seed, clear_memory
from sae_refusal.data import (
    load_wmdp,
    split,
    to_instructions,
    sample_data,
    load_harmbench
)
from sae_refusal.pipeline.select_directions import select_direction, get_junk_scores
from sae_refusal.pipeline.utils import compute_pca, generate_and_save_completions, proj, evaluate_multiple_choice
from sae_refusal.probe import LinearProbe
from sae_refusal.pipeline.generate_directions import generate_directions_rmu
from sae_refusal.pipeline.hook import (
    get_all_direction_ablation_hooks,
    get_all_direction_ablation_hooks_rmu,
    get_activation_addition_input_pre_hook,
    get_activation_addition_input_pre_hook_rmu
)
from sae_refusal.plot import (
    plot_scores_plotly,
    plot_refusal_scores_plotly,
)
from sae_refusal.pipeline.activations import (
    get_activations,
    get_activations_pre
)
from IPython.display import display, Markdown

In [2]:
LAYER_ID = 9
SAMPLE_SIZE = 300  # For each dataset
VAL_SIZE = 0.2
SEED = 42
MAX_LEN = 1024

ARTIFACT_DIR = "results/ablated"

MODEL_NAME = "google/gemma-2-2b"
RMU_NAME = "lenguyen1807/gemma-2-2b-RMU"

In [3]:
set_seed(SEED)
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")

In [4]:
TYPE = "base"

### Load model

In [5]:
rmu_model = GemmaModel(RMU_NAME, type="none")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
probe_weight = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/weight.pt", weights_only=True)
probe = LinearProbe(rmu_model.model.config.hidden_size)
probe.load_state_dict(probe_weight)

<All keys matched successfully>

In [7]:
mean = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/mean.pt", weights_only=True)
std = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/std.pt", weights_only=True)

### Load data

In [8]:
bio_test = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/t_bio")
cyber_test = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/t_cyber")

In [9]:
bio_final = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/left_bio")
cyber_final = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/left_cyber")

In [10]:
len(bio_final), len(cyber_final)

(373, 1087)

### Load directions

In [11]:
all_dirs = dict()

for dir in ["bio", "cyber", "general", "shared", "general_same"]:
    path = f"/home/ubuntu/thesis/sae/results/ablated/direction_{dir}"
    best_dir = torch.load(f"{path}/best_dir.pt", weights_only=True)
    rand_dir = torch.load(f"{path}/rand.pt", weights_only=True)

    # best hooks
    best_pre_hooks, best_hooks = get_all_direction_ablation_hooks_rmu(
        model_base=rmu_model,
        directions=[best_dir]
    )

    # rand hooks
    rand_pre_hooks, rand_hooks = get_all_direction_ablation_hooks_rmu(
        model_base=rmu_model,
        directions=[rand_dir]
    )

    all_dirs[dir] = {
        "best_dir": best_dir,
        "rand_dir": rand_dir,
        "pre_hooks": best_pre_hooks,
        "hooks": best_hooks,
        "rand_pre_hook": rand_pre_hooks,
        "rand_hooks": rand_hooks
    }

In [12]:
from scipy import stats

def t_test(results_treatment, results_control, alpha=0.05, alternative='two-sided'):
    t_statistic, p_value = stats.ttest_rel(results_treatment, results_control, alternative=alternative)

    print(f"T-statistic: {t_statistic}")
    print(f"P-value: {p_value}")

    # Interpret the result
    if p_value < alpha:
        print("The result is statistically significant.")
        if alternative == 'less':
            print("We reject the null hypothesis: the treatment intervention is significantly less than the control.")
        elif alternative == 'greater':
            print("We reject the null hypothesis: the treatment intervention is significantly greater than the control.")
        else:
            print("We reject the null hypothesis: the interventions have a different effect.")
    else:
        print("The result is not statistically significant.")
        print("We fail to reject the null hypothesis: we cannot conclude the interventions have a different effect.")

In [13]:
from functools import partial

junk_fn = partial(
    get_junk_scores,
    model=rmu_model,
    probe=probe.to(rmu_model.model.device),
    probe_mean=mean.to(rmu_model.model.device),
    probe_std=std.to(rmu_model.model.device),
    junk_layer=LAYER_ID + 1,
)

In [14]:
def evaluate_junk_scores(all_dirs, type="shared"):
    bio_junk = junk_fn(
        instructions=bio_test["question"],
        fwd_pre_hooks=all_dirs[type]["pre_hooks"],
        fwd_hooks=all_dirs[type]["hooks"],
        batch_size=8
    ) 

    cyber_junk = junk_fn(
        instructions=cyber_test["question"],
        fwd_pre_hooks=all_dirs[type]["pre_hooks"],
        fwd_hooks=all_dirs[type]["hooks"],
        batch_size=8
    ) 

    bio_rand_junk = junk_fn(
        instructions=bio_test["question"],
        fwd_pre_hooks=all_dirs[type]["rand_pre_hook"],
        fwd_hooks=all_dirs[type]["rand_hooks"],
        batch_size=8
    ) 

    cyber_rand_junk = junk_fn(
        instructions=cyber_test["question"],
        fwd_pre_hooks=all_dirs[type]["rand_pre_hook"],
        fwd_hooks=all_dirs[type]["rand_hooks"],
        batch_size=8
    ) 

    return {
        "bio": (bio_junk.cpu().numpy(), bio_rand_junk.cpu().numpy()),
        "cyber": (cyber_junk.cpu().numpy(), cyber_rand_junk.cpu().numpy())
    }


In [15]:
def evaluate_qa(model, all_dirs, type="shared", base=False, rand=False):
    if base:
        fwd_pre_hooks = []
        fwd_hooks = []
    elif rand:
        fwd_pre_hooks = all_dirs[type]["rand_pre_hook"]
        fwd_hooks = all_dirs[type]["rand_hooks"]
    else:
        fwd_pre_hooks = all_dirs[type]["pre_hooks"]
        fwd_hooks = all_dirs[type]["hooks"]

    cyber_results = evaluate_multiple_choice(
        model=model,
        questions=cyber_final["question"],
        choices=cyber_final["choices"],
        ground_truth_indices=cyber_final["answer"],
        fwd_pre_hooks=fwd_pre_hooks,
        fwd_hooks=fwd_hooks,
        batch_size=2,
    )

    bio_results = evaluate_multiple_choice(
        model=model,
        questions=bio_final["question"],
        choices=bio_final["choices"],
        ground_truth_indices=bio_final["answer"],
        fwd_pre_hooks=fwd_pre_hooks,
        fwd_hooks=fwd_hooks,
        batch_size=32,
    )

    return {"bio": bio_results, "cyber": cyber_results}

### Accuracy

In [16]:
if TYPE == "rmu":
    results = evaluate_qa(model=rmu_model, all_dirs=all_dirs, base=True)
elif TYPE == "base":
    rmu_model.model.to("cpu")
    base_model = GemmaModel(MODEL_NAME, type="none")
    results = evaluate_qa(model=base_model, all_dirs=all_dirs, base=True)
else:
    results = evaluate_qa(model=rmu_model, all_dirs=all_dirs, type=TYPE)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating MCQA Pairs:   0%|          | 0/2174 [00:00<?, ?it/s]

Evaluating MCQA Pairs:   0%|          | 0/47 [00:00<?, ?it/s]

In [17]:
print(f"Results for direction type: {TYPE}")
print("--- Bio ---")
print(f"Accuracy: {results['bio'][0]}")

print("--- Cyber ---")
print(f"Accuracy: {results['cyber'][0]}")

Results for direction type: base
--- Bio ---
Accuracy: 0.3324396782841823
--- Cyber ---
Accuracy: 0.43698252069917204
