## Setup

In [1]:
import wget
import os
import json
import random
import einops
import torch
import pandas as pd
import numpy as np
import plotly.express as px
from tqdm import tqdm
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login, login
from sae_lens import SAE, HookedSAETransformer
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from transformer_lens.utils import test_prompt
from functools import partial
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import re

import gc

# trigger garbage collection
gc.collect()

# free all memory in use by cuda
torch.cuda.empty_cache()

# Get the current GPU memory usage
allocated_memory = torch.cuda.memory_allocated()
reserved_memory = torch.cuda.memory_reserved()

print(f"Allocated memory: {allocated_memory / (1024 ** 3):.2f} GB")
print(f"Reserved memory: {reserved_memory / (1024 ** 3):.2f} GB")


  from .autonotebook import tqdm as notebook_tqdm


Allocated memory: 0.00 GB
Reserved memory: 0.00 GB


In [2]:
torch.set_grad_enabled(False) # avoid blowing up mem
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

Device: cuda


In [3]:
GEMMA_2_2B_MODEL = "google/gemma-2-2b"
GEMMA_2_2B_RELEASE = "gemma-scope-2b-pt-{}"
model_name = GEMMA_2_2B_MODEL

# the model we want to test, and attach the SAE to it
model = HookedSAETransformer.from_pretrained(model_name, device=device)

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  7.64it/s]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model google/gemma-2-2b into HookedTransformer


In [6]:
specific_layers = [
    ('mlp', layer, '65k') for layer in range(model.cfg.n_layers)
    # ('mlp', layer, '16k') for layer in range(model.cfg.n_layers)
]

## SAE Format Parsing and Preparation

In [7]:
import requests

def get_explanations_api(modelId, saeId):
    url = f"https://www.neuronpedia.org/api/explanation/export?modelId={modelId}&saeId={saeId}"
    print("explanations url = ", url)
    headers = {"Content-Type": "application/json"}
    response = requests.get(url, headers=headers)
    result = response.json()
    # for example: https://www.neuronpedia.org/api/explanation/export?modelId=gemma-2-2b&saeId=5-gemmascope-att-16k
    # "message": "We don't yet have an explanation export for this model and SAE. Please contact support@neuronpedia.org if you need this now and we'll get it to you ASAP."
    if 'message' in result:
        return None
    return response.json()

# get explanations using neuronpedia API
def get_explanations_by_features(modelId, saeId, features):
    data = get_explanations_api(modelId, saeId)
    if data is None:
        return []
    exp_df = pd.DataFrame(data)
    # rename index to "feature"
    exp_df.rename(columns={"index": "feature"}, inplace=True)
    exp_df["feature"] = exp_df["feature"].astype(int)
    exp_df["description"] = exp_df["description"].apply(lambda x: x.lower())
    exp_df = exp_df[exp_df['explanationModelName'] == 'gpt-4o-mini']
    exp_df = exp_df[exp_df['feature'].isin(features)]
    return list(exp_df['description'])

In [8]:
def get_all_saes_df():
  # Each row is a "release" which has multiple SAEs which may have different configs / match different hook points in a model.
  all_saes_df = pd.DataFrame.from_records({k:v.__dict__ for k,v in get_pretrained_saes_directory().items()}).T
  all_saes_df.drop(columns=["expected_var_explained", "expected_l0",
                            "config_overrides", "conversion_func"], inplace=True)
  return all_saes_df

# given a specific release, get its sae ids 
# if we pass "gemma-scope-2b-pt-" we get all the types (mlp, att, res)
def get_sae_ids(release, all_saes_df, is_gemma=True):
    saes_map = all_saes_df[(all_saes_df['release'].str.contains(release)) & ~(all_saes_df['release'].str.contains('canonical'))]['saes_map']
    df = pd.DataFrame(saes_map)
    sae_ids_df = pd.DataFrame(columns=['id', 'type'])
    for layer_type in df.index:
        ids = df.loc[layer_type]['saes_map'].keys()
        temp_df = pd.DataFrame({'id': ids, 'type': layer_type})
        sae_ids_df = pd.concat([sae_ids_df, temp_df])        

    if is_gemma:
      sae_ids_df['layer'] = sae_ids_df['id'].apply(lambda x: x.split('/')[0])
      sae_ids_df['width'] = sae_ids_df['id'].apply(lambda x: x.split('/')[1])
      sae_ids_df['average_l0'] = sae_ids_df['id'].apply(lambda x: x.split('/')[2])
    return sae_ids_df

def get_average_l0_from_csv(layer_type, layer_num, layer_width):
    df = pd.read_csv('gemma-2-2b.csv')
    id_str = f"{layer_num}-gemmascope-{layer_type}-{layer_width}"
    return df[df['id'] == id_str]['average'].iloc[0]

def get_saes_info_for_gemma2_specific_layers(specific_layers):
    all_saes_df = get_all_saes_df()
    sae_ids_df = get_sae_ids('gemma-scope-2b-pt-', all_saes_df, is_gemma=True)
    sae_ids = []
    for layer_type, layer_num, layer_width in specific_layers:
        # here we do have a specific type
        average_l0 = get_average_l0_from_csv(layer_type, layer_num, layer_width)
        sae_id_temp = sae_ids_df[(sae_ids_df['layer'] == f'layer_{layer_num}') & 
                                    (sae_ids_df['average_l0'] == average_l0) & 
                                    (sae_ids_df['type'].str.contains(layer_type))].iloc[0]['id']
        sae_ids.append(sae_id_temp)
    sae_single_ids_df = sae_ids_df[sae_ids_df['id'].isin(sae_ids)]
    # extract layer number based on id
    sae_single_ids_df['layer_num'] = sae_single_ids_df['layer'].apply(lambda x: x.split('/')[0].split('_')[1])
    return sae_single_ids_df.reset_index(drop=True)

In [9]:
saes_df = get_saes_info_for_gemma2_specific_layers(specific_layers).iloc[:-1] # TODO: There's a bug that adds a resid layer that we don't want, that's why we do iloc[:-1]
saes_df

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  sae_single_ids_df['layer_num'] = sae_single_ids_df['layer'].apply(lambda x: x.split('/')[0].split('_')[1])


