In [1]:
# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops
# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: mps


In [3]:
import os
from dotenv import load_dotenv

# Load environment variables from the .env file
load_dotenv()

# Access the PYTHONPATH variable
pythonpath = os.getenv('PYTHONPATH')

# Print to verify
print(f"PYTHONPATH: {pythonpath}")

datapath = pythonpath + "/data"


PYTHONPATH: /Users/tommasomencattini/Desktop/GitHub/SAE-Tuning-Merging


In [4]:
N_CONTEXT = 128 # should be equal to the context size of the dataset if it's pretokenized
N_BATCHES = 8
TOTAL_BATCHES = 10


GEMMA = True

if GEMMA == True:
    BASE_MODEL = "google/gemma-2b"
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
else:
    BASE_MODEL = "gpt2-small"
    FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'
    DATASET_NAME = "Skylion007/openwebtext"


SAVING_NAME = BASE_MODEL if "/" not in BASE_MODEL else BASE_MODEL.split("/")[-1]
# BASE_MODEL = "gpt2-small"
# FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'

# this is a tokenized language dataset that the base model's SAE was originally trained on
# DATASET_NAME = "Skylion007/openwebtext"

# DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"

# Or if we want to use Pile:
# DATASET_NAME = "monology/pile-uncopyrighted"

In [5]:
SAVING_NAME

'gemma-2b'

In [6]:
print(f"Base model: {SAVING_NAME}")
print(f"Dataset: {BASE_MODEL}")
print(f"Fine-tune model: {FINETUNE_MODEL}")

Base model: gemma-2b
Dataset: google/gemma-2b
Fine-tune model: shahdishank/gemma-2b-it-finetune-python-codes


In [7]:
LAYER_NUM = 6
SAE_LAYER = f'blocks.{LAYER_NUM}.hook_resid_pre'

In [8]:
from datasets import load_dataset

# Load the dataset in streaming mode
dataset = load_dataset(DATASET_NAME, split="train", streaming=True)

In [8]:
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2b into HookedTransformer


In [10]:
from sae_lens import LanguageModelSAERunnerConfig

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name=BASE_MODEL, # the huggingface model name
    dataset_path=DATASET_NAME,

    is_dataset_tokenized=False,
    streaming=True,

    # Activation Store Parameters
    store_batch_size_prompts=N_BATCHES,
    context_size=N_CONTEXT,

    # Misc
    device=device,
    seed=42,
)

Run name: 2048-L1-0.001-LR-0.0003-Tokens-2.000e+06
n_tokens_per_buffer (millions): 0.02048
Lower bound: n_contexts_per_buffer (millions): 0.00016
Total training steps: 488
Total wandb updates: 48
n_tokens_per_feature_sampling_window (millions): 1048.576
n_tokens_per_dead_feature_window (millions): 524.288
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 8.19e+06


In [11]:
from sae_lens import ActivationsStore

# Instantiate an activation store to easily sample tokenized batches from our dataset
activation_store = ActivationsStore.from_config(
    model=base_model, # the actual model once loaded
    cfg=cfg
)



1. Assume there’s an input tensors of shape `[N_BATCH, N_CONTEXT]`, containing a sample from the dataset in Task 1.
2. Run the base model with [run_with_cache](https://transformerlensorg.github.io/TransformerLens/generated/code/transformer_lens.hook_points.html#transformer_lens.hook_points.HookedRootModule.run_with_cache) on this input tensor, storing the activations in the cache object as it’s done on the Github notebook.
3. Save the activations tensor using torch.save
4. Similarly, run the finetune on the input tensor, storing & saving the activations
5. Return (save) two tensors of the shape `[N_BATCH, N_CONTEXT, N_ACTIVATIONS]`, one for each model.

In [12]:
from tqdm import tqdm

all_acts = []
all_tokens = []  # This will store the tokens for reuse

for k in tqdm(range(TOTAL_BATCHES)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
    
    # Store tokens for later reuse
    all_tokens.append(tokens)
    
    # Run the model and store the activations
    _, cache = base_model.run_with_cache(tokens, stop_at_layer=LAYER_NUM + 1, \
                                         names_filter=[SAE_LAYER])  # [N_BATCH, N_CONTEXT, D_MODEL]
    all_acts.append(cache[SAE_LAYER])

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    torch.cuda.empty_cache()

# Concatenate all feature activations into a single tensor
all_acts = torch.cat(all_acts)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT, D_MODEL]

# Concatenate all tokens into a single tensor for reuse
all_tokens = torch.cat(all_tokens)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

  0%|                                                                                                                                | 0/10 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.69it/s]


In [13]:
torch.save(all_tokens, datapath + f"/tokens_{SAVING_NAME}.pt")
torch.save(all_acts, datapath + f"/base_acts_{SAVING_NAME}.pt")

