# Setup

Inspired by https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=gvH_9J2WOJ9A

In [1]:
# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd
from IPython import get_ipython  # type: ignore

ipython = get_ipython()
assert ipython is not None
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")


torch.set_grad_enabled(False)

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
print(f"Device: {device}")

Device: cpu


In [2]:
MSR_df = pd.read_csv("MSR_data_vul.csv")
MSR_df.head()

Unnamed: 0,Access Gained,Attack Origin,Authentication Required,Availability,CVE ID,CVE Page,CWE ID,Complexity,Confidentiality,Integrity,...,lang,lines_after,lines_before,parentID,patch,project,project_after,project_before,vul,vul_func_with_fix
0,,Remote,Single system,Partial,CVE-2015-8467,https://www.cvedetails.com/cve/CVE-2015-8467/,CWE-264,Medium,Partial,Partial,...,C,struct ldb_context *ldb = ldb_module_ge...,,a819d2b440aafa3138d95ff6e8b824da885a70e9,"@@ -1558,12 +1558,15 @@ static int samldb_chec...",samba,https://git.samba.org/?p=samba.git;a=blob;f=so...,https://git.samba.org/?p=samba.git;a=blob;f=so...,1,static int samldb_check_user_account_control_a...
1,,Remote,Not required,Partial,CVE-2015-8382,https://www.cvedetails.com/cve/CVE-2015-8382/,CWE-119,Low,Partial,,...,C,"memset(offsets, 0, size_offsets*sizeof(...",,1a2ec3fc60e428c47fd59c9dd7966c71ca44024d,"@@ -640,7 +640,7 @@ PHPAPI void php_pcre_match...",php,https://git.php.net/?p=php-src.git;a=blob;f=ex...,https://git.php.net/?p=php-src.git;a=blob;f=ex...,1,PHPAPI void php_pcre_match_impl(pcre_cache_ent...


In [3]:
RELEASE = ... #"llama_scope_lxa_32x"
SAE_ID = ... #"Llama3_1-8B-Base-L0A-32x" 
HOOK_POINT = ... #"residuals"
MODEL_NAME = "gpt2-small" #"meta-llama/Llama-3.1-8B"

In [5]:
import einops
import numpy as np
import torch
from IPython.display import HTML, IFrame
from jaxtyping import Float

from transformer_lens import HookedTransformer
from drl_patches.sparse_autoencoders.utils import imshow, line, scatter, residual_stack_to_logit_diff, visualize_attention_patterns
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

torch.set_grad_enabled(False)
print("Disabled automatic differentiation")

  from .autonotebook import tqdm as notebook_tqdm


Loaded pretrained model gpt2-small into HookedTransformer
Disabled automatic differentiation


In [6]:
prompts = [MSR_df.iloc[0]['func_before']] + [MSR_df.iloc[0]['func_after']]

n_samples = 1

In [8]:
tokens = model.to_tokens(prompts, prepend_bos=True)

# Run the model and cache all activations
original_logits: Float[torch.Tensor, "batch seq_len voc_size"]

original_logits, cache = model.run_with_cache(tokens)
print(original_logits.shape)

# Converts to Token IDs
logit_diff_directions = torch.tensor([model.to_single_token("1"), model.to_single_token("0")])
print(model.tokens_to_residual_directions(logit_diff_directions).shape)

torch.Size([2, 123, 50257])
torch.Size([2, 768])


Logit difference is actually a really nice and elegant metric and is a particularly nice aspect of the setup of Indirect Object Identification. In general, there are two natural ways to interpret the model's outputs: the output logits, or the output log probabilities (or probabilities).

The logits are much nicer and easier to understand, as noted above. However, the model is trained to optimize the cross-entropy loss (the average of log probability of the correct token). This means it does not directly optimize the logits, and indeed if the model adds an arbitrary constant to every logit, the log probabilities are unchanged.

In [9]:
# Check the residual stream shape
print(cache["resid_post", -1].shape)


# This is to understand why the model would prefer 1 over 0
# Converts to Token IDs
logit_diff_directions = torch.tensor([model.to_single_token("1"), model.to_single_token("0")])
# Creates Embeddings for both tokens
logit_diff_directions = model.tokens_to_residual_directions(logit_diff_directions)




accumulated_residual, labels = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)
logit_lens_logit_diffs = residual_stack_to_logit_diff(prompts, accumulated_residual, cache, logit_diff_directions)
line(
    logit_lens_logit_diffs,
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    hover_name=labels,
    title="Logit Difference From Accumulate Residual Stream",
)