Unnamed: 0,id,type,layer,width,average_l0,layer_num
0,layer_0/width_65k/average_l0_72,gemma-scope-2b-pt-mlp,layer_0,width_65k,average_l0_72,0
1,layer_1/width_65k/average_l0_127,gemma-scope-2b-pt-mlp,layer_1,width_65k,average_l0_127,1
2,layer_2/width_65k/average_l0_134,gemma-scope-2b-pt-mlp,layer_2,width_65k,average_l0_134,2
3,layer_3/width_65k/average_l0_68,gemma-scope-2b-pt-mlp,layer_3,width_65k,average_l0_68,3
4,layer_4/width_65k/average_l0_66,gemma-scope-2b-pt-mlp,layer_4,width_65k,average_l0_66,4
5,layer_5/width_65k/average_l0_86,gemma-scope-2b-pt-mlp,layer_5,width_65k,average_l0_86,5
6,layer_6/width_65k/average_l0_101,gemma-scope-2b-pt-mlp,layer_6,width_65k,average_l0_101,6
7,layer_7/width_65k/average_l0_115,gemma-scope-2b-pt-mlp,layer_7,width_65k,average_l0_115,7
8,layer_8/width_65k/average_l0_110,gemma-scope-2b-pt-mlp,layer_8,width_65k,average_l0_110,8
9,layer_9/width_65k/average_l0_77,gemma-scope-2b-pt-mlp,layer_9,width_65k,average_l0_77,9


## SAE Feature Identification Pipeline

In [12]:
from dataclasses import dataclass

@dataclass
class Result:
    prompt: str
    subject: str
    object: str
    logit_diff: float
    clean_pred: str
    ablated_pred: str
    clean_completion: str
    ablated_completion: str
    found_features: list
    min_features: list = None
    min_ablated_completion: str = None
    min_ablation_num: int = None
    feature_diffs: dict = None

    def __str__(self):
        out = f"Result for '{self.prompt}'\n\tLogit diff: {self.logit_diff}\n\t'{self.clean_pred}' -> '{self.ablated_pred}'\n\t'{self.clean_completion}' -> '{self.ablated_completion}'\n\tAll features: {self.found_features}"
        if self.min_features:
            out += f"\n\tMin features: {self.min_features}"
        if self.min_ablated_completion:
            out += f"\n\tMin ablated completion: '{self.min_ablated_completion}'"
        if self.min_ablation_num:
            out += f"\n\tMin ablation num: '{self.min_ablation_num}'"

        return out + "\n"

    def __repr__(self):
        return str(self)

In [13]:
# Usef for finding the starting index of the subject tokens in a prompt
def get_tok_index_in_prompt(prompt, tokens):
    for i in range(len(prompt)-len(tokens)+1):
        found = True
        for j, k in enumerate(range(i, i+len(tokens))):
            if prompt[k] != tokens[j]:
                found = False
                break

        if found:
            return i

    assert False, f"Token not found in prompt: {model.to_string(tokens)}"


# Get the index of the subject in the prompt (string)
def get_last_subject_index(prompt, subject):
    prompt_tokens = model.to_tokens(prompt).squeeze().tolist()
    subject_tokens = model.to_tokens(subject, prepend_bos=False).squeeze().tolist()
    if isinstance(subject_tokens, int):
        subject_tokens = [subject_tokens]
    index = get_tok_index_in_prompt(prompt_tokens, subject_tokens)
    last_index = index + len(subject_tokens) - 1
    assert subject_tokens[-1] == prompt_tokens[last_index], f"Got mismatched tokens: {model.to_string(subject_tokens[-1])} != {model.to_string(prompt_tokens[last_index])}" 
    return last_index

In [14]:
# Function used to generate an n-length completion from a prompt
def pred_loop(model, prompt, n=10, use_hooks=False, saes=None, fwd_hooks=None):
    prompt = model.to_tokens(prompt)
    with torch.no_grad():
        if saes:
            for sae, _ in saes:
                model.add_sae(sae)
        for i in range(n):
            output = model(prompt)[:, -1] if not use_hooks else model.run_with_hooks(prompt, fwd_hooks=fwd_hooks)[:, -1]
            predicted = torch.argmax(output, dim=-1)
            prompt = torch.cat([prompt, predicted.unsqueeze(0)], dim=-1)

    return model.to_string(prompt)[0]

# Attach SAE and run the model with the prompt to get the top activating columns (features)
def get_top_activating_cols_for_sae(prompt, sae, sae_id, layer, top_feature_num=10, top_tokens_from_proj=10, tok_index=-1):
    with torch.no_grad():
        # Run with single SAE and get cache
        model.reset_saes()
        model.add_sae(sae)
        _, cache = model.run_with_cache(prompt, return_type=None)
        model.reset_saes()
        
        # Get the top activating columns
        top = cache[f"blocks.{layer}.hook_mlp_out.hook_sae_acts_post"][0,tok_index].topk(top_feature_num)
        top_indices = top.indices
        top_values = top.values

        logits = (sae.W_dec[top_indices].half() @ model.unembed.W_U.half())
        token_indices = logits.topk(top_tokens_from_proj, dim=-1).indices
        tokens = model.to_string(token_indices)
        
        # Get the top tokens from each of the top activating columns
        cols = []
        for i, (index, value) in enumerate(zip(top_indices, top_values)):
            cols.append((tokens[i], sae_id, index.item(), value.item()))

        # Free up memory
        del logits
        del cache
        del sae
        # gc.collect()
        # torch.cuda.empty_cache()
    
        return cols

# Get the top activating columns for each SAE in the model
def get_top_activating_cols(prompt, subject_index):
    top_per_layer = []

    for i in range(len(saes_df)):
        # Get SAE
        row = saes_df.iloc[i]
        sae, _, _ = SAE.from_pretrained(release = row.type, sae_id = row.id, device = device)
        sae.use_error_term = False # TODO: should test if true/false is needed

        # Get top activating columns and their tokens
        top_per_layer.append(get_top_activating_cols_for_sae(prompt, sae, row.id, row.layer_num, 200, 100, subject_index))

    return top_per_layer
    

def get_sae_features_for_relation(model, prompt, subject_index, object):
    prompt_toks = model.to_tokens(prompt)

    # Get the top activating columns for each SAE in the model
    top_features_per_layer = get_top_activating_cols(prompt_toks, subject_index)

    # Find the features that contain the object and return them
    object_features = []
    for i, layer in enumerate(top_features_per_layer):
        for j, (column_tokens, sae_id, index, value) in enumerate(layer):
            if value <= 0:
                continue
            # Choose this feature as meaningful if the object appears in ANY of the words from that feature's column
            if object.strip().lower() in column_tokens.strip().lower():
                object_features.append((sae_id, index, value))

                # Hacky stopgap for too many features - should make this a limit per SAE or something
                # if len(object_features) >= 10:
                    # return object_features

                # break - TODO: this is probably a bug so commenting out

    return object_features

