In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")



In [2]:
# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops

# import the LLM
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: cuda


In [3]:
import os
from dotenv import load_dotenv
from pathlib import Path

if not COLAB:
  # Load environment variables from the .env file
  load_dotenv()

  # Access the PYTHONPATH variable
  pythonpath = Path(os.getenv('PYTHONPATH'))

  # Print to verify
  print(f"PYTHONPATH: {pythonpath}")

  datapath = pythonpath / 'data'
  print(f"Data path: {datapath}")
else:
  datapath = Path('./')
  print(f"Data path: {datapath}")

Data path: .


In [4]:
# define the model to work with
GEMMA = True

if GEMMA == True:
    RELEASE = 'gemma-2b-res-jb'
    BASE_MODEL = "gemma-2b"
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    FINETUNE_DATASET = 'flytech/python-codes-25k'
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
    hook_part = 'post'
    layer_num = 6
else:
    RELEASE = 'gpt2-small-res-jb'
    BASE_MODEL = "gpt2-small"
    FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'
    FINETUNE_DATASET = 'legacy-datasets/wikipedia'
    DATASET_NAME = "Skylion007/openwebtext"
    hook_part = 'pre'
    layer_num = 6

In [17]:
def list_flatten(nested_list):
    return [x for y in nested_list for x in y]

# A very handy function Neel wrote to get context around a feature activation
def make_token_df(tokens, len_prefix=5, len_suffix=3, model=None):
    str_tokens = [model.to_str_tokens(t) for t in tokens]
    unique_token = [[f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens]

    context = []
    prompt = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p-len_prefix):p])
            if p==tokens.shape[1]-1:
                suffix = ""
            else:
                suffix = "".join(str_tokens[b][p+1:min(tokens.shape[1]-1, p+1+len_suffix)])
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            prompt.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(dict(
        str_tokens=list_flatten(str_tokens),
        unique_token=list_flatten(unique_token),
        context=context,
        prompt=prompt,
        pos=pos,
        label=label,
    ))

In [18]:
import torch
import pandas as pd
from tqdm import tqdm

def get_feature_activations(model, sae, activation_store, total_batches=20,
                            feature_list=None, len_prefix=8, len_suffix=5, return_loss=False):
    # If no specific feature list is provided, use all features available in the SAE configuration.
    if feature_list is None:
        feature_list = torch.arange(sae.cfg.d_sae, device='cpu')  # Initialize on CPU to save GPU memory

    # Initialize lists to store all feature activations and token dataframes across batches
    all_feature_acts = []
    all_token_dfs = []
    # Initialize a tensor to accumulate the total loss across all batches, if needed
    total_loss = torch.zeros(1) if return_loss else None  # Store loss on CPU

    # Loop over the total number of batches with a progress bar
    pbar = tqdm(range(total_batches))
    for i in pbar:
        # Get a batch of tokens from the dataset
        tokens = activation_store.get_batch_tokens()
        # Convert the tokens into a dataframe format with context, adding a batch identifier
        tokens_df = make_token_df(tokens, len_prefix=len_prefix, len_suffix=len_suffix, model=model)
        tokens_df["batch"] = i

        if return_loss:
            # Run the model on the batch, capturing loss and intermediate activations in a cache
            loss, cache = model.run_with_cache(tokens, return_type='loss', names_filter=[sae.cfg.hook_name])
            total_loss += loss.cpu()  # Accumulate loss on CPU to save GPU memory
        else:
            # Run the model up to the SAE hooked layer without returning the loss
            _, cache = model.run_with_cache(tokens, stop_at_layer=sae.cfg.hook_layer+1, names_filter=[sae.cfg.hook_name])

        # Retrieve the activations from the specified layer
        sae_in = cache[sae.cfg.hook_name].half()  # Reduce precision to save memory
        # Encode the activations using the SAE model, flatten them and convert to half precision
        feature_acts = sae.encode(sae_in).squeeze().half()
        feature_acts = feature_acts.flatten(0,1)

        # Create a mask to identify tokens that activate at least one of the selected features
        fired_mask = (feature_acts[:, feature_list]).sum(dim=-1) > 0
        # Filter the token dataframe for the tokens that fired (i.e., those that activated the selected features)
        token_df = tokens_df.iloc[fired_mask.cpu().nonzero().flatten().numpy()]

        # Append the filtered token dataframe and feature activations to their respective lists
        all_token_dfs.append(token_df)
        all_feature_acts.append(feature_acts[fired_mask][:, feature_list])

        # Explicitly free up memory by deleting the cache and emptying the CUDA cache
        del cache
        torch.cuda.empty_cache()

    # After processing all batches, concatenate the token dataframes into a single dataframe
    all_token_dfs = pd.concat(all_token_dfs)
    # Concatenate all feature activations into a single tensor
    all_feature_acts = torch.cat(all_feature_acts)

    if return_loss:
        # Compute the average loss across all batches if loss calculation was enabled
        total_loss = total_loss / total_batches
        return all_token_dfs, all_feature_acts, total_loss
    else:
        return all_token_dfs, all_feature_acts

