# Setup

In [None]:
# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd
from IPython import get_ipython  # type: ignore

import einops
import numpy as np
import torch
from IPython.display import HTML, IFrame
from jaxtyping import Float

from transformer_lens import HookedTransformer
from drl_patches.sparse_autoencoders.utils import imshow, line, scatter, residual_stack_to_logit_diff, visualize_attention_patterns



ipython = get_ipython()
assert ipython is not None
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")


torch.set_grad_enabled(False)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
print(f"Device: {device}")

RELEASE = "gemma-scope-2b-pt-res-canonical"
MODEL_NAME = "google/gemma-2-2b"
SAE_ID = "layer_0/width_16k/canonical"
GEMMA_SCOPE_0_WIDTH_16K_CANONICAL = "layer_0/width_16k/canonical"
cache_component = "hook_resid_post.hook_sae_acts_post"

# SAE_ID = "blocks.0.hook_resid_pre"
# MODEL_NAME = "gpt2-small"
# RELEASE = "gpt2-small-res-jb"
# cache_component = "hook_resid_pre.hook_sae_acts_post"

# model = HookedTransformer.from_pretrained(
#     MODEL_NAME,
#     center_unembed=True,
#     center_writing_weights=True,
#     fold_ln=True,
#     # refactor_factored_attn_matrices=True,
# )

torch.set_grad_enabled(False)
print("Disabled automatic differentiation")
# from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer

model = HookedSAETransformer.from_pretrained(MODEL_NAME, device=device)


In [47]:

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release=RELEASE,
    sae_id=SAE_ID,
    device=device,
)


In [48]:
MSR_df = pd.read_csv("artifacts/gbug-java.csv")

# Create an empty list to store the token counts
n_tokens_list = []

# Iterate over the rows
for row_idx, row in MSR_df.iterrows():
    n_tokens = model.to_tokens(row["func_before"], prepend_bos=True)
    n_tokens_list.append(n_tokens.shape[1])

# Assign the token counts as a new column
MSR_df["n_tokens"] = n_tokens_list

# Sort the dataframe by the number of tokens starting with the smallest
MSR_df = MSR_df.sort_values(by="n_tokens", ascending=True)

# Reset the index
MSR_df = MSR_df.reset_index(drop=True)

MSR_df.head()


