In [1]:
import torch
import einops
from datasets import concatenate_datasets, load_from_disk

from sae_refusal.model.gemma import GemmaModel
from sae_refusal.model.embedding import EmbeddingModel
from sae_refusal import set_seed, clear_memory
from sae_refusal.data import (
    load_wmdp,
    split,
    to_instructions,
    sample_data,
    load_harmbench
)
from sae_refusal.pipeline.select_directions import select_direction
from sae_refusal.pipeline.utils import compute_pca, generate_and_save_completions, proj, evaluate_multiple_choice
from sae_refusal.probe import LinearProbe
from sae_refusal.pipeline.generate_directions import generate_directions_rmu
from sae_refusal.pipeline.hook import (
    get_all_direction_ablation_hooks,
    get_all_direction_ablation_hooks_rmu,
    get_activation_addition_input_pre_hook,
    get_activation_addition_input_pre_hook_rmu
)
from sae_refusal.plot import (
    plot_scores_plotly,
    plot_refusal_scores_plotly,
)
from sae_refusal.pipeline.activations import (
    get_activations,
    get_activations_pre
)
from IPython.display import display, Markdown

In [2]:
LAYER_ID = 9
SAMPLE_SIZE = 300  # For each dataset
VAL_SIZE = 0.2
SEED = 42
MAX_LEN = 1024

ARTIFACT_DIR = "results/ablated"

MODEL_NAME = "google/gemma-2-2b"
RMU_NAME = "lenguyen1807/gemma-2-2b-RMU"

In [3]:
set_seed(SEED)
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")

### Load model

In [4]:
base_model = GemmaModel(MODEL_NAME, type="none")
rmu_model = GemmaModel(RMU_NAME, type="none")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
probe_weight = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/weight.pt", weights_only=True)
probe = LinearProbe(base_model.model.config.hidden_size)
probe.load_state_dict(probe_weight)

<All keys matched successfully>

In [6]:
mean = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/mean.pt", weights_only=True)
std = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/std.pt", weights_only=True)

### Load dataset

In [7]:
bio_train = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/bio_train")
bio_val = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/bio_val")
cyber_train = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/cyber_train")
cyber_val = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/cyber_val")

In [8]:
(
    bio_train_instructions,
    cyber_train_instructions,
    val_instructions,  # we combine both for validation
) = to_instructions(
    [
        bio_train,
        cyber_train,
        concatenate_datasets([bio_val, cyber_val])
    ],
    lambda x: x["question"],
)

In [9]:
len(cyber_train_instructions), len(bio_train_instructions)

(240, 240)

### Find directions

In [10]:
bio_diff = generate_directions_rmu(
    base_model,
    rmu_model,
    bio_train_instructions,
    batch_size=8,
    positions=list(range(-5, 0))
)

cyber_diff = generate_directions_rmu(
    base_model,
    rmu_model,
    cyber_train_instructions,
    batch_size=8,
    positions=list(range(-5, 0))
)

Calculating mean: 100%|██████████| 30/30 [00:19<00:00,  1.52it/s]
Calculating mean: 100%|██████████| 30/30 [00:19<00:00,  1.52it/s]
Calculating mean: 100%|██████████| 30/30 [00:27<00:00,  1.10it/s]
Calculating mean: 100%|██████████| 30/30 [00:32<00:00,  1.09s/it]


In [11]:
shared_diff = bio_diff + cyber_diff / 2

### Get all activations

In [12]:
bio_rmu_activations = get_activations_pre(
    rmu_model.model,
    rmu_model.tokenizer,
    bio_train_instructions,
    tokenize_instructions_fn=rmu_model.tokenize_instructions_fn,
    block_modules=rmu_model.model_block_modules,
    positions=[-1],
    batch_size=8
)

cyber_rmu_activations = get_activations_pre(
    rmu_model.model,
    rmu_model.tokenizer,
    cyber_train_instructions,
    tokenize_instructions_fn=rmu_model.tokenize_instructions_fn,
    block_modules=rmu_model.model_block_modules,
    positions=[-1],
    batch_size=4
)

100%|██████████| 30/30 [00:22<00:00,  1.33it/s]
100%|██████████| 60/60 [00:46<00:00,  1.28it/s]


In [13]:
bio_base_activations = get_activations_pre(
    base_model.model,
    base_model.tokenizer,
    bio_train_instructions,
    tokenize_instructions_fn=base_model.tokenize_instructions_fn,
    block_modules=base_model.model_block_modules,
    positions=[-1],
    batch_size=8
)

cyber_base_activations = get_activations_pre(
    base_model.model,
    base_model.tokenizer,
    cyber_train_instructions,
    tokenize_instructions_fn=base_model.tokenize_instructions_fn,
    block_modules=base_model.model_block_modules,
    positions=[-1],
    batch_size=4
)

100%|██████████| 30/30 [00:18<00:00,  1.58it/s]
100%|██████████| 60/60 [00:48<00:00,  1.24it/s]


### Analysis

#### Similarity and Norm of Activations

In [None]:
bio_act_sim = torch.nn.functional.cosine_similarity(bio_base_activations, bio_rmu_activations, dim=-1)
cyber_act_sim = torch.nn.functional.cosine_similarity(cyber_base_activations, cyber_rmu_activations, dim=-1)

In [None]:
plot_scores_plotly(
    scores_group={
        "Biology": bio_act_sim.squeeze(-1),
        "Cybersecurity": cyber_act_sim.squeeze(-1),
    },
    layer_of_interest=LAYER_ID + 1,
    title="Cosine Similarity between Base and RMU model activations at the last token position",
    ylabel="Cosine similarity",
    artifact_dir=ARTIFACT_DIR,
    artifact_name="cosine_similarity_plot_base",
)