torch.Size([2, 123, 768])


In [10]:
per_layer_residual, labels = cache.decompose_resid(
    layer=-1, pos_slice=-1, return_labels=True
)
per_layer_logit_diffs = residual_stack_to_logit_diff(prompts, per_layer_residual, cache, logit_diff_directions)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

We can further break down the output of each attention layer into the sum of the outputs of each attention head. Each attention layer consists of 12 heads, which each act independently and additively.

In [11]:
per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_logit_diffs = residual_stack_to_logit_diff(prompts, per_head_residual, cache, logit_diff_directions)
per_head_logit_diffs = einops.rearrange(
    per_head_logit_diffs,
    "(layer head_index) -> layer head_index",
    layer=model.cfg.n_layers,
    head_index=model.cfg.n_heads,
)
imshow(
    per_head_logit_diffs,
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
)

Tried to stack head results when they weren't cached. Computing head results now


In [17]:
top_k = 3

top_positive_logit_attr_heads = torch.topk(
    per_head_logit_diffs.flatten(), k=top_k
).indices

positive_html = visualize_attention_patterns(
    model,
    top_positive_logit_attr_heads,
    cache,
    tokens[0],
    f"Top {top_k} Positive Logit Attribution Heads",
)

top_negative_logit_attr_heads = torch.topk(
    -per_head_logit_diffs.flatten(), k=top_k
).indices

negative_html = visualize_attention_patterns(
    model,
    top_negative_logit_attr_heads,
    cache,
    tokens[0],
    title=f"Top {top_k} Negative Logit Attribution Heads",
)

HTML(positive_html + negative_html)

We can assume that the information is processed in some specific 

- 0_attn_out
- 10_mlp_out
- 11_attn_out (-)
- 11_mlp_out

The residual stream shows
- 0_mid
- 11_pre
- final_post (the most)


The attnetion patterns show
- layre 11 head 0
- layre 11 head 8
- layre 10 head 4


In [48]:
RELEASE = "gpt2-small-res-jb"
SAE_ID = "blocks.0.hook_resid_pre"
HOOK_POINT = ... #"residuals"
MODEL_NAME = "gpt2-small" #"meta-llama/Llama-3.1-8B"

In [49]:
from IPython import get_ipython  # type: ignore

ipython = get_ipython()
assert ipython is not None
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px
import pandas as pd

# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Device: mps


In [50]:
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory

# TODO: Make this nicer.
df = pd.DataFrame.from_records(
    {k: v.__dict__ for k, v in get_pretrained_saes_directory().items()}
).T
df.drop(
    columns=[
        "expected_var_explained",
        "expected_l0",
        "config_overrides",
        "conversion_func",
    ],
    inplace=True,
)

print("-" * 50)
print("SAEs in the feature splitting release")
for k, v in (
    df.loc[df.release == RELEASE, "saes_map"]
    .values[0]
    .items()
):
    print(f"SAE id: {k} for hook point: {v}")


--------------------------------------------------
SAEs in the feature splitting release
SAE id: blocks.0.hook_resid_pre for hook point: blocks.0.hook_resid_pre
SAE id: blocks.1.hook_resid_pre for hook point: blocks.1.hook_resid_pre
SAE id: blocks.2.hook_resid_pre for hook point: blocks.2.hook_resid_pre
SAE id: blocks.3.hook_resid_pre for hook point: blocks.3.hook_resid_pre
SAE id: blocks.4.hook_resid_pre for hook point: blocks.4.hook_resid_pre
SAE id: blocks.5.hook_resid_pre for hook point: blocks.5.hook_resid_pre
SAE id: blocks.6.hook_resid_pre for hook point: blocks.6.hook_resid_pre
SAE id: blocks.7.hook_resid_pre for hook point: blocks.7.hook_resid_pre
SAE id: blocks.8.hook_resid_pre for hook point: blocks.8.hook_resid_pre
SAE id: blocks.9.hook_resid_pre for hook point: blocks.9.hook_resid_pre
SAE id: blocks.10.hook_resid_pre for hook point: blocks.10.hook_resid_pre
SAE id: blocks.11.hook_resid_pre for hook point: blocks.11.hook_resid_pre
SAE id: blocks.11.hook_resid_post for hook 

In [51]:
# from transformer_lens import HookedTransformer
from sae_lens import SAE, HookedSAETransformer

model = HookedSAETransformer.from_pretrained(MODEL_NAME, device=device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release=RELEASE,
    sae_id=SAE_ID,
    device=device,
)