all_tokens.shape, all_acts.shape

(torch.Size([80, 128]), torch.Size([80, 128, 768]))

In [11]:
# Offload the first model from memory, but save its tokenizer
base_tokenizer = base_model.tokenizer

del base_model
torch.cuda.empty_cache()

base_tokenizer

GemmaTokenizerFast(name_or_path='google/gemma-2b', vocab_size=256000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<bos>', 'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<start_of_turn>', '<end_of_turn>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<eos>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("<bos>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	4: AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),
	5: AddedToken("<2mass>", rstrip=False, lstrip=False, single_wor

In [9]:
# Load the finetune model and its tokenizer
finetune_tokenizer = AutoTokenizer.from_pretrained(FINETUNE_MODEL)

tokenizer_config.json:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/522 [00:00<?, ?B/s]

In [None]:
finetune_model_hf = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL)
finetune_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, hf_model=finetune_model_hf, dtype=torch.float16)

finetune_tokenizer

In [16]:
# Load the stored tokens from the previous run (from the base model sampling)
all_tokens = torch.load(datapath + f"/tokens_{SAVING_NAME}.pt") # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

  all_tokens = torch.load(datapath + f"/tokens_{SAVING_NAME}.pt") # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]


## Tokenizers check
Here we'll check how different are the tokenizers of base vs finetuned models

In [12]:
base_vocab = base_tokenizer.get_vocab()
finetune_vocab = finetune_tokenizer.get_vocab()

In [15]:
import os
import logging

# Setup logs for mismatches between the base model tokenizer and the finetune one
log_dir = os.path.join(pythonpath, 'logs')
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

log_file = os.path.join(log_dir, f'{SAVING_NAME}_tokenizer_vocab_comparison_log.txt')
    
# Create a custom logger
logger = logging.getLogger('tokenizer_vocab_comparison')

# Clear any existing handlers to ensure no console logging
logger.handlers.clear()

# Set the log level
logger.setLevel(logging.INFO)

# Create file handler with UTF-8 encoding
file_handler = logging.FileHandler(log_file, encoding='utf-8')

# Set the logging format
formatter = logging.Formatter('%(asctime)s - %(message)s')
file_handler.setFormatter(formatter)

# Add only the file handler to the logger
logger.addHandler(file_handler)

# Disable propagation to prevent any parent loggers from printing to the console
logger.propagate = False

def compare_tokenizer_vocabularies(base_vocab, finetune_vocab):
    # 1. Compare the keys (words/tokens)
    base_keys = set(base_vocab.keys())
    finetune_keys = set(finetune_vocab.keys())
    
    # Keys that are in one tokenizer but not in the other
    only_in_base = base_keys - finetune_keys
    only_in_finetune = finetune_keys - base_keys
    
    logger.info("Keys only in base tokenizer:")
    for key in only_in_base:
        logger.info(f"  {key}")
    
    logger.info("\nKeys only in fine-tuned tokenizer:")
    for key in only_in_finetune:
        logger.info(f"  {key}")
    
    # 2. Compare the values (token ids)
    mismatched_values = {}
    for key in base_keys.intersection(finetune_keys):
        base_value = base_vocab[key]
        finetune_value = finetune_vocab[key]
        if base_value != finetune_value:
            mismatched_values[key] = (base_value, finetune_value)
    
    logger.info("\nKeys with mismatched token IDs:")
    for key, (base_value, finetune_value) in mismatched_values.items():
        logger.info(f"  {key}: Base ID = {base_value}, Fine-tune ID = {finetune_value}")

    return only_in_base, only_in_finetune, mismatched_values

In [17]:
only_in_base, only_in_finetune, mismatched_values = compare_tokenizer_vocabularies(base_vocab, finetune_vocab)
# Ensure the log is flushed
for handler in logger.handlers:
    handler.flush()

In [18]:
# Define variables based on results
base_vocab_size = len(base_vocab)
finetune_vocab_size = len(finetune_vocab)
only_in_base_size = len(only_in_base)
only_in_finetune_size = len(only_in_finetune)
mismatched_values_size = len(mismatched_values)

# Calculate good token counts
good_base_tokens_count = base_vocab_size - only_in_base_size - mismatched_values_size
good_finetune_tokens_count = finetune_vocab_size - only_in_finetune_size - mismatched_values_size

# Calculate percentages
good_base_tokens_percent = good_base_tokens_count / base_vocab_size * 100
good_finetune_tokens_percent = good_finetune_tokens_count / finetune_vocab_size * 100

print('Percentage of good tokens in the base vocab: ', good_base_tokens_percent)
print('Percentage of good tokens in the finetune vocab: ', good_finetune_tokens_percent)

Percentage of good tokens in the base vocab:  99.99960937499999
Percentage of good tokens in the finetune vocab:  99.99960937499999


In [19]:
# Summary statistics
summary_statistics = {
    "Base Tokenizer Size": base_vocab_size,
    "Fine-tune Tokenizer Size": finetune_vocab_size,
    "Keys only in Base": only_in_base_size,
    "Keys only in Fine-tune": only_in_finetune_size,
    "Keys with Mismatched Token IDs": mismatched_values_size,
    "Good Tokens in Base (%)": good_base_tokens_percent,
    "Good Tokens in Fine-tune (%)": good_finetune_tokens_percent
}

# Create a pandas DataFrame for display
summary_df = pd.DataFrame(list(summary_statistics.items()), columns=["Metric", "Value"])
summary_df

Unnamed: 0,Metric,Value
0,Base Tokenizer Size,256000.0
1,Fine-tune Tokenizer Size,256000.0
2,Keys only in Base,1.0
3,Keys only in Fine-tune,1.0
4,Keys with Mismatched Token IDs,0.0
5,Good Tokens in Base (%),99.999609
6,Good Tokens in Fine-tune (%),99.999609


In [20]:
only_in_base_size + mismatched_values_size

1

In [21]:
# import logging
# import os
# import sys

# # Setup logs for mismatches between the base model tokenizer and the finetune one
# log_dir = os.path.join(pythonpath, 'logs')
# if not os.path.exists(log_dir):
#     os.makedirs(log_dir)
    
# log_file = os.path.join(log_dir, f'{SAVING_NAME}_tokenizer_mismatch_log.txt')

# # Create a custom logger
# logger = logging.getLogger('tokenizer_mismatch_logger')

# # Set the log level
# logger.setLevel(logging.INFO)

# # Create file handler with UTF-8 encoding
# file_handler = logging.FileHandler(log_file, encoding='utf-8')

# # Set the logging format
# formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
# file_handler.setFormatter(formatter)

# # Add only the file handler to the logger
# if not logger.handlers:  # To prevent adding multiple handlers
#     logger.addHandler(file_handler)

# # Disable propagation to avoid logging output to console
# logger.propagate = False

# def check_tokenizers(tokens, current_batch,
#                      base_tokenizer=base_tokenizer, finetune_tokenizer=finetune_tokenizer):
#     # Decode tokens back to text using base and fine-tune tokenizers
#     base_decoded_text = base_tokenizer.batch_decode(tokens, skip_special_tokens=True)
#     finetune_decoded_text = finetune_tokenizer.batch_decode(tokens, skip_special_tokens=True)
    
#     # Compare decoded texts and log mismatches
#     for i, (base_text, finetune_text) in enumerate(zip(base_decoded_text, finetune_decoded_text)):
#         if base_text != finetune_text:
#             # Log the mismatch as a warning (shown in both file and console)
#             logger.warning(f"Batch {current_batch} - Mismatch at token index {i} in current batch.")
            
#             # Log the exact tokens for detailed investigation (only in the file)
#             for j, (base_token, finetune_token) in enumerate(zip(base_tokenizer.convert_ids_to_tokens(tokens[i]), 
#                                                                  finetune_tokenizer.convert_ids_to_tokens(tokens[i]))):
#                 if base_token != finetune_token:
#                     logger.info(f"Batch {current_batch}, Sequence {i}, Token position {j}: "
#                                 f"Base tokenizer = '{base_token}', Fine-tune tokenizer = '{finetune_token}'")

In [22]:
# Initialize an empty list to store activations from the fine-tuned model
all_acts = []

# Split the tokens back into batches and run the fine-tuned model
for k in tqdm(range(TOTAL_BATCHES)):
    # Calculate the start and end indices for the current batch
    start_idx = k * N_BATCHES
    end_idx = (k + 1) * N_BATCHES
    
    # Get the corresponding batch of tokens from all_tokens
    tokens = all_tokens[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]
    # Check if the tokens map to the same text using the finetune model tokenizer, as with the base model tokenizer
    # check_tokenizers(tokens, k)
    
    # Run the fine-tuned model and store the activations
    _, cache = finetune_model.run_with_cache(tokens, stop_at_layer=LAYER_NUM + 1, \
                                             names_filter=[SAE_LAYER])  # [N_BATCH, N_CONTEXT, D_MODEL]
    all_acts.append(cache[SAE_LAYER])

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    torch.cuda.empty_cache()

# Concatenate all activations from the fine-tuned model into a single tensor
all_acts = torch.cat(all_acts)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT, D_MODEL]
all_acts.shape

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 61.14it/s]


torch.Size([80, 128, 768])

In [23]:
torch.save(all_acts, datapath + f"/finetune_acts_{SAVING_NAME}.pt")