In [1]:
import torch
import einops
from datasets import concatenate_datasets, load_from_disk
import plotly.graph_objects as go
import numpy as np

from sae_refusal.model.gemma import GemmaModel
from sae_refusal.model.embedding import EmbeddingModel
from sae_refusal import set_seed, clear_memory
from sae_refusal.data import (
    load_wmdp,
    split,
    to_instructions,
    sample_data,
    load_harmbench
)
from sae_refusal.pipeline.select_directions import select_direction
from sae_refusal.pipeline.utils import compute_pca, generate_and_save_completions
from sae_refusal.probe import LinearProbe
from sae_refusal.pipeline.generate_directions import generate_directions_rmu
from sae_refusal.pipeline.hook import (
    get_all_direction_ablation_hooks,
    get_all_direction_ablation_hooks_rmu,
    get_activation_addition_input_pre_hook,
    get_activation_addition_input_pre_hook_rmu
)
from sae_refusal.plot import (
    plot_scores_plotly,
    plot_refusal_scores_plotly,
)
from sae_refusal.pipeline.activations import (
    get_activations,
    get_activations_pre
)
from IPython.display import display, Markdown
from sklearn.utils import resample
from tqdm.auto import tqdm

In [2]:
LAYER_ID = 9
SAMPLE_SIZE = 400  # For each dataset
VAL_SIZE = 0.2
SEED = 42
MAX_LEN = 1024
BOOTSTRAP_K = 1000

ARTIFACT_DIR = "results/ablated"

MODEL_NAME = "google/gemma-2-2b"
RMU_NAME = "lenguyen1807/gemma-2-2b-RMU"

In [3]:
set_seed(SEED)
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")

### Load model

In [4]:
base_model = GemmaModel(MODEL_NAME, type="none")
rmu_model = GemmaModel(RMU_NAME, type="none")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
probe_weight = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/weight.pt", weights_only=True)
probe = LinearProbe(base_model.model.config.hidden_size)
probe.load_state_dict(probe_weight)

<All keys matched successfully>

In [6]:
mean = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/mean.pt", weights_only=True)
std = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/std.pt", weights_only=True)

### Load dataset

In [7]:
bio_train = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/bio_train")
bio_val = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/bio_val")
cyber_train = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/cyber_train")
cyber_val = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/cyber_val")

In [8]:
(
    bio_train_instructions,
    cyber_train_instructions,
    val_instructions,  # we combine both for validation
) = to_instructions(
    [
        bio_train,
        cyber_train,
        concatenate_datasets([bio_val, cyber_val])
    ],
    lambda x: x["question"],
)

### Create bootstrapping dataset

Get activations first

In [9]:
bio_rmu_activations = get_activations_pre(
    rmu_model.model,
    rmu_model.tokenizer,
    bio_train_instructions,
    tokenize_instructions_fn=rmu_model.tokenize_instructions_fn,
    block_modules=rmu_model.model_block_modules,
    positions=list(range(-5, 0)),
    batch_size=8
)

cyber_rmu_activations = get_activations_pre(
    rmu_model.model,
    rmu_model.tokenizer,
    cyber_train_instructions,
    tokenize_instructions_fn=rmu_model.tokenize_instructions_fn,
    block_modules=rmu_model.model_block_modules,
    positions=list(range(-5, 0)),
    batch_size=4
)

100%|██████████| 30/30 [00:18<00:00,  1.60it/s]
100%|██████████| 60/60 [00:39<00:00,  1.50it/s]


In [10]:
bio_base_activations = get_activations_pre(
    base_model.model,
    base_model.tokenizer,
    bio_train_instructions,
    tokenize_instructions_fn=base_model.tokenize_instructions_fn,
    block_modules=base_model.model_block_modules,
    positions=list(range(-5, 0)),
    batch_size=8
)

cyber_base_activations = get_activations_pre(
    base_model.model,
    base_model.tokenizer,
    cyber_train_instructions,
    tokenize_instructions_fn=base_model.tokenize_instructions_fn,
    block_modules=base_model.model_block_modules,
    positions=list(range(-5, 0)),
    batch_size=4
)

100%|██████████| 30/30 [00:16<00:00,  1.81it/s]
100%|██████████| 60/60 [00:40<00:00,  1.47it/s]


In [11]:
print("Base activations shape: ", bio_base_activations.shape)
print("RMU activations shape: ", bio_rmu_activations.shape)

Base activations shape:  torch.Size([240, 26, 5, 2304])
RMU activations shape:  torch.Size([240, 26, 5, 2304])


Then using bootstrap on that activations dataset instead