Loaded pretrained model gpt2-small into HookedTransformer




This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)



In [52]:
print(sae.cfg.__dict__)

{'architecture': 'standard', 'd_in': 768, 'd_sae': 24576, 'activation_fn_str': 'relu', 'apply_b_dec_to_input': True, 'finetuning_scaling_factor': False, 'context_size': 128, 'model_name': 'gpt2-small', 'hook_name': 'blocks.0.hook_resid_pre', 'hook_layer': 0, 'hook_head_index': None, 'prepend_bos': True, 'dataset_path': 'Skylion007/openwebtext', 'dataset_trust_remote_code': True, 'normalize_activations': 'none', 'dtype': 'torch.float32', 'device': 'mps', 'sae_lens_training_version': None, 'activation_fn_kwargs': {}, 'neuronpedia_id': 'gpt2-small/0-res-jb', 'model_from_pretrained_kwargs': {'center_writing_weights': True}, 'seqpos_slice': (None,)}


In [53]:
from IPython.display import IFrame

# get a random feature from the SAE
feature_idx = torch.randint(0, sae.cfg.d_sae, (1,)).item()

html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"


def get_dashboard_html(sae_release="gpt2-small", sae_id="0-res-jb", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)


html = get_dashboard_html(
    feature_idx=feature_idx
)
IFrame(html, width=1200, height=600)

In [54]:

# TODO: Fix this for other entries
MSR_df = pd.read_csv("MSR_data_cleaned_vul.csv")
MSR_df.head()


Unnamed: 0,Access Gained,Attack Origin,Authentication Required,Availability,CVE ID,CVE Page,CWE ID,Complexity,Confidentiality,Integrity,...,lang,lines_after,lines_before,parentID,patch,project,project_after,project_before,vul,vul_func_with_fix
0,,Remote,Single system,Partial,CVE-2015-8467,https://www.cvedetails.com/cve/CVE-2015-8467/,CWE-264,Medium,Partial,Partial,...,C,struct ldb_context *ldb = ldb_module_ge...,,a819d2b440aafa3138d95ff6e8b824da885a70e9,"@@ -1558,12 +1558,15 @@ static int samldb_chec...",samba,https://git.samba.org/?p=samba.git;a=blob;f=so...,https://git.samba.org/?p=samba.git;a=blob;f=so...,1,static int samldb_check_user_account_control_a...
1,,Remote,Not required,Partial,CVE-2015-8382,https://www.cvedetails.com/cve/CVE-2015-8382/,CWE-119,Low,Partial,,...,C,"memset(offsets, 0, size_offsets*sizeof(...",,1a2ec3fc60e428c47fd59c9dd7966c71ca44024d,"@@ -640,7 +640,7 @@ PHPAPI void php_pcre_match...",php,https://git.php.net/?p=php-src.git;a=blob;f=ex...,https://git.php.net/?p=php-src.git;a=blob;f=ex...,1,PHPAPI void php_pcre_match_impl(pcre_cache_ent...
2,,Remote,Not required,Partial,CVE-2013-6712,https://www.cvedetails.com/cve/CVE-2013-6712/,CWE-119,Low,,,...,C,} while (!s->errors->error_coun...,} while (*ptr);\n,63f3ff7b5f89f50eb9df76c3d0860c04cc6e0f66,"@@ -1,4 +1,4 @@\n-/* Generated by re2c 0.13.5 ...",php,https://git.php.net/?p=php-src.git;a=blob;f=ex...,https://git.php.net/?p=php-src.git;a=blob;f=ex...,1,static int scan(Scanner *s)\n{\n\tuchar *curso...
3,,Remote,Not required,Partial,CVE-2013-6449,https://www.cvedetails.com/cve/CVE-2013-6449/,CWE-310,Medium,,,...,C,if (s->method->version == TLS1_2_VERSIO...,if (TLS1_get_version(s) >= TLS1_2_VERSI...,2ec4181ba92fc6b828687d2dc47c13dcd35a5d93,"@@ -4286,7 +4286,7 @@ need to go to SSL_ST_ACC...",openssl,https://git.openssl.org/gitweb/?p=openssl.git;...,https://git.openssl.org/gitweb/?p=openssl.git;...,1,long ssl_get_algorithm2(SSL *s)\n {\n ...
4,,Remote,Not required,Partial,CVE-2013-6420,https://www.cvedetails.com/cve/CVE-2013-6420/,CWE-119,Low,Partial,Partial,...,C,if (ASN1_STRING_type(timestr) != V_ASN1...,if (timestr->length < 13) {\n ...,32873cd0ddea7df8062213bb025beb6fb070e59d,"@@ -644,18 +644,28 @@ static time_t asn1_time_...",php,https://git.php.net/?p=php-src.git;a=blob;f=ex...,https://git.php.net/?p=php-src.git;a=blob;f=ex...,1,static time_t asn1_time_to_time_t(ASN1_UTCTIME...