In [None]:
bio_rmu_norm = torch.linalg.norm(bio_rmu_activations, dim=-1)
bio_base_norm = torch.linalg.norm(bio_base_activations, dim=-1)
cyber_rmu_norm = torch.linalg.norm(cyber_rmu_activations, dim=-1)
cyber_base_norm = torch.linalg.norm(cyber_base_activations, dim=-1)

In [None]:
plot_scores_plotly(
    scores_group={"RMU Model": bio_rmu_norm.squeeze(-1), "Base Model": bio_base_norm.squeeze(-1)},
    layer_of_interest=LAYER_ID + 1,
    title="Norm of Base Model and RMU Model activations at the last token position (on Biology)",
    ylabel="Norm",
    artifact_dir=ARTIFACT_DIR,
    artifact_name="bio_norm_plot_base",
)

In [None]:
plot_scores_plotly(
    scores_group={"RMU Model": cyber_rmu_norm.squeeze(-1), "Base Model": cyber_base_norm.squeeze(-1)},
    layer_of_interest=LAYER_ID + 1,
    title="Norm of Base Model and RMU Model activations at the last token position (on Cybersecurity)",
    ylabel="Norm",
    artifact_dir=ARTIFACT_DIR,
    artifact_name="cyber_norm_plot_base",
)

#### Each directions

In [None]:
cyber_bio_sim = torch.nn.functional.cosine_similarity(cyber_diff, bio_diff, dim=-1)

In [None]:
plot_refusal_scores_plotly(
    refusal_scores=cyber_bio_sim,
    baseline_refusal_score=None,
    artifact_dir=ARTIFACT_DIR,
    token_labels=["5th-last token", "4th-last token", "3rd-last token", "2nd-last token", "Last token"],
    layers=[LAYER_ID - 1],
    artifact_name="cyber_bio_sim_base",
    title="Cosine Similarity between Bio and Cyber directions",
    ylabel="Cosine Similarity"
)

In [18]:
cyber_diff_norm = torch.linalg.norm(cyber_diff, dim=-1)

In [19]:
bio_diff_norm = torch.linalg.norm(bio_diff, dim=-1)

In [None]:
plot_refusal_scores_plotly(
    refusal_scores=cyber_diff_norm,
    baseline_refusal_score=0.0,
    artifact_dir=ARTIFACT_DIR,
    token_labels=["5th-last token", "4th-last token", "3rd-last token", "2nd-last token", "Last token"],
    layers=[LAYER_ID + 1],
    artifact_name="cyber_diff_norm_base",
    title="Norm of Cyber Directions",
    ylabel="Norm",
)

In [None]:
plot_refusal_scores_plotly(
    refusal_scores=bio_diff_norm,
    baseline_refusal_score=0.0,
    artifact_dir=ARTIFACT_DIR,
    token_labels=["5th-last token", "4th-last token", "3rd-last token", "2nd-last token", "Last token"],
    layers=[LAYER_ID + 1],
    artifact_name="bio_diff_norm_base",
    title="Norm of Bio Directions",
    ylabel="Norm",
)

In [14]:
def proj(a, b):
    a = a.to(b.device) # Ensure 'a' is on the same device as 'b'

    # Calculate the dot product (a ⋅ b)
    # Result shape: (..., 1)
    dot_product_ab = einops.einsum(
        a, b, "... d_model, ... d_model -> ..."
    ).unsqueeze(-1) # Add a trailing dimension of 1 for broadcasting

    # Calculate the squared L2 norm of b (||b||²)
    # Result shape: (..., 1)
    # Using a small epsilon for numerical stability in case ||b|| is zero or very small
    b_squared_norm = einops.einsum(
        b, b, "... d_model, ... d_model -> ..."
    ).unsqueeze(-1) + 1e-6 # Add epsilon here

    # Calculate the projection: ((a ⋅ b) / ||b||²) * b
    # The (..., 1) shapes of dot_product_ab and b_squared_norm will broadcast correctly with b (..., d_model)
    projection = (dot_product_ab / b_squared_norm) * b
    return projection

In [15]:
bio_diff_proj = proj(bio_diff, cyber_diff) # project bio to cyber

In [16]:
bio_diff_ortho = bio_diff - bio_diff_proj # 

In [20]:
shared_ratio_bio = torch.square(torch.linalg.norm(bio_diff_proj, dim=-1)) / (torch.square(bio_diff_norm) + 1e-8)

In [None]:
plot_refusal_scores_plotly(
    refusal_scores=shared_ratio_bio,
    baseline_refusal_score=0.0,
    artifact_dir=ARTIFACT_DIR,
    token_labels=["5th-last token", "4th-last token", "3rd-last token", "2nd-last token", "Last token"],
    layers=[LAYER_ID -1 ],
    artifact_name="shared_ratio_bio_diff_base",
    title="Shared Ratio of Bio Direction based on Cyber Direction",
    ylabel="Shared Ratio",
)

In [21]:
ortho_shared_sim = torch.nn.functional.cosine_similarity(bio_diff_ortho + cyber_diff, shared_diff, dim=-1)

In [24]:
plot_refusal_scores_plotly(
    refusal_scores=ortho_shared_sim,
    baseline_refusal_score=None,
    artifact_dir=ARTIFACT_DIR,
    token_labels=["5th-last token", "4th-last token", "3rd-last token", "2nd-last token", "Last token"],
    layers=[LAYER_ID - 1],
    artifact_name="ortho_shared_sim_base",
    title="Cosine Similarity between Shared and (Ortho + Cyber) directions",
    ylabel="Cosine Similarity"
)

Plot saved to results/ablated/figures/ortho_shared_sim_base.pdf