In [12]:
def calculate_bootstrap_directions(
    rmu_activations, 
    base_activations, 
    num_resamples,
    random_seed_offset=0,
    stack=True
):
    diff_activations = rmu_activations - base_activations

    indices_to_sample = list(range(diff_activations.shape[0]))

    bootstrap_directions = []
    for i in tqdm(range(num_resamples), desc="Bootstrapping Directions"):
        bootstrap_indices = resample(
            indices_to_sample,
            replace=True,
            n_samples=len(indices_to_sample),
            random_state=i + random_seed_offset,  # Ensure cyber uses a different seed
        )

        sample_mean_diff = diff_activations[bootstrap_indices].mean(dim=0)

        bootstrap_directions.append(sample_mean_diff)

    if stack:
        return torch.stack(bootstrap_directions)
    else:
        return bootstrap_directions

In [13]:
bio_bootstrap_directions = calculate_bootstrap_directions(
    bio_rmu_activations,
    bio_base_activations,
    BOOTSTRAP_K,
    stack=False,
)
cyber_bootstrap_directions = calculate_bootstrap_directions(
    cyber_rmu_activations,
    cyber_base_activations,
    BOOTSTRAP_K,
    random_seed_offset=BOOTSTRAP_K,
    stack=False,
)

Bootstrapping Directions:   0%|          | 0/1000 [00:00<?, ?it/s]

Bootstrapping Directions:   0%|          | 0/1000 [00:00<?, ?it/s]

In [14]:
print("Bio bootstrap directions shape:", bio_bootstrap_directions[0].shape)
print("Cyber bootstrap directions shape:", cyber_bootstrap_directions[0].shape)

Bio bootstrap directions shape: torch.Size([26, 5, 2304])
Cyber bootstrap directions shape: torch.Size([26, 5, 2304])


### Bootstrapping Analysis

In [15]:
bio_pcas = []
for layer in range(rmu_model.model.config.num_hidden_layers):
    bio_pcas.append(compute_pca(tensors=[dir[layer, :, :] for dir in bio_bootstrap_directions], k=12))

In [16]:
cyber_pcas = []
for layer in range(rmu_model.model.config.num_hidden_layers):
    cyber_pcas.append(compute_pca(tensors=[dir[layer, :, :] for dir in cyber_bootstrap_directions], k=12))

In [17]:
cyber_best = torch.load("/home/ubuntu/thesis/sae/results/ablated/direction_cyber/best_dir.pt", weights_only=True)
shared_best = torch.load("/home/ubuntu/thesis/sae/results/ablated/direction_shared/best_dir.pt", weights_only=True)
bio_best = torch.load("/home/ubuntu/thesis/sae/results/ablated/direction_bio/best_dir.pt", weights_only=True)

### Visualize Analysis

In [18]:
def calculate_cosine_sim(pcas, dir):
    cosines = []
    for layer in range(len(pcas)):
        cosines.append(torch.nn.functional.cosine_similarity(
            pcas[layer]["components"][0],
            dir,
            dim=0
        ))
    return torch.stack(cosines, dim=-1)

In [19]:
cyber_total_var = torch.stack([pca["total_variance"] for pca in cyber_pcas], dim=-1)
bio_total_var = torch.stack([pca["total_variance"] for pca in bio_pcas], dim=-1)

In [20]:
cyber_cosine = calculate_cosine_sim(cyber_pcas, cyber_best).unsqueeze(dim=0)
cyber_shared_cosine = calculate_cosine_sim(cyber_pcas, shared_best).unsqueeze(dim=0)
bio_cosine = calculate_cosine_sim(bio_pcas, bio_best).unsqueeze(dim=0)
bio_shared_cosine = calculate_cosine_sim(bio_pcas, shared_best).unsqueeze(dim=0)

In [23]:
plot_scores_plotly(
    scores_group={
        "Cybersecurity_Cybersecurity": cyber_cosine,
        "Biology_Biology": bio_cosine,
        "Cybersecurity_Shared": cyber_shared_cosine,
        "Biology_Shared": bio_shared_cosine
    },
    layer_of_interest=LAYER_ID,
    title="Cosine Similarity between PCA of each group and correspoding best directions",
    ylabel="Cosine similarity",
    artifact_dir=f"{ARTIFACT_DIR}/bootstrap",
    artifact_name="cosine_best"
)

Plot saved to results/ablated/bootstrap/figures/cosine_best.pdf