In [14]:
def get_intervals_df(df, value_column, descending=True, n_tokens=10,
                     interval_fractions=[0.4, 0.4, 0.2], sample_random=False):
    # Sort the dataframe by the value column
    df_sorted = df.sort_values(by=value_column, ascending=not descending)

    # Calculate the maximum activation value
    max_act = df_sorted[value_column].max()

    # Determine the number of tokens for each interval
    interval_counts = [int(n_tokens * frac) for frac in interval_fractions]

    # Initialize a list to store dataframes for each interval
    interval_dfs = []

    # Calculate the thresholds for each interval and sample tokens
    current_threshold = max_act
    for i, count in enumerate(interval_counts):
        if i < len(interval_fractions) - 1:
            # Calculate the threshold for the next interval
            next_threshold = max_act * (1 - sum(interval_fractions[:i + 1]))
        else:
            # For the last interval, no need to set a lower threshold
            next_threshold = -float('inf')

        # Filter the dataframe for the current interval
        interval_df = df_sorted[(df_sorted[value_column] <= current_threshold) &
                                (df_sorted[value_column] > next_threshold)]

        # Sample or select the top tokens
        if sample_random:
            interval_df = interval_df.sample(n=min(count, len(interval_df)))
        else:
            interval_df = interval_df.head(count)

        # Add the selected tokens to the list
        interval_dfs.append(interval_df)

        # Update the current threshold
        current_threshold = next_threshold

    # Concatenate the interval dataframes and return the final result
    final_df = pd.concat(interval_dfs).head(n_tokens)

    return final_df

def get_feature_intervals_df(feature, token_df, act_df, descending=True, sample_random=False,
                             n_tokens=10, interval_fractions = [0.4, 0.4, 0.2]):
    # Sort the activations of the specified feature in descending order
    sorted_activations, sorted_activations_idx = torch.sort(act_df[:, feature], descending=descending)
    max_act = sorted_activations[0].item()

    # Store the current feature activation values in the token dataframe
    token_df['activation_value'] = act_df[:, feature].cpu()

    # Create a view of the token dataframe sorted by the activations of the specified feature
    sorted_token_df = token_df.iloc[sorted_activations_idx.cpu().numpy()]

    return get_intervals_df(sorted_token_df, 'activation_value', sample_random=sample_random,
                            descending=descending, n_tokens=n_tokens, interval_fractions=interval_fractions)

### Base model: looking for the features activating in the finetuning dataset's context

In [5]:
base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [6]:
# import the required libraries
from sae_lens import SAE

sae_id = f'blocks.{layer_num}.hook_resid_{hook_part}'

sae, cfg_dict, sparsity = SAE.from_pretrained(
                            release = RELEASE,
                            sae_id = sae_id,
                            device = device
)
cfg_dict