In [44]:
prompt = [
    MSR_df.iloc[0]["func_before"],
    MSR_df.iloc[0]["func_after"],
]
_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
print([(k, v.shape) for k, v in cache.items() if "sae" in k])

feature_activation_df = pd.DataFrame(
    cache["blocks.0.hook_resid_pre.hook_sae_acts_post"][0, -1, :].cpu().numpy(),
    index=[f"feature_{i}" for i in range(sae.cfg.d_sae)],
)
feature_activation_df.columns = ["vulnerable"]
feature_activation_df["secure"] = (
    cache["blocks.0.hook_resid_pre.hook_sae_acts_post"][1, -1, :].cpu().numpy()
)
feature_activation_df["diff"] = (
    abs(feature_activation_df["vulnerable"] - feature_activation_df["secure"])
)

fig = px.line(
    feature_activation_df,
    title="Feature activations for the prompt",
    labels={"index": "Feature", "value": "Activation"},
)

# hide the x-ticks
fig.update_xaxes(showticklabels=False)
fig.show()

[('blocks.0.hook_resid_pre.hook_sae_input', torch.Size([2, 1024, 768])), ('blocks.0.hook_resid_pre.hook_sae_acts_pre', torch.Size([2, 1024, 24576])), ('blocks.0.hook_resid_pre.hook_sae_acts_post', torch.Size([2, 1024, 24576])), ('blocks.0.hook_resid_pre.hook_sae_recons', torch.Size([2, 1024, 768])), ('blocks.0.hook_resid_pre.hook_sae_output', torch.Size([2, 1024, 768]))]


Features 22338, 1698, 9370 seem important

In [45]:

feature_activation_df.sort_values("diff", ascending=False)


Unnamed: 0,vulnerable,secure,diff
feature_22338,2.880469,0.000000,2.880469
feature_1698,0.000000,1.736947,1.736947
feature_9370,0.000000,0.794410,0.794410
feature_14954,0.000000,0.319746,0.319746
feature_3249,0.000000,0.246750,0.246750
...,...,...,...
feature_8287,0.000000,0.000000,0.000000
feature_8286,0.000000,0.000000,0.000000
feature_8284,0.000000,0.000000,0.000000
feature_8283,0.000000,0.000000,0.000000


In [47]:
important_features = feature_activation_df.sort_values("diff", ascending=False).index[:3]

for feature in important_features:
    feature_idx = int(feature.split("_")[-1])

    for prompt_idx, text in enumerate(prompt):
        
        security = "vulnerable" if prompt_idx == 0 else "secure"

        title = f"Feature {feature_idx} activations for {security} prompt {prompt_idx}"


        # # See the tokens that are most activated by the feature
        # feature_idx = 22338					

        # prompt_idx = 0

        # Foward pass the tokens through the SAE to get the activations per token
        token_wise_activation = cache["blocks.0.hook_resid_pre.hook_sae_acts_post"][prompt_idx, :, feature_idx]

        # Clip the activations to be between -10 and 10
        # token_wise_activation = torch.clamp(token_wise_activation, -10, 10)

        tokenized_prompt = model.tokenizer.tokenize(text)


        our_tokens = tokenized_prompt + [model.tokenizer.eos_token] #+ [model.tokenizer.eos_token] * (len(token_wise_activation) - len(tokenized_prompt)-1)

        if len(our_tokens) < len(token_wise_activation):
            our_tokens += [model.tokenizer.eos_token] * (len(token_wise_activation) - len(our_tokens))
        elif len(our_tokens) > len(token_wise_activation):
            our_tokens = our_tokens[:len(token_wise_activation)]
            # Some hacky way to do this
            


        print(len(our_tokens))
        # VIsualize the activations per token. I want a heatmap where the x-axis is the token index and the y-axis is the token.
        # The color of the cell should be the activation value.
        # We can use plotly for this.

        fig = px.imshow( token_wise_activation.unsqueeze(0).cpu().numpy(),title=title,labels={"x": "Token index", "y": "Token"},x=our_tokens)
        fig.show()


1024


1024


1024


1024


1024


1024