# Try to ablate the features and see how the model's prediction changes
def try_ablation(prompt, object, subject_index, features, ablation_num=-20, do_completion=True):

    # Ablation function - notice that currently we are ablating all features indexed by 'cols', in
    # all token positions (even though we only searched updates for the subject token). It would be
    # better if we can make ablating only the subject token be effective.
    def ablate_feature_hook(acts, hook, index, cols):
        for col in cols:
            for i in [index-1, index, index+1, -1, -2]:
                acts[:,i,col] = ablation_num

    # Create map so we can group features by SAE (to not try to double attach the same SAE)
    saes_to_features = {}
    for sae_id, index, _ in features:
        if sae_id not in saes_to_features:
            saes_to_features[sae_id] = []

        saes_to_features[sae_id].append(index)

    # Create SAEs and hooks
    saes = [(SAE.from_pretrained(release = "gemma-scope-2b-pt-res", sae_id = sae_id, device = device)[0], cols) for sae_id, cols in saes_to_features.items()]
    fwd_hooks = [(sae.cfg.hook_name + '.hook_sae_acts_post', partial(ablate_feature_hook, cols=cols, index=subject_index)) for sae, cols in saes]

    # Run the model with and without the ablation
    clean_logits = model(prompt)
    ablated_logits = model.run_with_hooks_with_saes(prompt, saes=[sae for sae,_ in saes], fwd_hooks=fwd_hooks)

    object_tok = model.to_tokens(object, prepend_bos=False).squeeze(0)[0]

    # Get the logit values for the object token in the clean and ablated logits
    clean_logit_value = clean_logits[0,-1,object_tok].item() 
    ablated_logit_value = ablated_logits[0,-1,object_tok].item()

    # Get the predictions for the clean and ablated logits
    clean_pred = clean_logits[0,-1].argmax().item()
    ablated_pred = ablated_logits[0,-1].argmax().item()

    # Get the completions for the clean and ablated logits
    if do_completion:
        clean_completion = pred_loop(model, prompt, n=15, use_hooks=False)
        ablated_completion = pred_loop(model, prompt, n=15, use_hooks=True, saes=saes, fwd_hooks=fwd_hooks)
    else:
        clean_completion = None
        ablated_completion = None

    # Free up memory
    del clean_logits
    del ablated_logits
    del saes
    # Maybe also delete hooks - also maybe add the gc.collect() and erase cache stuff

    return ablated_logit_value - clean_logit_value, clean_pred, ablated_pred, clean_completion, ablated_completion

# Run the pipeline for a single prompt - find the features, ablate them, see how the model's prediction changes, return all results
def run_pipeline_one_prompt(prompt, subject, object):
    with torch.no_grad():
        subject_index = get_last_subject_index(prompt, subject)
        features = get_sae_features_for_relation(model, prompt, subject_index, object)

        # Super hacky way to make the ablation work for different number of features - if we have lots of features, a high
        # ablation_num might cause a lot of damage, so we try to minimize it for more features. We can try to play with this.
        if len(features) <= 4:
            ablation_num = -20
        elif len(features) <= 6:
            ablation_num = -15
        elif len(features) <= 8:
            ablation_num = -10
        else:
            ablation_num = -5


        logit_change, clean_pred, ablated_pred, clean_completion, ablated_completion = try_ablation(prompt, object, subject_index, features, ablation_num=ablation_num)
        return Result(prompt, subject, object, logit_change, model.to_string(clean_pred), model.to_string(ablated_pred), clean_completion, ablated_completion, list(set(features)))

# Run the pipeline for a dataset
def run_pipline(dataset):
    out = []
    for prompt, subject, object in tqdm(dataset):
        out.append(run_pipeline_one_prompt(prompt, subject, object))

    return out

In [15]:
def find_important_features_by_ablation(prompt, object, subject_index, features, ablation_num=-30):
    diffs = {}
    for feature in features:
        logit_diff, clean_pred, ablated_pred, _, _ = try_ablation(prompt, object, subject_index, [feature], ablation_num=ablation_num, do_completion=False)
        diffs[feature] = logit_diff

    sorted_features = sorted(diffs.keys(), key=lambda key: diffs[key])

    ablation_num=-70

    ablnum = ablation_num
    needed_features = []
    for feature in sorted_features:
        needed_features.append(feature)
        logit_diff, clean_pred, ablated_pred, _, ablated_comp = try_ablation(prompt, object, subject_index, needed_features, ablation_num=ablnum, do_completion=True)
        if object.lower() not in ablated_comp.lower():
            break
            
    return needed_features, try_ablation(prompt, object, subject_index, needed_features, ablation_num=ablnum, do_completion=True)[-1], ablnum

def add_min_important_features_to_results(results):
    for result in tqdm(results):
        result.min_features, result.min_ablated_completion = find_important_features_by_ablation(result.prompt, result.object, get_last_subject_index(result.prompt, result.subject), result.found_features)

In [16]:
def score_features_by_ablation(prompt, object, subject_index, features, ablation_num=-30):
    diffs = {}
    for feature in features:
        logit_diff, clean_pred, ablated_pred, _, _ = try_ablation(prompt, object, subject_index, [feature], ablation_num=ablation_num, do_completion=False)
        diffs[feature] = logit_diff

    return diffs

In [17]:
def validate_dataset(dataset):
    for prompt, subject, object in dataset:
        # assert len(model.to_str_tokens(object, prepend_bos=False)) == 1, f"Object must be a single token: {object}"
        get_last_subject_index(prompt, subject)

### Trial Run

In [43]:
# Notice prepended spacing of some tokens
sample_dataset = [
    ("Barack Obama was born in", "Barack Obama", " Hawaii"),
    ("George Bush was the governor of", "George Bush", " Texas"),
    ("Beats is owned by", "Beats", " Apple"),
    ("Bill Gates is the founder of", "Bill Gates", " Microsoft"),
    ("The Dalai Lama is a spiritual leader from", " Dalai Lama", " Tibet"),
]

validate_dataset(sample_dataset)
out = run_pipline(sample_dataset)
out

100%|██████████| 5/5 [01:33<00:00, 18.73s/it]