{'model_name': 'gemma-2b',
 'model_class_name': 'HookedTransformer',
 'hook_point': 'blocks.6.hook_resid_post',
 'hook_point_eval': 'blocks.{layer}.attn.pattern',
 'hook_point_layer': 6,
 'hook_point_head_index': None,
 'dataset_path': 'HuggingFaceFW/fineweb',
 'streaming': True,
 'is_dataset_tokenized': False,
 'context_size': 1024,
 'use_cached_activations': False,
 'cached_activations_path': None,
 'd_in': 2048,
 'd_sae': 16384,
 'b_dec_init_method': 'zeros',
 'expansion_factor': 8,
 'activation_fn': 'relu',
 'normalize_sae_decoder': False,
 'noise_scale': 0.0,
 'from_pretrained_path': None,
 'apply_b_dec_to_input': False,
 'decoder_orthogonal_init': False,
 'decoder_heuristic_init': True,
 'init_encoder_as_decoder_transpose': True,
 'n_batches_in_buffer': 64,
 'training_tokens': 1228800000,
 'finetuning_tokens': 0,
 'store_batch_size_prompts': 8,
 'train_batch_size_tokens': 4096,
 'normalize_activations': 'none',
 'device': 'cuda',
 'seed': 42,
 'dtype': 'torch.float32',
 'prepend_

In [9]:
from datasets import Dataset, load_dataset

if GEMMA:
    finetune_dataset = load_dataset(FINETUNE_DATASET, split='train', streaming=True)
else:
    # I didn't manage to find the Portuguese wikipedia dump, it's not specified in the model card.
    # The apparent one based on the exception I got seems to be 20220301.pt, but it doesn't work...
    finetune_dataset = load_dataset(FINETUNE_DATASET, '20220301.simple', split='train',
                                    streaming=True, trust_remote_code=True)

In [8]:
from sae_lens import ActivationsStore

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=base_model,
    sae=sae,
    streaming=True,
    dataset=finetune_dataset,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

batch_size_prompts = activation_store.store_batch_size_prompts
batch_size_tokens = activation_store.context_size * batch_size_prompts

batch_size_prompts, batch_size_tokens



(8, 8192)

In [11]:
all_token_dfs, all_feature_acts = get_feature_activations(base_model, sae, activation_store,
                                                          total_batches=20)
clear_cache()
all_feature_acts.shape

100%|██████████| 20/20 [00:23<00:00,  1.16s/it]


torch.Size([163840, 16384])

In [12]:
import plotly.express as px
all_feature_mean_acts = all_feature_acts.mean(0)

fig = px.line(
    all_feature_mean_acts.cpu(),
    title="Mean feature activations: finetune dataset",
    labels={"index": "Feature", "value": "Activation"},
)
fig.show()

In [13]:
# The weird saving name is "Feature Activations on the Finetune dataset"
torch.save(all_feature_mean_acts, datapath / f'{BASE_MODEL}_base_faof.pt')

Let's see if found some genuine features activating in the context of finetuning dataset

In [15]:
# Extract the top-25 activating features indices and their activating values
top_activating_features_values, top_activating_features = all_feature_mean_acts.topk(10)
top_activating_features, top_activating_features_values

(tensor([11631,  5439,  3309, 14173,  6947, 15417, 13728, 16277, 11683, 12522],
        device='cuda:0'),
 tensor([6.1992, 2.6465, 1.9238, 1.5693, 1.5293, 1.1357, 1.0742, 1.0293, 1.0215,
         0.8975], device='cuda:0', dtype=torch.float16))

In [16]:
# To prevent line breaks while printing dataframes
pd.set_option('display.expand_frame_repr', False)

for i, feature in enumerate(top_activating_features):
    mean_feature_value = top_activating_features_values[i]
    print(f'\n{i+1}) Feature #{feature}: Mean activation = {mean_feature_value}\n' + '-'*100)

    act_intervals_df = get_feature_intervals_df(feature, all_token_dfs, all_feature_acts,
                                                n_tokens=15)
    print(act_intervals_df)


1) Feature #11631: Mean activation = 6.19921875
----------------------------------------------------------------------------------------------------
     str_tokens unique_token                                            context  prompt  pos  label  batch  activation_value
