In [1]:
# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops
# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

In [13]:
# define the model to work with 
GEMMA = False

if GEMMA == True:
    BASE_MODEL = "google/gemma-2b"
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
else:
    BASE_MODEL = "gpt2-small"
    FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'
    DATASET_NAME = "Skylion007/openwebtext"

layer_num = 6
hook_part = "pre"
TOTAL_BATCHES = 500


### Task 4.1 Pretrained case

In [3]:
from saetuning.utils import *

In [4]:
# import the LLM
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)

In [5]:
# import the required libraries
from sae_lens import SAE

# define the SAE id
sae_id = f"blocks.{layer_num}.hook_resid_{hook_part}"
# load the SAE model
sae, cfg_dict, sparsity = SAE.from_pretrained(
        release = 'gpt2-small-res-jb',
        sae_id = sae_id, # in the case of GPT-2 res SAEs, it coincides with the hook name
        device = device
    )

In [6]:
# this must be checked for the forward method of sae.encode_xxx
cfg_dict["activation_fn_str"]

In [7]:
from sae_lens import ActivationsStore

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=base_model,
    sae=sae,
    streaming=True,
    # dataset=chess_dataset,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)

#### 4.1.1 L0 loss

In [14]:
# L0_loss(x, threshold=1e-8)
# get_substitution_loss(tokens, model, sae, sae_layer)

from tqdm import tqdm

all_tokens_L0 = []  # This will store the tokens for reuse
all_L0 = []

for k in tqdm(range(TOTAL_BATCHES)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
    
    # Store tokens for later reuse
    all_tokens.append(tokens)
    
    # Run the model and store the activations 
    _, cache = base_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]
    
    # Get the activations from the cache at the sae_id 
    original_activations = cache[sae_id]
    
    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(original_activations) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM) 

    feature_activations.to('cpu')
    
    # Store the encoded activations
    all_L0.append(L0_loss(feature_activations))

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del original_activations
    del feature_activations
    torch.cuda.empty_cache()

# Concatenate all tokens into a single tensor for reuse
all_tokens_L0 = torch.cat(all_tokens_L0)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 500/500 [00:40<00:00, 12.28it/s]


In [17]:
torch.tensor(all_L0).mean()

tensor(50.3976)

#### 4.1.2 Substitution Loss

In [18]:
# L0_loss(x, threshold=1e-8)
# get_substitution_loss(tokens, model, sae, sae_layer)

from tqdm import tqdm

all_tokens_SL = []  # This will store the tokens for reuse
all_SL = []

for k in tqdm(range(TOTAL_BATCHES)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
    
    # Store tokens for later reuse
    all_tokens_SL.append(tokens)
    
    # Store loss 
    all_SL.append(get_substitution_loss(tokens, base_model, sae, sae_id))

# Concatenate all tokens into a single tensor for reuse
all_tokens = torch.cat(all_tokens_SL)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 500/500 [02:34<00:00,  3.24it/s]


In [20]:
torch.tensor(all_SL).mean()

tensor(3.6367, dtype=torch.float16)

### Task 4.1 FineTuned case

In [21]:
from saetuning.utils import *

In [None]:
# import the LLM
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)

In [22]:
# Load the finetune model and its tokenizer
finetune_tokenizer = AutoTokenizer.from_pretrained(FINETUNE_MODEL)
finetune_model_hf = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL)
finetune_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, hf_model=finetune_model_hf, dtype=torch.float16)

finetune_tokenizer

tokenizer_config.json:   0%|          | 0.00/92.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/850k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/508k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/120 [00:00<?, ?B/s]



Loaded pretrained model gpt2-small into HookedTransformer


GPT2TokenizerFast(name_or_path='pierreguillou/gpt2-small-portuguese', vocab_size=50257, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [23]:
# import the required libraries
from sae_lens import SAE

# define the SAE id
sae_id = f"blocks.{layer_num}.hook_resid_{hook_part}"
# load the SAE model
sae, cfg_dict, sparsity = SAE.from_pretrained(
        release = 'gpt2-small-res-jb',
        sae_id = sae_id, # in the case of GPT-2 res SAEs, it coincides with the hook name
        device = device
    )

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [24]:
# this must be checked for the forward method of sae.encode_xxx
cfg_dict["activation_fn_str"]

'relu'

In [25]:
from sae_lens import ActivationsStore

# a convenient way to instantiate an activation store is to use the from_sae method
activation_store = ActivationsStore.from_sae(
    model=base_model,
    sae=sae,
    streaming=True,
    # dataset=chess_dataset,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=32,
    device=device,
)



#### 4.2.1 L0 loss

In [27]:
# L0_loss(x, threshold=1e-8)
# get_substitution_loss(tokens, model, sae, sae_layer)

from tqdm import tqdm

all_tokens_L0 = []  # This will store the tokens for reuse
all_L0 = []

for k in tqdm(range(TOTAL_BATCHES)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
    
    # Store tokens for later reuse
    all_tokens_L0.append(tokens)
    
    # Run the model and store the activations 
    _, cache = finetune_model.run_with_cache(tokens, stop_at_layer=layer_num + 1, \
                                         names_filter=[sae_id])  # [N_BATCH, N_CONTEXT, D_MODEL]
    
    # Get the activations from the cache at the sae_id 
    original_activations = cache[sae_id]
    
    # Encode the activations with the SAE
    feature_activations = sae.encode_standard(original_activations) # the result of the encode method of the sae on the "sae_id" activations (a specific activation tensor of the LLM) 

    feature_activations.to('cpu')
    
    # Store the encoded activations
    all_L0.append(L0_loss(feature_activations))

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    del original_activations
    del feature_activations
    torch.cuda.empty_cache()

# Concatenate all tokens into a single tensor for reuse
all_tokens_L0 = torch.cat(all_tokens_L0)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 500/500 [00:39<00:00, 12.51it/s]


In [28]:
torch.tensor(all_L0).mean()

tensor(74.1725)

#### 4.2.2 Substitution Loss

In [29]:
# L0_loss(x, threshold=1e-8)
# get_substitution_loss(tokens, model, sae, sae_layer)

from tqdm import tqdm

all_tokens_SL = []  # This will store the tokens for reuse
all_SL = []

for k in tqdm(range(TOTAL_BATCHES)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
    
    # Store tokens for later reuse
    all_tokens_SL.append(tokens)
    
    # Store loss 
    all_SL.append(get_substitution_loss(tokens, finetune_model, sae, sae_id))

# Concatenate all tokens into a single tensor for reuse
all_tokens = torch.cat(all_tokens_SL)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 500/500 [02:35<00:00,  3.21it/s]


In [30]:
torch.tensor(all_SL).mean()

tensor(10.1172, dtype=torch.float16)