In [24]:
def scree(
    results,
    artifact_dir,
    artifact_name,
    figsize=(8, 8),
    title="Scree Plot of Bio Bootstrap PCA",
):
    fig = go.Figure()

    for idx, result in enumerate(results):
        explained_ratio = result["explained_ratio"].cpu().numpy()
        components = list(range(1, len(explained_ratio) + 1))

        fig.add_trace(
            go.Scatter(
                x=components,
                y=explained_ratio,
                text=[f"{v:.2f}" for v in explained_ratio],
                cliponaxis=False,
                mode="lines+markers",
                name=f"Layer {idx}"
            )
        )

    fig.update_layout(
        title=title,
        xaxis_title="Principal Component",
        yaxis_title="Explained Variance Ratio",
        template="plotly_white",
        legend=dict(
            yanchor="top", y=0.98, xanchor="right", x=0.98,
            bgcolor='white', bordercolor="LightGray", borderwidth=1,
        ),
        barmode='group',
        xaxis=dict(
            showgrid=True,
        ),
        yaxis=dict(
            showgrid=True,
        ),
        plot_bgcolor='white',
        bargap=0.2,
        width=figsize[0] * 100,
        height=figsize[1] * 100,
    )

    fig.update_xaxes(showline=True, linewidth=1.5, linecolor='black')
    fig.update_yaxes(showline=True, linewidth=1.5, linecolor='black')

    fig.show()

    if artifact_dir:
        fig.write_image(f"{artifact_dir}/{artifact_name}.pdf", format="pdf")

In [26]:
scree(
    bio_pcas,
    title="Scree Plot of Bio Bootstrap PCA",
    artifact_dir=f"{ARTIFACT_DIR}/bootstrap",
    artifact_name="scee_bio_full",
)

In [28]:
scree(
    bio_pcas[:15],
    title="Scree Plot of Bio Bootstrap PCA",
    artifact_dir=f"{ARTIFACT_DIR}/bootstrap",
    artifact_name="scee_bio_15",
)

In [30]:
scree(
    cyber_pcas[:15],
    title="Scree Plot of Cyber Bootstrap PCA",
    artifact_dir=f"{ARTIFACT_DIR}/bootstrap/figures",
    artifact_name="scee_cyber_15",
)

In [32]:
scree(
    cyber_pcas,
    title="Scree Plot of Cyber Bootstrap PCA",
    artifact_dir=f"{ARTIFACT_DIR}/bootstrap/figures",
    artifact_name="scee_cyber_full",
)

In [33]:
def pairwise_cosine(directions):
    # (layer, seq_len, batch, d_model) => (layer * seq_len, batch, d_model)
    dir = einops.rearrange(
        directions, "batch layer seq_len d_model -> (layer seq_len) batch d_model"
    )
    # (layer * seq_len, batch, d_model)
    norm_dir = torch.nn.functional.normalize(dir, p=2, dim=-1, eps=1e-8)
    # calculate cosine
    cosine = torch.bmm(norm_dir, norm_dir.transpose(1, 2))

    cosine_reshapes = einops.rearrange(
        cosine,
        "(layer seq_len) batch1 batch2 -> layer seq_len batch1 batch2",
        layer=rmu_model.model.config.num_hidden_layers,
    )

    i, j = torch.triu_indices(cosine_reshapes.shape[2], cosine_reshapes.shape[3], offset=1)

    return cosine_reshapes[:, :, i, j]

In [34]:
bio_cosine = pairwise_cosine(bio_bootstrap_directions)

In [35]:
cyber_cosine = pairwise_cosine(cyber_bootstrap_directions)

In [36]:
def hist(
    bio_cosine,
    cyber_cosine,
    layer,
    artifact_dir,
):
    fig = go.Figure()

    fig.add_trace(
        go.Histogram(
            x=bio_cosine[LAYER_ID].mean(dim=0).cpu().numpy(),
            nbinsx=50,
            name="Bio Directions",
            opacity=0.6,
            marker=dict(line_width=1, line_color="white")
        )
    )

    fig.add_trace(
        go.Histogram(
            x=cyber_cosine[LAYER_ID].mean(dim=0).cpu().numpy(),
            nbinsx=50,
            name="Cyber Directions",
            opacity=0.6,
            marker=dict(line_width=1, line_color="white")
        )
    )

    fig.update_layout(
        title=f"Pairwise Cosine Similarity Distribution at layer {layer} - Averaging over token positions",
        xaxis_title="Cosine Similarity",
        yaxis_title="Count",
        barmode="overlay", 
        bargap=0.1, 
        width=800,
        height=600
    )

    #fig.show()
    
    if artifact_dir:
        fig.write_image(f"{artifact_dir}/layer_{layer}.pdf", format="pdf")

In [37]:
for layer in tqdm(range(rmu_model.model.config.num_hidden_layers)):
    hist(bio_cosine, cyber_cosine, layer, artifact_dir=f"{ARTIFACT_DIR}/bootstrap/")

for layer in tqdm(range(rmu_model.model.config.num_hidden_layers)):
    hist(bio_cosine, cyber_cosine, layer, artifact_dir=f"{ARTIFACT_DIR}/bootstrap/")

  0%|          | 0/26 [00:00<?, ?it/s]

  0%|          | 0/26 [00:00<?, ?it/s]