4876    execute  execute/780      users WHERE id=?'\ncursor.|execute|(safe_sql,       4  780  4/780      7         34.125000
4680    execute  execute/584   = conn.cursor()\ncursor.|execute|('CREATE TAB...       4  584  4/584     11         33.562500
452        Enum     Enum/452   import Enum, auto\nclass Color(|Enum|):\n    ...       0  452  0/452     15         33.093750
3911          (        (/839  .getActiveWindow()\n        if any|(|browser i...       3  839  3/839      4         33.062500
3695    execute  execute/623   conn.cursor()\n    cursor.|execute|('SELECT *...       3  623  3/623      7         33.000000
764          (*       (*/764  logging.INFO)\n    def wrapper|(*|args, **kwar...       0  764  0/764 

In [17]:
del all_feature_acts, all_token_dfs

In [18]:
del base_model, activation_store
clear_cache()

### Finetuned model: looking for the features activating in the finetuning dataset's context

In [5]:
# Load the finetune model and its tokenizer
finetune_model_hf = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL)
finetune_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, hf_model=finetune_model_hf, dtype=torch.float16)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [6]:
del finetune_model_hf
clear_cache()

In [7]:
# import the required libraries
from sae_lens import SAE

sae_id = f'blocks.{layer_num}.hook_resid_{hook_part}'

sae, cfg_dict, sparsity = SAE.from_pretrained(
                            release = RELEASE,
                            sae_id = sae_id,
                            device = device
)
cfg_dict

{'model_name': 'gemma-2b',
 'model_class_name': 'HookedTransformer',
 'hook_point': 'blocks.6.hook_resid_post',
 'hook_point_eval': 'blocks.{layer}.attn.pattern',
 'hook_point_layer': 6,
 'hook_point_head_index': None,
 'dataset_path': 'HuggingFaceFW/fineweb',
 'streaming': True,
 'is_dataset_tokenized': False,
 'context_size': 1024,
 'use_cached_activations': False,
 'cached_activations_path': None,
 'd_in': 2048,
 'd_sae': 16384,
 'b_dec_init_method': 'zeros',
 'expansion_factor': 8,
 'activation_fn': 'relu',
 'normalize_sae_decoder': False,
 'noise_scale': 0.0,
 'from_pretrained_path': None,
 'apply_b_dec_to_input': False,
 'decoder_orthogonal_init': False,
 'decoder_heuristic_init': True,
 'init_encoder_as_decoder_transpose': True,
 'n_batches_in_buffer': 64,
 'training_tokens': 1228800000,
 'finetuning_tokens': 0,
 'store_batch_size_prompts': 8,
 'train_batch_size_tokens': 4096,
 'normalize_activations': 'none',
 'device': 'cuda',
 'seed': 42,
 'dtype': 'torch.float32',
 'prepend_

In [10]:
from sae_lens import ActivationsStore

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=finetune_model,
    sae=sae,
    streaming=True,
    dataset=finetune_dataset,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

batch_size_prompts = activation_store.store_batch_size_prompts
batch_size_tokens = activation_store.context_size * batch_size_prompts

batch_size_prompts, batch_size_tokens



(8, 8192)

In [19]:
all_token_dfs, all_feature_acts = get_feature_activations(finetune_model, sae, activation_store,
                                                          total_batches=20)
clear_cache()

100%|██████████| 20/20 [00:22<00:00,  1.13s/it]


In [20]:
import plotly.express as px
all_feature_mean_acts = all_feature_acts.mean(0)

fig = px.line(
    all_feature_mean_acts.cpu(),
    title="Mean feature activations: finetune dataset",
    labels={"index": "Feature", "value": "Activation"},
)
fig.show()

In [21]:
# The weird saving name is "Feature Activations on the Finetune dataset"
torch.save(all_feature_mean_acts, datapath / f'{BASE_MODEL}_finetune_faof.pt')

In [22]:
# Extract the top-25 activating features indices and their activating values
top_activating_features_values, top_activating_features = all_feature_mean_acts.topk(10)
top_activating_features, top_activating_features_values

(tensor([ 8720,  4416, 11631, 13490,  3075,  4013, 15869,  5862, 12670,  8063],
        device='cuda:0'),
 tensor([2.7148, 1.9844, 1.5127, 1.2168, 0.9858, 0.9312, 0.7471, 0.6699, 0.6538,
         0.6182], device='cuda:0', dtype=torch.float16))

In [23]:
# To prevent line breaks while printing dataframes
pd.set_option('display.expand_frame_repr', False)

for i, feature in enumerate(top_activating_features):
    mean_feature_value = top_activating_features_values[i]
    print(f'\n{i+1}) Feature #{feature}: Mean activation = {mean_feature_value}\n' + '-'*100)

    act_intervals_df = get_feature_intervals_df(feature, all_token_dfs, all_feature_acts,
                                                n_tokens=15)
    print(act_intervals_df)


1) Feature #8720: Mean activation = 2.71484375
----------------------------------------------------------------------------------------------------
     str_tokens unique_token                                            context  prompt  pos  label  batch  activation_value
5612      <bos>    <bos>/492  plotlib to visualize the simulations.\n```|<bo...       5  492  5/492     16         279.50000
5526      <bos>    <bos>/406   sequence to a MIDI file.\n```|<bos>|How to si...       5  406  5/406     16         279.00000
5450      <bos>    <bos>/330   to identify and mitigate biases.\n```|<bos>|H...       5  330  5/330     16         277.75000
5695      <bos>    <bos>/575  -to-speech functionalities.\n```|<bos>|How to ...       5  575  5/575     16         276.75000
1922      <bos>    <bos>/898   model to classify incoming emails.\n```|<bos>...       1  898  1/898     16         274.50000
5918      <bos>    <bos>/798   Extract the most important sentences.\n```|<b...       5  798  5/798  