In [1]:
import torch
import einops
from datasets import load_from_disk
import plotly.graph_objects as go

from sae_refusal.model.gemma import GemmaModel
from sae_refusal.model.embedding import EmbeddingModel
from sae_refusal import set_seed, clear_memory
from sae_refusal.data import (
    load_wmdp,
    split,
    to_instructions,
    sample_data,
    format_qa
)
from sae_refusal.probe import LinearProbe
from sae_refusal.pipeline.generate_directions import generate_directions_rmu
from sae_refusal.plot import (
    plot_scores_plotly,
    plot_refusal_scores_plotly,
)
from sae_refusal.pipeline.activations import (
    get_activations,
    get_activations_pre,
    get_sae_encoded_features,
    get_refusal_gradients,
    get_sae_reconstructed_pre
)
from sae_refusal.sae import (
    load_sae_gemma_2b,
    layer_sparisity_widths
)
from sae_refusal.pipeline.utils import (
    generate_and_save_completions
)
from sae_refusal.plot import (
    plot_scores_plotly
)

In [2]:
LAYER_ID = 9
SAMPLE_SIZE = 300  # For each dataset
VAL_SIZE = 0.2
SEED = 42
MAX_LEN = 1024

ARTIFACT_DIR = "results/ablated/"

MODEL_NAME = "google/gemma-2-2b"
RMU_NAME = "lenguyen1807/gemma-2-2b-RMU"
ABLATED_NAME = "lenguyen1807/gemma-2-2b-RMU-ablated"

In [3]:
set_seed(SEED)
torch.set_float32_matmul_precision("high")

In [4]:
TYPE = "cyber"

### Load model

In [5]:
base_model = GemmaModel(MODEL_NAME, type="none")
rmu_model = GemmaModel(RMU_NAME, type="none")
ablated_model = GemmaModel(ABLATED_NAME, type="none")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
probe_weight = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/weight.pt", weights_only=True)
probe = LinearProbe(base_model.model.config.hidden_size)
probe.load_state_dict(probe_weight)

<All keys matched successfully>

In [7]:
mean = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/mean.pt", weights_only=True)
std = torch.load("/home/ubuntu/thesis/sae/results/ablated/probes/best_probes/std.pt", weights_only=True)

In [8]:
best_diff = torch.load("/home/ubuntu/thesis/sae/results/ablated/direction_shared/best_dir.pt", weights_only=True)

### Load dataset

In [9]:
bio_test = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/t_bio")
cyber_test = load_from_disk("/home/ubuntu/thesis/sae/results/ablated/data/t_cyber")

In [10]:
bio_raw = bio_test["question"]
cyber_raw = cyber_test["question"]

In [11]:
bio_reason = [format_qa(bio) for bio in bio_test]
cyber_reason = [format_qa(cyber) for cyber in cyber_test]

### Junk Gradient

In [None]:
base_grad = get_refusal_gradients(
    model=base_model.model,
    tokenize_instructions_fn=base_model.tokenize_instructions_fn,
    block_modules=base_model.model_block_modules,
    prompts=bio_raw if TYPE == "bio" else cyber_raw,
    refusal_direction=best_diff,
    gradient_layer_idx=5,
    refusal_layer_idx=LAYER_ID,
    batch_size=1,
    positions=[-5]
)

In [None]:
ablated_grad = get_refusal_gradients(
    model=ablated_model.model,
    tokenize_instructions_fn=ablated_model.tokenize_instructions_fn,
    block_modules=ablated_model.model_block_modules,
    prompts=bio_raw if TYPE == "bio" else cyber_raw,
    refusal_direction=best_diff,
    gradient_layer_idx=5,
    refusal_layer_idx=LAYER_ID,
    batch_size=1,
    positions=[-5]
)

In [None]:
rmu_grad = get_refusal_gradients(
    model=rmu_model.model,
    tokenize_instructions_fn=rmu_model.tokenize_instructions_fn,
    block_modules=rmu_model.model_block_modules,
    prompts=bio_raw if TYPE == "bio" else cyber_raw,
    refusal_direction=best_diff,
    gradient_layer_idx=5,
    refusal_layer_idx=LAYER_ID,
    batch_size=1,
    positions=[-5]
)

### Load SAE

In [None]:
# sae_dict = dict()
# for layer in range(base_model.model.config.num_hidden_layers):
#     sae_dict[layer] = load_sae_gemma_2b(layer)

In [12]:
sae_layer = load_sae_gemma_2b(layer_id=5)

#### SAE Gradient

In [13]:
sae_layer_decoder = sae_layer.state_dict()["W_dec"]