Unnamed: 0,bug_id,func_before,func_after,n_tokens
0,vmzakharov-dataframe-ec-12af99192d24,@Override\n public String asStringLiter...,@Override\n public String asStringLiter...,29
1,jitterted-ensembler-60ec3bf0273b,public void joinAsSpectator(MemberId membe...,public void joinAsSpectator(MemberId membe...,38
2,jitterted-ensembler-60ec3bf0273b,public void joinAsSpectator(MemberId membe...,public void joinAsSpectator(MemberId membe...,38
3,AuthMe-ConfigMe-aa91a6b315ec,public void setComment(@NotNull String pat...,public void setComment(@NotNull String pat...,43
4,AuthMe-ConfigMe-aa91a6b315ec,public void setComment(@NotNull String pat...,public void setComment(@NotNull String pat...,43


In [58]:
for idx, row in enumerate(MSR_df.iterrows()):
    print(idx)
    row = row[1]
    # print(row)
    print(row["func_before"])
    # print(row["func_after"])
    # Write to file by appending
    with open(f"test.txt", "a") as f:
        f.write(f"Function {idx}:\n")
        f.write(f"Before:\n{row['func_before']}\n")
        # f.write(f"After:\n{row['func_after']}\n")
        f.write("\n\n")

0
    @Override
    public String asStringLiteral()
    {
        return '"' + this.value + '"';
    }

1
    public void joinAsSpectator(MemberId memberId) {
        membersAsSpectators.add(memberId);
        membersWhoAccepted.remove(memberId);
    }

2
    public void joinAsSpectator(MemberId memberId) {
        membersAsSpectators.add(memberId);
        membersWhoAccepted.remove(memberId);
    }

3
    public void setComment(@NotNull String path, @NotNull String... commentLines) {
        comments.put(path, Collections.unmodifiableList(Arrays.asList(commentLines)));
    }

4
    public void setComment(@NotNull String path, @NotNull String... commentLines) {
        comments.put(path, Collections.unmodifiableList(Arrays.asList(commentLines)));
    }

5
    @Override
    public String toString() {
        return rangesList.stream()
            .map(RangesList::formatRanges)
            .collect(joining(OR_JOINER));
    }

6
    @Override
    public String toString() {
        return 

In [63]:
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots

colors = [
    "#d3dee3",  # Soft Blue-Grey
    #"#fdf3d1",  # Light Yellow
    "#f4f4f4",  # Light Grey
    "#b8b09c",  # Neutral Grey
]

important_features = ["feature_6843", "feature_21821", "feature_6111", "feature_1352", "feature_22338", "feature_18861", "feature_6002", "feature_19462", "feature_22351", "feature_24401", "feature_7101", "feature_19337", "feature_9954", "feature_16488", "feature_2122", "feature_17229", "feature_19388", "feature_14084", "feature_9226", "feature_11786", "feature_9237", "feature_11233", "feature_1070", "feature_7261", "feature_2333", "feature_916", "feature_22779"]
important_features = important_features[:1]
#important_features = ["feature_3336"]
important_features = ["feature_1987"]

after_func_col = "func_after"
before_func_col = "func_before"
layer = 0
index = [f"feature_{i}" for i in range(sae.cfg.d_sae)]

# Focus on the first entry only
for entry_idx, entry in enumerate(MSR_df[29:30].iterrows()):
    entry = entry[1]
    prompt_texts = [
        str(entry[after_func_col]),   # vulnerable
        str(entry[before_func_col]),  # secure
    ]
    

    for feature in important_features:
        feature_idx = int(feature.split("_")[-1])

        # --- Process each prompt (vulnerable and secure) ---
        activations_dict = {}
        token_strings = {}

        feature_activation_df = pd.DataFrame(
            columns=["vulnerable", "secure", "diff"],
            index=index,
        )
        print("-------------------")
        print(f"Vulnerable prompt:\n {prompt_texts[0]}")
        print(f"Secure prompt:\n {prompt_texts[1]}")
        
        for prompt_idx, text in enumerate(prompt_texts):
            security = "vulnerable" if prompt_idx == 0 else "secure"
            
            # Forward pass
            _, cache = model.run_with_cache_with_saes(text, saes=[sae])
            cache_acts = cache[f"blocks.{layer}.{cache_component}"]

            # Extract activations for the selected feature across tokens
            token_wise_activation = cache_acts[0, :, feature_idx]

            token_wise_all_activations = cache_acts[0, -1, :]
            if prompt_idx == 0:
                feature_activation_df["vulnerable"] = token_wise_all_activations.cpu().numpy()
            else:
                feature_activation_df["secure"] = token_wise_all_activations.cpu().numpy()


            # Save
            activations_dict[security] = token_wise_activation.cpu()

            tokens = model.to_tokens(text, prepend_bos=True)

            token_string = model.tokenizer.convert_ids_to_tokens(tokens[0])
            # normalize "____________" to "____"
            token_string = [token.replace("▁", "") for token in token_string]
            token_strings[security] = token_string

        # --- Create the combined heatmap ---
        # Padd with zeros if the lengths are different
        max_length = max(len(activations_dict["vulnerable"]), len(activations_dict["secure"]))
        for key in activations_dict:
            if len(activations_dict[key]) < max_length:
                activations_dict[key] = torch.cat(
                    [activations_dict[key], torch.zeros(max_length - len(activations_dict[key]))]
                )
                token_strings[key] += [""] * (max_length - len(token_strings[key]))

        vulnerable_acts = activations_dict["vulnerable"]
        secure_acts = activations_dict["secure"]
        diff_acts = secure_acts - vulnerable_acts

        our_color_scale = [
            [0, colors[0]],
            [0.5, colors[1]],
            [1, colors[2]]
        ]

        fig = make_subplots(
            rows=1, cols=3,
            subplot_titles=("Buggy Tokens", "Patched Tokens", "Activation Differences"),
            shared_yaxes=False,
            horizontal_spacing=0.35
        )

        # Vulnerable Heatmap
        fig.add_trace(go.Heatmap(
            z=vulnerable_acts.reshape(-1, 1).numpy(),
            x=["Activation"],
            y=[f"{i}: {token}" for i, token in enumerate(token_strings["vulnerable"])],
            colorscale=our_color_scale,
            zmin=-1, zmax=1,
            colorbar=dict(title="Activation")
        ), row=1, col=1)

        # Secure Heatmap
        fig.add_trace(go.Heatmap(
            z=secure_acts.reshape(-1, 1).numpy(),
            x=["Activation"],
            y=[f"{i}: {token}" for i, token in enumerate(token_strings["secure"])],
            colorscale=our_color_scale,
            zmin=-1, zmax=1,
            showscale=False,
        ), row=1, col=2)

        # Difference Heatmap
        fig.add_trace(go.Heatmap(
            z=diff_acts.reshape(-1, 1).numpy(),
            x=["Activation Diff"],
            y=[f"{i}" for i, token in enumerate(diff_acts.reshape(-1, 1).numpy())],
            colorscale=our_color_scale,
            zmid=0,  # Important: Center around 0
            zmin=-1, zmax=1,
            showscale=False
        ), row=1, col=3)

        # --- Layout Updates ---
        fig.update_layout(
            title_text=f"Feature {feature_idx} Activations per Token",
            height=700,
            width=600,
            title_x=0.5,
            font=dict(family="Arial", size=14),
            margin=dict(l=30, r=30, t=80, b=30)
        )
        fig.show()


        feature_activation_df["diff"] = abs(
            feature_activation_df["vulnerable"] - feature_activation_df["secure"]
        )
        fig = px.line(
            feature_activation_df,
            title="Feature activations for the prompt",
            labels={"index": "Feature", "value": "Activation"},
        )

        # hide the x-ticks
        fig.update_xaxes(showticklabels=False)
        fig.show()



-------------------
Vulnerable prompt:
     @Override
    public String constructUrlFromEndpointPortDatabase(String endpoint, String port, String dbname) {
        String url = "jdbc:postgresql://" + endpoint;
        if (!StringUtils.isNullOrEmpty(port)) {
            url += ":" + port;
        }
        url += "/";
        if (!StringUtils.isNullOrEmpty(dbname)) {
            url += dbname;
        }
        return url;
    }

Secure prompt:
     @Override
    public String constructUrlFromEndpointPortDatabase(String endpoint, String port, String dbname) {
        String url = "jdbc:postgresql://" + endpoint;
        if (!StringUtils.isNullOrEmpty(port)) {
            url += ":" + port;
        }
        if (!StringUtils.isNullOrEmpty(dbname)) {
            url += "/" + dbname;
        }
        return url;
    }