[Result for 'Barack Obama was born in'
 	Logit diff: -5.333288192749023
 	' Hawaii' -> ' '
 	'<bos>Barack Obama was born in Hawaii, but he was raised in Indonesia. He was a student at Occidental' -> '<bos>Barack Obama was born in 1964 in the small town of Kinyanya, Kenya'
 	All features: [('layer_20/width_16k/average_l0_109', 12504, 4.3818769454956055), ('layer_8/width_16k/average_l0_136', 14111, 1.0018670558929443), ('layer_6/width_16k/average_l0_133', 4999, 6.1822404861450195)],
 Result for 'George Bush was the governor of'
 	Logit diff: 0.04281425476074219
 	' Texas' -> ' Florida'
 	'<bos>George Bush was the governor of Texas when he was elected president in 2000. He was' -> '<bos>George Bush was the governor of Florida when he was elected president in 1988. He was'
 	All features: [('layer_18/width_16k/average_l0_106', 4413, 2.529614210128784), ('layer_10/width_16k/average_l0_110', 13846, 1.6948988437652588), ('layer_9/width_16k/average_l0_88', 11567, 1.6754446029663086), ('layer_8

In [45]:
for result in out:
    result.min_features, result.min_ablated_completion, result.min_ablation_num = find_important_features_by_ablation(result.prompt, result.object, get_last_subject_index(result.prompt, result.subject), result.found_features)

out

[Result for 'Barack Obama was born in'
 	Logit diff: -5.333288192749023
 	' Hawaii' -> ' '
 	'<bos>Barack Obama was born in Hawaii, but he was raised in Indonesia. He was a student at Occidental' -> '<bos>Barack Obama was born in 1964 in the small town of Kinyanya, Kenya'
 	All features: [('layer_20/width_16k/average_l0_109', 12504, 4.3818769454956055), ('layer_8/width_16k/average_l0_136', 14111, 1.0018670558929443), ('layer_6/width_16k/average_l0_133', 4999, 6.1822404861450195)]
 	Min features: [('layer_6/width_16k/average_l0_133', 4999, 6.1822404861450195)]
 	Min ablated completion: '<bos>Barack Obama was born in 1962, and he was the son of a Muslim.'
 	Min ablation num: '-70',
 Result for 'George Bush was the governor of'
 	Logit diff: 0.04281425476074219
 	' Texas' -> ' Florida'
 	'<bos>George Bush was the governor of Texas when he was elected president in 2000. He was' -> '<bos>George Bush was the governor of Florida when he was elected president in 1988. He was'
 	All features: [

## Running on Dataset

### Get Dataset

In [10]:
df = pd.read_csv("./gpt_dataset.csv")
df

Unnamed: 0,prompt,subject,object
0,Barack Obama was born in,Barack Obama,Hawaii
1,George Bush was the governor of,George Bush,Texas
2,Beats is owned by,Beats,Apple
3,Bill Gates is the founder of,Bill Gates,Microsoft
4,The Dalai Lama is a spiritual leader from,Dalai Lama,Tibet
...,...,...,...
58,Mount Kilimanjaro is located in,Mount Kilimanjaro,Tanzania
59,The Dead Sea is located in,Dead Sea,Jordan
60,Mount Vesuvius is located near,Mount Vesuvius,Naples
61,The Louvre is the most visited,Louvre,museum


### Filter

In [83]:
for _, row in tqdm(df.iterrows()):
    # obj_toks = model.to_str_tokens(row.object, prepend_bos=False)
    # if len(obj_toks) >= 2:
    #     row.object = obj_toks[0]

    # try:
    validate_dataset([(row.prompt, row.subject, row.object)])
    # except:
    #     validate_dataset([(row.prompt, " " + row.subject, row.object)])
    #     row.subject = " " + row.subject
    

100it [00:00, 3545.93it/s]


In [129]:
list(df.itertuples(index=False))

[Pandas(_0=0, prompt='Isaac Newton formulated the law of', subject='Isaac Newton', object=' gravity'),
 Pandas(_0=1, prompt='Albert Einstein developed the theory of', subject='Albert Einstein', object=' relativity'),
 Pandas(_0=2, prompt='Pablo Picasso painted', subject='Pablo Picasso', object=' Guernica'),
 Pandas(_0=3, prompt='Marie Curie discovered the element', subject='Marie Curie', object=' radium'),
 Pandas(_0=4, prompt='The Eiffel Tower is located in', subject=' Eiffel Tower', object=' Paris'),
 Pandas(_0=5, prompt='Elon Musk is the CEO of', subject='Elon Musk', object=' Tesla'),
 Pandas(_0=6, prompt='Cristiano Ronaldo plays for', subject='Cristiano Ronaldo', object=' Al-Nassr'),
 Pandas(_0=7, prompt='William Shakespeare wrote', subject='William Shakespeare', object=' Hamlet'),
 Pandas(_0=8, prompt='Napoleon Bonaparte was exiled to', subject='Napoleon Bonaparte', object=' Elba'),
 Pandas(_0=9, prompt='The Great Wall of China is in', subject=' Great Wall of China', object=' Chin

In [328]:
df.iloc[98]

Unnamed: 0                               98
prompt        The Nobel Prize is awarded in
subject                         Nobel Prize
object                               Sweden
Name: 98, dtype: object

In [317]:
model.to_str_tokens(" Rubicon")

['<bos>', ' Rub', 'icon']

In [29]:
row = df.iloc[2]
prompt, subject, object = row.prompt, row.subject, row.object
# prompt, subject, object = "Barack Obama was born in", "Barack Obama", " Hawaii"
# prompt, subject, object = "George Bush was the governor of", "George Bush", " Texas"
# prompt, subject, object = "Beats is owned by", "Beats", " Apple"
# prompt, subject, object = "Bill Gates is the founder of", "Bill Gates", " Microsoft"
# prompt, subject, object = "The Dalai Lama is a spiritual leader from", " Dalai Lama", " Tibet"


# result = run_pipeline_one_prompt(prompt, subject, object)
result.min_features, result.min_ablated_completion, result.min_ablation_num = find_important_features_by_ablation(
    result.prompt,
    result.object,
    get_last_subject_index(result.prompt, result.subject),
    result.found_features,
)
result

Failed at ablation -30 with #features=1 - '<bos>Marie Curie discovered the element radium in 1898. She found that radium disintegrates'
Failed at ablation -30 with #features=2 - '<bos>Marie Curie discovered the element radium in 1898. She found that radium disintegrates'
Failed at ablation -30 with #features=3 - '<bos>Marie Curie discovered the element radium in 1898. She found that radium reacted with substances'
Failed at ablation -30 with #features=4 - '<bos>Marie Curie discovered the element radium in 1898. She found that the element had a'
Failed at ablation -50 with #features=1 - '<bos>Marie Curie discovered the element radium in 1898. She found that the element had a'
Failed at ablation -50 with #features=2 - '<bos>Marie Curie discovered the element radium in 1898. She found that the element had a'
Failed at ablation -50 with #features=3 - '<bos>Marie Curie discovered the element radium in 1898. She found that the element had an'
Failed at ablation -50 with #features=4 - '<bos>M

Result for 'Marie Curie discovered the element'
	Logit diff: -1.4076080322265625
	' radium' -> ' radium'
	'<bos>Marie Curie discovered the element radium in 1898. She was awarded the Nobel Prize in' -> '<bos>Marie Curie discovered the element radium in 1898. She was awarded the Nobel Prize in'
	All features: [('layer_22/width_16k/average_l0_121', 13848, 3.5524840354919434), ('layer_16/width_16k/average_l0_72', 2072, 3.521700382232666), ('layer_20/width_16k/average_l0_109', 15142, 3.436981201171875), ('layer_6/width_16k/average_l0_133', 3606, 2.357353687286377)]
	Min features: [('layer_6/width_16k/average_l0_133', 3606, 2.357353687286377)]
	Min ablated completion: '<bos>Marie Curie discovered the element Curie in 1898. She was a French physicist and chemist'
	Min ablation num: '-70'

In [176]:
subject_index = get_last_subject_index(prompt, subject)

def ablate_feature_hook(acts, hook, index, cols):
    for col in cols:
        for i in [index-1, index, index+1, -1, -2]:
            acts[:,i,col] = -60


saes_to_features = {}
for sae_id, index, _ in result.min_features:
    if sae_id not in saes_to_features:
        saes_to_features[sae_id] = []

    saes_to_features[sae_id].append(index)

# Create SAEs and hooks
saes = [(SAE.from_pretrained(release = "gemma-scope-2b-pt-mlp", sae_id = sae_id, device = device)[0], cols) for sae_id, cols in saes_to_features.items()]
fwd_hooks = [(sae.cfg.hook_name + '.hook_sae_acts_post', partial(ablate_feature_hook, cols=cols, index=subject_index)) for sae, cols in saes]
pred_loop(model, prompt, n=20, use_hooks=True, saes=saes, fwd_hooks=fwd_hooks)

'<bos>The capital of the country Lionel Messi plays for is Barcelona. The capital of the country Cristiano Ronaldo plays for is Manchester.\n\nThe capital of the country'

In [172]:
results[2]

Result for 'Pablo Picasso painted'
	Logit diff: 0.0
	' a' -> ' a'
	'<bos>Pablo Picasso painted a portrait of his mistress, Marie-Thérèse Walter, in 1' -> '<bos>Pablo Picasso painted a portrait of his mistress, Marie-Thérèse Walter, in 1'
	All features: []
	Min ablated completion: '<bos>Pablo Picasso painted the famous <em>Les</em> <em>N</em><em>o</em><em>tres'

### Run

In [15]:
results = []
for _, row in tqdm(list(df.iterrows())):
    results.append(run_pipeline_one_prompt(row.prompt, row.subject, row.object))

100%|██████████| 63/63 [20:27<00:00, 19.48s/it]  


In [16]:
results

[Result for 'Barack Obama was born in'
 	Logit diff: -5.333288192749023
 	' Hawaii' -> ' '
 	'<bos>Barack Obama was born in Hawaii, but he was raised in Indonesia. He was a student at Occidental' -> '<bos>Barack Obama was born in 1964 in the small town of Kinyanya, Kenya'
 	All features: [('layer_6/width_16k/average_l0_133', 4999, 6.1822404861450195), ('layer_8/width_16k/average_l0_136', 14111, 1.0018670558929443), ('layer_20/width_16k/average_l0_109', 12504, 4.3818769454956055)],
 Result for 'George Bush was the governor of'
 	Logit diff: 0.04281425476074219
 	' Texas' -> ' Florida'
 	'<bos>George Bush was the governor of Texas when he was elected president in 2000. He was' -> '<bos>George Bush was the governor of Florida when he was elected president in 1988. He was'
 	All features: [('layer_4/width_16k/average_l0_85', 16103, 1.7314324378967285), ('layer_9/width_16k/average_l0_88', 11567, 1.6754446029663086), ('layer_8/width_16k/average_l0_136', 8213, 2.7836596965789795), ('layer_18/

In [17]:
for result in tqdm(results):
    result.min_features, result.min_ablated_completion, result.min_ablation_num = find_important_features_by_ablation(result.prompt, result.object, get_last_subject_index(result.prompt, result.subject), result.found_features)

results

100%|██████████| 63/63 [24:58<00:00, 23.78s/it] 


[Result for 'Barack Obama was born in'
 	Logit diff: -5.333288192749023
 	' Hawaii' -> ' '
 	'<bos>Barack Obama was born in Hawaii, but he was raised in Indonesia. He was a student at Occidental' -> '<bos>Barack Obama was born in 1964 in the small town of Kinyanya, Kenya'
 	All features: [('layer_6/width_16k/average_l0_133', 4999, 6.1822404861450195), ('layer_8/width_16k/average_l0_136', 14111, 1.0018670558929443), ('layer_20/width_16k/average_l0_109', 12504, 4.3818769454956055)]
 	Min features: [('layer_6/width_16k/average_l0_133', 4999, 6.1822404861450195)]
 	Min ablated completion: '<bos>Barack Obama was born in 1961, the son of a Kenyan father and a white'
 	Min ablation num: '-70',
 Result for 'George Bush was the governor of'
 	Logit diff: 0.04281425476074219
 	' Texas' -> ' Florida'
 	'<bos>George Bush was the governor of Texas when he was elected president in 2000. He was' -> '<bos>George Bush was the governor of Florida when he was elected president in 1988. He was'
 	All feat

In [18]:
import pickle
# with open("results_16_fixed.pkl", "wb") as f:
#     pickle.dump(results, f)

with open("results_65_fixed.pkl", "rb") as f:
    results = pickle.load(f)

In [18]:
for result in tqdm(results):
    result.feature_diffs = score_features_by_ablation(result.prompt, result.object, get_last_subject_index(result.prompt, result.subject), result.found_features, ablation_num=-20)

100%|██████████| 63/63 [19:55<00:00, 18.97s/it] 


In [118]:
Result(**results[0].__dict__)

Result for 'Barack Obama was born in'
	Logit diff: -5.333288192749023
	' Hawaii' -> ' '
	'<bos>Barack Obama was born in Hawaii, but he was raised in Indonesia. He was a student at Occidental' -> '<bos>Barack Obama was born in 1964 in the small town of Kinyanya, Kenya'
	All features: [('layer_20/width_16k/average_l0_109', 12504, 4.3818769454956055), ('layer_8/width_16k/average_l0_136', 14111, 1.0018670558929443), ('layer_6/width_16k/average_l0_133', 4999, 6.1822404861450195)]
	Min features: [('layer_6/width_16k/average_l0_133', 4999, 6.1822404861450195)]
	Min ablated completion: '<bos>Barack Obama was born in 1962, and he was the son of a Muslim.'
	Min ablation num: '-70'

In [119]:
results2 = []
for result in results:
    results2.append(Result(**result.__dict__))

In [120]:
with open("results2.pkl", "wb") as f:
    pickle.dump(results2, f)

In [123]:
results2[10].feature_diffs

{('layer_24/width_16k/average_l0_73',
  8291,
  7.590907096862793): -1.12371826171875,
 ('layer_6/width_16k/average_l0_133',
  8360,
  1.6326402425765991): -0.6614704132080078,
 ('layer_23/width_16k/average_l0_128',
  15497,
  37.62406539916992): 0.022794723510742188,
 ('layer_13/width_16k/average_l0_112',
  1146,
  1.9262585639953613): -0.0060710906982421875,
 ('layer_12/width_16k/average_l0_108',
  6557,
  4.149245262145996): 0.3075523376464844,
 ('layer_4/width_16k/average_l0_85',
  9499,
  1.5218584537506104): -0.45745086669921875}

In [128]:
import re

results[0].min_features[0]

def get_exp_by_feature(sae, feature):
    layer = int(re.match(r"layer_(\d+)/", sae).group(1))
    return get_explanations_by_features("gemma-2-2b", f"res|{layer}|{sae}", feature)

index = 20
sae_id, feature = results[index].min_features[0][0], results[index].min_features[0][1]
print(results[index].prompt)
print(sae_id, feature)
get_exp_by_feature(sae_id, feature)

The Colosseum is located in
layer_14/width_65k/average_l0_89 10216
explanations url =  https://www.neuronpedia.org/api/explanation/export?modelId=gemma-2-2b&saeId=res|14|layer_14/width_65k/average_l0_89


[]

In [20]:
# sae, _, _ = SAE.from_pretrained(release = "gemma-scope-2b-pt-mlp", sae_id = sae_id, device = device)
sae, _, _ = SAE.from_pretrained(release = "gemma-scope-2b-pt-res", sae_id = "layer_5/width_65k/average_l0_105", device = device)

In [38]:
feature = 34324
model.unembed(sae.W_dec[feature]).topk(200).values

tensor([0.8150, 0.7861, 0.7471, 0.7338, 0.7214, 0.7212, 0.7133, 0.7000, 0.6989,
        0.6881, 0.6835, 0.6752, 0.6660, 0.6606, 0.6573, 0.6543, 0.6527, 0.6475,
        0.6436, 0.6430, 0.6424, 0.6397, 0.6396, 0.6355, 0.6334, 0.6319, 0.6300,
        0.6297, 0.6267, 0.6263, 0.6244, 0.6230, 0.6222, 0.6209, 0.6186, 0.6170,
        0.6167, 0.6151, 0.6135, 0.6126, 0.6120, 0.6091, 0.6024, 0.6023, 0.6022,
        0.6017, 0.5986, 0.5956, 0.5952, 0.5947, 0.5905, 0.5903, 0.5899, 0.5892,
        0.5842, 0.5834, 0.5824, 0.5819, 0.5778, 0.5777, 0.5776, 0.5770, 0.5758,
        0.5753, 0.5748, 0.5743, 0.5736, 0.5726, 0.5721, 0.5719, 0.5702, 0.5702,
        0.5689, 0.5664, 0.5648, 0.5634, 0.5633, 0.5632, 0.5616, 0.5613, 0.5611,
        0.5600, 0.5598, 0.5585, 0.5568, 0.5556, 0.5549, 0.5544, 0.5540, 0.5532,
        0.5532, 0.5521, 0.5515, 0.5513, 0.5511, 0.5508, 0.5506, 0.5506, 0.5495,
        0.5489, 0.5472, 0.5469, 0.5451, 0.5446, 0.5444, 0.5438, 0.5436, 0.5432,
        0.5426, 0.5416, 0.5413, 0.5411, 

In [39]:
print(model.to_str_tokens(model.unembed(sae.W_dec[feature]).topk(100).indices))

[' незавершена', ' Normdatei', ' chi̍t', ' متعلقه', 'dymyr', 'ngdoc', ' Efq', 'WithIOException', 'QMetaType', ' iſt', 'ressee', ' ſtate', 'ſelf', ' StatefulWidget', 'neſs', ' myſelf', '՚', '\x12', ' Majefty', 'Lähteet', 'حياته', 'toires', ' Stron', '*/;', ' disambiguazione', 'centralwidget', ' arşivlendi', 'resaid', ' configureStore', 'Diwedd', ' ſch', ' ſind', ' pleaſure', ' insuffisamment', ' geox', ' BoxDecoration', '={({', ' تانيه', ' endregion', ' للمعارف', ' cdti', ' CSRF', 'ſelves', ' laſt', ' CHtml', ' ſche', 'Jeografia', ' onData', '!")\r', ' tématu', ' يتيمه', ' itſelf', ' Moines', ' NUKAT', 'makeConstraints', 'DataContract', ' crdi', ' löyty', ' Pingback', ' niedersachsen', 'ագրություններ', ' Reſ', ' تضيفلها', 'vocable', 'noinspection', 'createClass', '!*\\', ' ftate', ' reaſon', ' myrtle', 'Pratique', 'runOnUiThread', ' */;', " '\\\\;'", 'ValueStyle', 'الحياه', 'ulier', ' beſt', ' herren', '---*/', 'énario', 'מבר', 'uxxxx', 'esgue', ' stdClass', ' Anſ', ' sé', ' femininas',

In [40]:
print(model.to_str_tokens(model.unembed(sae.W_enc[:, feature]).topk(50).indices))

[' chi̍t', ' Efq', ' Normdatei', 'neſs', ' ſtate', 'QMetaType', ' arşivlendi', '*/;', 'ſelf', ' myſelf', '՚', ' CHtml', ' незавершена', ' itſelf', ' geox', 'toires', ' يتيمه', ' تانيه', 'ressee', 'ſelves', ' löyty', 'Jeografia', ' laſt', ' MotionEvent', 'withIdentifier', ' Majefty', ' Monfieur', ' femininas', ' ftate', ' pleaſure', ' tslint', ' PrintWriter', ' Reſ', ' متعلقه', '!")\r', 'WithIOException', ' للمعارف', ' configureStore', 'حياته', 'nofollow', ' Betracht', 'Халык', ' reaſon', 'Pratique', '--)\r', ' BoxDecoration', 'INSEE', ' Chriftian', ' stdClass', ' */;']


In [223]:
try_ablation("I want to write all of the names in a", " list", 2, [("layer_5/width_65k/average_l0_105", 34326,0)], 0)

(-2.465116500854492,
 1889,
 3821,
 '<bos>I want to write all of the names in a list and then print them out. I have the list of names in a',
 '<bos>I want to write all of the names in a single line in a single line in a single line in a single line in')

In [222]:
model.reset_hooks(); model.reset_saes()

In [219]:
pred_loop(model, "I want to write all of the names in a")

'<bos>I want to write all of the names in a list and then print them out. I have the'

## Lowering Key Match

In [95]:
result = run_pipeline_one_prompt("George Bush was the governor of", "George Bush", " Texas")
result.min_features, result.min_ablated_completion, result.min_ablation_num = find_important_features_by_ablation(result.prompt, result.object, get_last_subject_index(result.prompt, result.subject), result.found_features)
result

Result for 'George Bush was the governor of'
	Logit diff: -5.885408401489258
	' Texas' -> ' Florida'
	'<bos>George Bush was the governor of Texas when he was elected president in 2000. He was' -> '<bos>George Bush was the governor of Florida when he was elected president in 2000. He was'
	All features: [('layer_4/width_65k/average_l0_66', 13357, 1.395000696182251), ('layer_7/width_65k/average_l0_115', 64391, 4.383184909820557), ('layer_20/width_65k/average_l0_88', 6027, 11.44218921661377), ('layer_8/width_65k/average_l0_110', 46230, 2.594975471496582), ('layer_14/width_65k/average_l0_89', 30758, 3.128693103790283)]
	Min features: [('layer_14/width_65k/average_l0_89', 30758, 3.128693103790283)]
	Min ablated completion: '<bos>George Bush was the governor of New York when he was elected president. He was the governor of New York'
	Min ablation num: '-70'

In [76]:
result = run_pipeline_one_prompt("Barack Obama was born in", "Barack Obama", " Hawaii")
result.min_features, result.min_ablated_completion, result.min_ablation_num = find_important_features_by_ablation(result.prompt, result.object, get_last_subject_index(result.prompt, result.subject), result.found_features)
result

Result for 'Barack Obama was born in'
	Logit diff: -0.29739952087402344
	' Hawaii' -> ' Hawaii'
	'<bos>Barack Obama was born in Hawaii, but he was raised in Indonesia. He was a student at Occidental' -> '<bos>Barack Obama was born in Hawaii, but he grew up in Indonesia. He was raised by his mother'
	All features: [('layer_6/width_65k/average_l0_101', 62465, 1.2823169231414795)]
	Min features: [('layer_6/width_65k/average_l0_101', 62465, 1.2823169231414795)]
	Min ablated completion: '<bos>Barack Obama was born in Kenya, raised in Indonesia, and educated in Hawaii. He is the first'
	Min ablation num: '-70'

In [8]:
sae = SAE.from_pretrained(release = "gemma-scope-2b-pt-mlp", sae_id = "layer_14/width_65k/average_l0_89", device = device)[0]

In [9]:
model.to_str_tokens(model.unembed(sae.W_dec[30758]).topk(200).indices)

[' Texas',
 ' Texan',
 'Texas',
 ' Dallas',
 ' Texans',
 'Dallas',
 ' TX',
 ' Houston',
 'TX',
 ' texas',
 'Houston',
 'txn',
 'TEXAS',
 'iterable',
 'tx',
 'texas',
 ' Tx',
 ' TEXAS',
 ' Lubbock',
 ' tx',
 'Tex',
 'python',
 ' Teks',
 ' Waco',
 '卦',
 'Tx',
 ' star',
 ' Cowboys',
 ' Cactus',
 ' ranch',
 ' Tex',
 ' colonias',
 ' TEX',
 ' Oklahoma',
 'Gujarat',
 'Oklahoma',
 ' wholesale',
 'TEX',
 'Baylor',
 ' Galveston',
 'cycle',
 ' UNT',
 'пита',
 ' tex',
 '>>',
 ' Amarillo',
 ' ane',
 ' cycle',
 ' Ana',
 ' stars',
 ' OMITBAD',
 'Ana',
 ' recovered',
 ' placed',
 ' ana',
 'gallon',
 ' python',
 ' sack',
 ' Ranch',
 ' Southwestern',
 ' we',
 ' QH',
 'documentElement',
 ' in',
 'ไข',
 ' Python',
 ' txn',
 'Cactus',
 'Ezek',
 ' Hester',
 'nak',
 ' Gujarat',
 ' Denton',
 '>>>',
 'ansas',
 ' estrela',
 'Wholesale',
 'rank',
 ' Qui',
 ' deserto',
 ' Fort',
 '[::-',
 'сал',
 ' מק',
 ' juist',
 ' Lucio',
 'RequestMethod',
 ' hard',
 'stackoverflow',
 ' Austin',
 '习',
 'timedelta',
 ' jLabel',

In [10]:
model.to_str_tokens(model.unembed(sae.W_enc[:,30758]).topk(200).indices)

[' Texan',
 ' Texans',
 ' Dallas',
 'Dallas',
 ' Texas',
 'mal',
 ' but',
 ' TX',
 ' éstos',
 'DateTimeField',
 'Ӏ',
 'nt',
 ' ab',
 ' SWIG',
 'Texas',
 ' OMITBAD',
 ' But',
 '<h3>',
 ' Waco',
 'tx',
 'ilent',
 '繁',
 "('../../../",
 ' Cowboys',
 ' in',
 'ाल',
 ' éstas',
 'UT',
 ' Com',
 'sellors',
 ' Tud',
 ' juist',
 'ณ์',
 'ly',
 '_',
 'al',
 ' прошло',
 ' "../../../',
 ' Iqbal',
 'ut',
 'timmt',
 ' we',
 'Scaffold',
 'Pyx',
 ' Houston',
 ' hề',
 ' usus',
 ' learnt',
 ' considérons',
 'ival',
 'else',
 ' star',
 ' Na',
 "('../",
 'ual',
 'たい',
 ' Mais',
 ' about',
 ' stars',
 'Gujarat',
 'But',
 ' fortuna',
 'эй',
 ' again',
 'Houston',
 ' else',
 ' gone',
 ' it',
 'Cowboy',
 " '../",
 'шед',
 'おり',
 ' Um',
 ' Hark',
 'sic',
 'TX',
 ' Aber',
 '壕',
 '\n\n\n',
 ' ISD',
 ' then',
 ' now',
 ' moti',
 'сал',
 'mLast',
 'etro',
 ' past',
 ' tá',
 '     ',
 ' Segu',
 ' Austin',
 'utilise',
 'Ab',
 'τως',
 ' वाली',
 ' succes',
 ' diikuti',
 ' whereas',
 ' Nach',
 'Twas',
 'Tud',
 'OU',
 'TEX

In [34]:
cache = model.run_with_cache("George Bush was the governor of")[1]

In [12]:
model.to_str_tokens(model.unembed(cache["blocks.14.hook_resid_mid"][0,2]).topk(50).indices)

[' Bush',
 'Bush',
 ',',
 ' and',
 '’',
 ' was',
 ' who',
 ' BUSH',
 "'",
 'Busch',
 ' bush',
 ' is',
 ' on',
 ' announced',
 ' (',
 ' yesterday',
 ' has',
 ' in',
 'hhhhhhhh',
 ' whoſe',
 ' признаки',
 ' whofe',
 'かしい',
 ' over',
 ' ſaid',
 ' through',
 ' –',
 ' לכם',
 '-',
 ' that',
 ' ',
 'Obama',
 ' today',
 ' got',
 ' had',
 ' Majefty',
 ' from',
 'Laughter',
 'Washington',
 ' with',
 ' apparently',
 'NULL',
 ' suddenly',
 'Przypisy',
 ' (“',
 ' Efq',
 ' announce',
 'monių',
 ' President',
 ' urged']

In [13]:
cache["blocks.14.hook_resid_mid"][0,2] @ sae.W_enc[:,30758]

tensor(10.5064, device='cuda:0')

In [14]:
torch.set_grad_enabled(True)
torch.is_grad_enabled()

True

In [17]:
def train_min_shift(key, col_num, epochs=1000, lr=1e-3):
    losses, matches, diffs = [], [], []
    
    v = sae.W_enc[:,col_num].clone()
    delta = torch.zeros_like(v, requires_grad=True, device=device)

    indices = torch.randint(0, 256000, (100000,))
    normal_output = model.embed.W_E[indices] @ v

    optimizer = torch.optim.Adam([delta], lr=lr)

    pbar = tqdm(total=epochs)
    for epoch in range(epochs):
        
        new_v = v + delta
        key_match = key @ new_v
        key_loss = (-1000 - key_match) ** 2
        
        new_output = model.embed.W_E[indices] @ new_v
        output_diff = (new_output - normal_output).abs().max()

        loss = key_loss + output_diff

        losses.append(loss.item())
        matches.append(key_match.item())
        diffs.append(output_diff.item())

        loss.backward(retain_graph=True)
        optimizer.step()
        optimizer.zero_grad()
        
        pbar.update(1)
        pbar.set_description(f"Epoch: {epoch+1}/{epochs}, Loss: {loss.item():.2f}, Match: {key_match.item():.2f}, Key Loss: {key_loss.item():.2f}, Diff: {output_diff.item()/1000:.2f}")

    return delta, losses, matches, diffs

In [18]:
key = cache["blocks.14.hook_resid_mid"][0,2].clone()
delta, losses, matches, diffs = train_min_shift(key, 30758)

Epoch: 1000/1000, Loss: 96.84, Match: -999.94, Key Loss: 0.00, Diff: 0.10: 100%|██████████| 1000/1000 [00:21<00:00, 46.37it/s]       


In [49]:
key @ (sae.W_enc[:,30758] + delta)

tensor(-2010.3872, device='cuda:0', grad_fn=<DotBackward0>)

In [20]:
model.to_str_tokens(model.run_with_saes("George Bush was the governor of", saes=[sae])[0,-1].argmax())

[' Texas']

In [21]:
model.to_str_tokens(model.run_with_saes("The Lone Star State, also known as", saes=[sae])[0,-1].argmax())

[' Texas']

In [22]:
backup = sae.W_enc[:,30758].clone()

In [24]:
sae.W_enc.requires_grad = False

In [25]:
sae.W_enc[:, 30758] = backup + delta

In [26]:
torch.allclose(sae.W_enc[:, 30758] - backup, delta)

True

In [29]:
model.to_str_tokens(model.run_with_saes("George Bush was the governor of", saes=[sae])[0,-1].argmax())

tensor(9447, device='cuda:0')

In [28]:
model.to_str_tokens(model.run_with_saes("The Lone Star State, also known as", saes=[sae])[0,-1].argmax())

[' Texas']

In [35]:
logits, cache2 = model.run_with_cache_with_saes("George Bush was the governor of", saes=[sae])

In [38]:
torch.allclose(cache["blocks.14.hook_resid_mid"][0,2], cache2["blocks.14.hook_resid_mid"][0,2])

True

In [54]:
(cache["blocks.14.hook_resid_mid"][0,2] @ sae.W_enc)[30758]

tensor(-999.9406, device='cuda:0', grad_fn=<SelectBackward0>)

In [70]:
sae.decode(sae.encode(cache["blocks.14.hook_resid_mid"][0,2]))

tensor([ 0.2053, -9.5506, -5.4975,  ..., -3.4410,  0.2597, -0.5893],
       device='cuda:0', grad_fn=<AddBackward0>)

In [59]:
cache2[sae.cfg.hook_name + ".hook_sae_acts_input"][0,2,30758]

KeyError: 'blocks.14.hook_mlp_out.hook_sae_acts_input'

In [55]:
[x for x in cache2.keys() if "sae" in x]

['blocks.14.hook_mlp_out.hook_sae_input',
 'blocks.14.hook_mlp_out.hook_sae_acts_pre',
 'blocks.14.hook_mlp_out.hook_sae_acts_post',
 'blocks.14.hook_mlp_out.hook_sae_recons',
 'blocks.14.hook_mlp_out.hook_sae_output']

In [None]:
sae.en

In [84]:
def hook_shite(act, hook):
    encoded = sae.encode_standard(cache["blocks.14.hook_resid_mid"])
    # print(encoded[0,2,30758])
    # encoded[0,2,30758] = -100
    return sae.decode(encoded)

model.to_str_tokens(model.run_with_hooks("George Bush was the governor of", fwd_hooks=[("blocks.14.hook_resid_post", hook_shite)])[0,-1].argmax())

[' Texas']