In [None]:
base_cosine = torch.nn.functional.cosine_similarity(base_grad.to(sae_layer_decoder.device), sae_layer_decoder, dim=-1)
ablated_cosine = torch.nn.functional.cosine_similarity(ablated_grad.to(sae_layer_decoder.device), sae_layer_decoder, dim=-1)
rmu_cosine = torch.nn.functional.cosine_similarity(rmu_grad.to(sae_layer_decoder.device), sae_layer_decoder, dim=-1)

In [None]:
def plot_similarity_histogram(
    similarities,
    artifact_dir,
    type = "Bio",
    k: int = 10,
    figsize = (10, 6)
):
    if similarities.ndim != 1:
        raise ValueError("similarities tensor must be 1-dimensional")

    sims_numpy = similarities.cpu().numpy()

    # --- 1. Find the top and bottom k features ---
    top_k = torch.topk(similarities, k=k, largest=True)
    bottom_k = torch.topk(similarities, k=k, largest=False)

    fig = go.Figure()

    # --- 2. Add the main histogram trace ---
    fig.add_trace(
        go.Histogram(
            x=sims_numpy,
            nbinsx=100,  # Increased bins for better resolution
            name="All Features",
            opacity=0.7,
            marker=dict(line_width=1, line_color="white")
        )
    )

    # --- 3. Add vertical lines for top and bottom features ---
    for i in range(k):
        # Top k lines (blue, dotted)
        fig.add_vline(
            x=top_k.values[i].item(),
            line_width=2,
            line_dash="dot",
            line_color="royalblue",
        )
        # Bottom k lines (red, dotted)
        fig.add_vline(
            x=bottom_k.values[i].item(),
            line_width=2,
            line_dash="dot",
            line_color="firebrick",
        )

    # --- 4. Add annotation boxes for the indices ---
    # Create the text for the annotation boxes
    top_k_text = "<b>Top {} Indices</b><br>".format(k) + "<br>".join([str(i.item()) for i in top_k.indices])
    bottom_k_text = "<b>Bottom {} Indices</b><br>".format(k) + "<br>".join([str(i.item()) for i in bottom_k.indices])

    fig.add_annotation(
        text=top_k_text,
        align='left',
        showarrow=False,
        xref='paper',  # Use paper coordinates for x positioning
        yref='paper',  # Use paper coordinates for y positioning
        x=0.98,        # Positioned at 98% from the left edge of the plot
        y=0.95,        # Positioned at 95% from the bottom edge of the plot
        bordercolor="black",
        borderwidth=1,
        bgcolor="rgba(255, 255, 255, 0.8)" # Slightly transparent background
    )

    fig.add_annotation(
        text=bottom_k_text,
        align='left',
        showarrow=False,
        xref='paper',
        yref='paper',
        x=0.02,        # Positioned at 2% from the left edge
        y=0.95,
        bordercolor="black",
        borderwidth=1,
        bgcolor="rgba(255, 255, 255, 0.8)"
    )

    # --- 5. Final layout updates ---
    fig.update_layout(
        title=f"Cosine Similarity Between SAE Latents and Refusal Gradients ({type})",
        xaxis_title="Cosine Similarity",
        yaxis_title="Count",
        barmode="overlay",
        bargap=0.1,
        width=figsize[0]*100,
        height=figsize[1]*100,
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="center",
            x=0.5
        )
    )

    fig.show()
    fig.write_image(artifact_dir)

In [None]:
plot_similarity_histogram(base_cosine, artifact_dir=f"{ARTIFACT_DIR}/{TYPE}_cosine_latents_base.pdf", type=TYPE)

In [None]:
plot_similarity_histogram(rmu_cosine, artifact_dir=f"{ARTIFACT_DIR}/{TYPE}_cosine_latents_rmu.pdf", type=TYPE)

In [None]:
plot_similarity_histogram(ablated_cosine, artifact_dir=f"{ARTIFACT_DIR}/{TYPE}_cosine_latents_ablated.pdf", type=TYPE)

In [None]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release, sae_id, feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)
html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id="5-gemmascope-res-16k", feature_idx=12187)
IFrame(html, width=1200, height=600)

### Steering

CYBER: 4101, 2502 
BIO: 16213, 8085

In [14]:
steer_dir_1 = sae_layer_decoder[16213]
steer_dir_2 = sae_layer_decoder[8085]

In [15]:
steer_dir = (steer_dir_1 + steer_dir_2) / 2

In [16]:
ablated_model.mod_model(steer_dir, type="ablation")

In [17]:
ablated_model.save("/home/ubuntu/thesis/sae/results/models/sae_steer")

The OrderedVocab you are attempting to save contains holes for indices [0], your vocabulary could be corrupted !
