In [1]:
import torch
import einops
from datasets import concatenate_datasets, load_dataset, load_from_disk
import json

from sae_refusal.model.gemma import GemmaModel
from sae_refusal.model.embedding import EmbeddingModel
from sae_refusal import set_seed, clear_memory
from sae_refusal.data import (
    load_wmdp,
    split,
    to_instructions,
    sample_data,
    load_harmbench
)
from sae_refusal.pipeline.select_directions import select_direction
from sae_refusal.pipeline.utils import compute_pca, generate_and_save_completions, proj
from sae_refusal.probe import LinearProbe
from sae_refusal.pipeline.generate_directions import generate_directions_rmu
from sae_refusal.pipeline.hook import (
    get_all_direction_ablation_hooks,
    get_all_direction_ablation_hooks_rmu,
    get_activation_addition_input_pre_hook,
    get_activation_addition_input_pre_hook_rmu
)
from sae_refusal.plot import (
    plot_scores_plotly,
    plot_refusal_scores_plotly,
)
from sae_refusal.pipeline.activations import (
    get_activations,
    get_activations_pre
)
from IPython.display import display, Markdown

In [2]:
LAYER_ID = 9
SAMPLE_SIZE = 300  # For each dataset
VAL_SIZE = 0.2
SEED = 42
MAX_LEN = 1024

ARTIFACT_DIR = "results/ablated"

MODEL_NAME = "google/gemma-2-2b"
RMU_NAME = "lenguyen1807/gemma-2-2b-RMU"

In [3]:
set_seed(SEED)
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")

### Load model

In [4]:
base_model = GemmaModel(MODEL_NAME, type="none")
rmu_model = GemmaModel(RMU_NAME, type="none")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

### Load directions

In [5]:
shared_diff = torch.load("/home/ubuntu/thesis/sae/results/ablated/direction_shared/best_dir.pt", weights_only=True)
shared_rand = torch.load("/home/ubuntu/thesis/sae/results/ablated/direction_shared/rand.pt", weights_only=True)

In [6]:
cyber_diff = torch.load("/home/ubuntu/thesis/sae/results/ablated/direction_cyber/best_dir.pt", weights_only=True)
cyber_rand = torch.load("/home/ubuntu/thesis/sae/results/ablated/direction_cyber/rand.pt", weights_only=True)

In [7]:
bio_diff = torch.load("/home/ubuntu/thesis/sae/results/ablated/direction_bio/best_dir.pt", weights_only=True)
bio_rand = torch.load("/home/ubuntu/thesis/sae/results/ablated/direction_bio/rand.pt", weights_only=True)

In [8]:
general_diff = torch.load("/home/ubuntu/thesis/sae/results/ablated/direction_general/best_dir.pt", weights_only=True)
general_same_diff = torch.load("/home/ubuntu/thesis/sae/results/ablated/direction_general_same/best_dir.pt", weights_only=True)

### Filtered

In [None]:
# Bio and General don't have directions satisfy the constraints
# but what if we relax the constraints more ? 
with open("/home/ubuntu/thesis/sae/results/ablated/direction_bio/direction_evaluations.json", mode="r", encoding="utf-8") as f:
    bio_results = json.load(f)

In [None]:
with open("/home/ubuntu/thesis/sae/results/ablated/direction_general/direction_evaluations.json", mode="r", encoding="utf-8") as f:
    general_results = json.load(f)

In [None]:
with open("/home/ubuntu/thesis/sae/results/ablated/direction_general_same/direction_evaluations.json", mode="r", encoding="utf-8") as f:
    general_same_results = json.load(f)

In [None]:
with open("/home/ubuntu/thesis/sae/results/ablated/direction_cyber/direction_evaluations_filtered.json", mode="r", encoding="utf-8") as f:
    cyber_filtered = json.load(f)

In [None]:
def filter_fn(
    results,
    n_layer=rmu_model.model.config.num_hidden_layers,
    kl_score=0.2,
    induce_score=0.0,
    prune_layer_percentage=0.20,
):
    filtered_results = []
    for result in results:
        if result["layer"] >= int(n_layer * (1.0 - prune_layer_percentage)):
            continue
        if result["kl_div_score"] > kl_score:
            continue
        if result["steering_score"] < induce_score:
            continue
        filtered_results.append(result)
    return sorted(filtered_results, key=lambda x: x["ablation_score"])

In [None]:
bio_filtered = filter_fn(bio_results)

In [None]:
bio_filtered[0]

In [None]:
general_filtered = filter_fn(general_results, kl_score=0.4)

In [None]:
general_filtered[0]

In [None]:
general_same_filtered = filter_fn(general_same_results, kl_score=0.3)

In [None]:
general_same_filtered

### Get missing directions

In [None]:
bio_train = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/bio_train")
bio_val = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/bio_val")
cyber_train = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/cyber_train")
cyber_val = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/cyber_val")

In [None]:
(
    bio_train_instructions,
    cyber_train_instructions,
    val_instructions,  # we combine both for validation
) = to_instructions(
    [
        bio_train,
        cyber_train,
        concatenate_datasets([bio_val, cyber_val])
    ],
    lambda x: x["question"],
)

#### Bio

In [None]:
bio_diff = generate_directions_rmu(
    base_model,
    rmu_model,
    bio_train_instructions,
    batch_size=8,
    positions=list(range(-5, 0))
)

In [None]:
# Get the best direction with relax constraints
bio_layer = bio_filtered[0]["layer"]
bio_pos = bio_filtered[0]["position"]
bio_best = bio_diff[bio_pos, bio_layer, :]

In [None]:
with open("/home/ubuntu/thesis/sae/results/ablated/direction_bio/direction_evaluations_filtered.json", mode="w+", encoding="utf-8") as f:
    json.dump(bio_filtered, f, indent=4)

In [None]:
torch.save(bio_best, "/home/ubuntu/thesis/sae/results/ablated/direction_bio/best_dir.pt")

#### Cyber

In [None]:
cyber_diff = generate_directions_rmu(
    base_model,
    rmu_model,
    cyber_train_instructions,
    batch_size=8,
    positions=list(range(-5, 0))
)

In [None]:
# Get the best direction with relax constraints
cyber_layer = cyber_filtered[0]["layer"]
cyber_pos = cyber_filtered[0]["position"]
cyber_best = cyber_diff[bio_pos, bio_layer, :]

In [None]:
torch.save(cyber_best, "/home/ubuntu/thesis/sae/results/ablated/direction_cyber/best_dir.pt")

#### General

In [None]:
adv_bench = load_dataset("walledai/AdvBench", split="train")
adv_bench = adv_bench.remove_columns("target")

In [None]:
minstruct = load_dataset("walledai/MaliciousInstruct", split="train")

In [None]:
harmbench_ds = load_dataset("walledai/HarmBench", "standard", split="train")
harmbench_ds = harmbench_ds.remove_columns("category")

In [None]:
train_instructions, val_instructions = to_instructions(
    data=split(
        sample_data(
            data=concatenate_datasets([adv_bench, minstruct, harmbench_ds]),
            sample_size=SAMPLE_SIZE,
            seed=SEED,
            max_len=None,
            key="prompt",
        ),
        test_size=VAL_SIZE,
    ),
    format_fn=lambda x: x["prompt"],
)

In [None]:
general_diff = generate_directions_rmu(
    base_model,
    rmu_model,
    train_instructions,
    batch_size=8,
    positions=list(range(-5, 0))
)

In [None]:
# Get the best direction with relax constraints
general_layer = general_filtered[0]["layer"]
general_pos = general_filtered[0]["position"]
general_best = general_diff[general_pos, general_layer, :]

In [None]:
with open("/home/ubuntu/thesis/sae/results/ablated/direction_general/direction_evaluations_filtered.json", mode="w+", encoding="utf-8") as f:
    json.dump(general_filtered, f, indent=4)

In [None]:
torch.save(general_best, "/home/ubuntu/thesis/sae/results/ablated/direction_general/best_dir.pt")

#### General same

In [None]:
with open('/home/ubuntu/thesis/sae/results/ablated/data/bio_augment.txt', 'r') as f:
    bio_augment = f.read().strip().split('\n')

with open('/home/ubuntu/thesis/sae/results/ablated/data/cyber_augment.txt', 'r') as f:
    cyber_augment = f.read().strip().split('\n')

In [None]:
bio_same_diff = generate_directions_rmu(
    base_model,
    rmu_model,
    bio_augment,
    batch_size=8,
    positions=list(range(-5, 0))
)

cyber_same_diff = generate_directions_rmu(
    base_model,
    rmu_model,
    cyber_augment,
    batch_size=8,
    positions=list(range(-5, 0))
)

In [None]:
shared_same_diff = (bio_same_diff + cyber_same_diff) / 2

In [None]:
general_same_layer = general_same_filtered[0]["layer"]
general_same_pos = general_same_filtered[0]["position"]
general_same_best = shared_same_diff[general_same_pos, general_same_layer, :]

In [None]:
torch.save(general_same_best, "/home/ubuntu/thesis/sae/results/ablated/direction_general_same/best_dir.pt")

#### Random Ablation

In [None]:
def create_placebo_control(treatment_dir):
    rand_dir = torch.randn_like(treatment_dir)
    
    # Project rand_dir onto treatment_dir to find the parallel component
    # Assumes proj(a, b) projects a onto b
    parallel_component = proj(rand_dir, treatment_dir)
    
    # The orthogonal component is the original random vector minus the parallel part
    orthogonal_component = rand_dir - parallel_component
    
    # Rescale the orthogonal vector to have the same norm as the treatment vector
    # This is the crucial step for a valid placebo
    control_dir = orthogonal_component * (
        torch.linalg.norm(treatment_dir) / torch.linalg.norm(orthogonal_component)
    )
    
    return control_dir

In [None]:
shared_rand = create_placebo_control(shared_diff)
bio_rand = create_placebo_control(bio_best)
cyber_rand = create_placebo_control(cyber_best)
general_rand = create_placebo_control(general_best)
general_same_rand = create_placebo_control(general_same_best)

In [None]:
torch.save(bio_rand, "/home/ubuntu/thesis/sae/results/ablated/direction_bio/rand.pt")
torch.save(cyber_rand, "/home/ubuntu/thesis/sae/results/ablated/direction_cyber/rand.pt")
torch.save(shared_rand, "/home/ubuntu/thesis/sae/results/ablated/direction_shared/rand.pt")
torch.save(general_rand, "/home/ubuntu/thesis/sae/results/ablated/direction_general/rand.pt")
torch.save(general_same_rand, "/home/ubuntu/thesis/sae/results/ablated/direction_general_same/rand.pt")

### Mod model

In [9]:
rmu_model.mod_model(general_same_diff, type="ablation")

In [10]:
rmu_model.save("results/models/general_same")

The OrderedVocab you are attempting to save contains holes for indices [0], your vocabulary could be corrupted !
