# Model config

In [1]:
# define the model to work with
MODEL = 'GEMMA' # GEMMA, GPT2

if MODEL == 'GEMMA':
    RELEASE = 'gemma-2b-res-jb'
    BASE_MODEL = "google/gemma-2b"
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
    FINETUNE_PATH = None
    BASE_TOKENIZER_NAME = BASE_MODEL

    hook_part = 'post'
    layer_num = 6
elif MODEL == 'GPT2':
    RELEASE = 'gpt2-small-res-jb'
    BASE_MODEL = "gpt2-small"
    FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'
    FINETUNE_PATH = None
    DATASET_NAME = "Skylion007/openwebtext"
    BASE_TOKENIZER_NAME = BASE_MODEL

    hook_part = 'pre'
    layer_num = 6
elif MODEL == 'MISTRAL':
    RELEASE = 'mistral-7b-res-wg'
    BASE_MODEL = "mistral-7b"
    DATASET_NAME = "monology/pile-uncopyrighted"
    BASE_TOKENIZER_NAME = 'mistralai/Mistral-7B-v0.1'

    FINETUNE_MODEL = 'meta-math/MetaMath-Mistral-7B' #DeepMount00/Mistral-Ita-7b
    FINETUNE_PATH = f'/content/drive/My Drive/Finetunes/MetaMath-Mistral-7B'

    hook_part = 'pre'
    layer_num = 8

SAE_HOOK = f'blocks.{layer_num}.hook_resid_{hook_part}'

In [2]:
try:
    import google.colab
    IN_COLAB = True
    %pip install sae-lens transformer-lens

    from google.colab import drive
    drive.mount('/content/drive')
except ImportError:
    IN_COLAB = False

Collecting sae-lens
  Downloading sae_lens-3.22.2-py3-none-any.whl.metadata (5.1 kB)
Collecting transformer-lens
  Downloading transformer_lens-2.7.0-py3-none-any.whl.metadata (12 kB)
Collecting automated-interpretability<1.0.0,>=0.0.5 (from sae-lens)
  Downloading automated_interpretability-0.0.6-py3-none-any.whl.metadata (778 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae-lens)
  Downloading babe-0.0.7-py3-none-any.whl.metadata (10 kB)
Collecting datasets<3.0.0,>=2.17.1 (from sae-lens)
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting matplotlib<4.0.0,>=3.8.3 (from sae-lens)
  Downloading matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting plotly-express<0.5.0,>=0.4.1 (from sae-lens)
  Downloading plotly_express-0.4.1-py2.py3-none-any.whl.metadata (1.7 kB)
Collecting pytest-profiling<2.0.0,>=1.7.0 (from sae-lens)
  Downloading pytest_profiling-1.7.0-py2.py3-none-any.whl.metadata (12 kB)
Collecting python-dot

In [3]:
# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops
from datasets import load_dataset
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from sae_lens import LanguageModelSAERunnerConfig
from sae_lens import ActivationsStore
import os
from dotenv import load_dotenv
import typing
from dataclasses import dataclass
from tqdm import tqdm
import logging

# GPU memory saver (this script doesn't need gradients computation)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7d4475eab700>

In [4]:
from pathlib import Path

def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

def get_env_var():
    if IN_COLAB:
      return Path('./'), Path('/content/drive/My Drive/sae_data')

    # Load environment variables from the .env file
    load_dotenv()
    # Access the PYTHONPATH variable
    pythonpath = Path(os.getenv('PYTHONPATH'))
    # Print to verify
    print(f"PYTHONPATH: {pythonpath}")
    datapath = pythonpath / 'data'
    print(f"DATAPATH: {datapath}")

    return pythonpath, datapath

import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

# Loading finetune utility

In [16]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def adjust_state_dict(model, base_model_vocab_size):
    """Adjust the state_dict of the model to match the base model's vocab size."""
    state_dict = model.state_dict()

    # Adjust the embedding matrix
    if state_dict['model.embed_tokens.weight'].shape[0] > base_model_vocab_size:
        state_dict['model.embed_tokens.weight'] = state_dict['model.embed_tokens.weight'][:base_model_vocab_size, :]

    # Adjust the unembedding (lm_head) matrix
    if state_dict['lm_head.weight'].shape[0] > base_model_vocab_size:
        state_dict['lm_head.weight'] = state_dict['lm_head.weight'][:base_model_vocab_size, :]

    return state_dict

def load_hf_model(path, base_model=BASE_MODEL, device='cuda', dtype=None):
    tokenizer = AutoTokenizer.from_pretrained(path)
    model = AutoModelForCausalLM.from_pretrained(path)

    # Adjust the model's state dict to match the base model's vocab size
    if base_model == 'mistral-7b':
      base_model_vocab_size = 32000  # Mistral 7B base vocab size
      adjusted_state_dict = adjust_state_dict(model, base_model_vocab_size)

      # Adjust model architecture to match the new vocab size
      model.resize_token_embeddings(base_model_vocab_size)

      # Load the adjusted state dict back into the model
      model.load_state_dict(adjusted_state_dict, strict=False)

    # Now load the fine-tuned model into the HookedSAETransformer
    finetune_model = HookedSAETransformer.from_pretrained(
        base_model, device=device, hf_model=model, dtype=dtype
    )

    del model  # offload the HF model as it's already wrapped into HookedSAETransformer (finetune_model)
    clear_cache()

    return tokenizer, finetune_model

In [6]:
@dataclass
class TokenizerComparisonConfig:
    # LLMs
    BASE_MODEL_TOKENIZER: str
    FINETUNE_MODEL_TOKENIZER: str

@dataclass
class ActivationStoringConfig:
    # LLMs
    BASE_MODEL: str
    FINETUNE_MODEL: str

    # dataset
    DATASET_NAME: str
    N_CONTEXT: int
    N_BATCHES: int
    TOTAL_BATCHES: int

    # SAE configs
    LAYER_NUM : int
    SAE_HOOK : str

    # misc
    DTYPE: torch.dtype = torch.float16
    IS_DATASET_TOKENIZED: bool = False
    FINETUNE_PATH: typing.Optional[str] = None

In [7]:
def get_activations_and_tokens(model, LAYER_NUM, SAE_HOOK, TOTAL_BATCHES, DATAPATH, SAVING_NAME_MODEL, SAVING_NAME_DS, N_BATCHES,
                               tokens_loading_path=None, activation_store=None, save=True, tokens_already_loaded=False):
    """
    Get activations and tokens (of which we took the activations) through the model (base or finetuned one)
    """
    if not tokens_already_loaded:
        assert activation_store is not None, "The activation store must be passed for sampling when tokens_already_loaded is False"

        try:
            # If the tokens and activations are already computed, return them
            all_tokens = torch.load(DATAPATH / f"tokens_{SAVING_NAME_MODEL}_on_{SAVING_NAME_DS}.pt")
            all_acts = torch.load(DATAPATH / f"base_acts_{SAVING_NAME_MODEL}_on_{SAVING_NAME_DS}.pt")
            return all_acts, all_tokens
        except:
            # Otherwise compute everything from scratch
            all_tokens = []  # This will store the tokens for reuse
            all_acts = []

            for k in tqdm(range(TOTAL_BATCHES)):
                # Get a batch of tokens from the dataset
                tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]

                # Store tokens for later reuse
                all_tokens.append(tokens)

                # Run the model and store the activations
                _, cache = model.run_with_cache(tokens, stop_at_layer=LAYER_NUM + 1, \
                                                names_filter=[SAE_HOOK])  # [N_BATCH, N_CONTEXT, D_MODEL]
                all_acts.append(cache[SAE_HOOK])

                # Explicitly free up memory by deleting the cache and emptying the CUDA cache
                del cache
                clear_cache()

            # Concatenate all feature activations into a single tensor
            all_acts = torch.cat(all_acts)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT, D_MODEL]

            # Concatenate all tokens into a single tensor for reuse
            all_tokens = torch.cat(all_tokens)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

            torch.save(all_tokens, DATAPATH / f"tokens_{SAVING_NAME_MODEL}_on_{SAVING_NAME_DS}.pt")

            if save:
                torch.save(all_acts, DATAPATH / f"base_acts_{SAVING_NAME_MODEL}_on_{SAVING_NAME_DS}.pt")

            return all_acts, all_tokens

    # Otherwise, we're dealing with the finetune model and want to load the same tokens sample
    assert tokens_loading_path is not None, "You must provide a path to the sample of tokens for the finetune model when calling this method with tokens_already_loaded=True"

    try:
        all_tokens = torch.load(tokens_loading_path)
    except:
        raise ValueError('A sample of tokens for the finetune model must be already saved at the `all_tokens` path when calling this method with tokens_already_loaded=True')

    try:
        all_acts = torch.load(DATAPATH / f"finetune_acts_{SAVING_NAME_MODEL}_on_{SAVING_NAME_DS}.pt")
        return all_acts, all_tokens
    except:
        all_acts = []
        # Split the tokens back into batches and run the fine-tuned model
        for k in tqdm(range(TOTAL_BATCHES)):
            # Calculate the start and end indices for the current batch
            start_idx = k * N_BATCHES
            end_idx = (k + 1) * N_BATCHES

            # Get the corresponding batch of tokens from all_tokens
            tokens = all_tokens[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]

            # Run the fine-tuned model and store the activations
            _, cache = model.run_with_cache(tokens, stop_at_layer=LAYER_NUM + 1, \
                                                    names_filter=[SAE_HOOK])  # [N_BATCH, N_CONTEXT, D_MODEL]
            all_acts.append(cache[SAE_HOOK])

            # Explicitly free up memory by deleting the cache and emptying the CUDA cache
            del cache
            clear_cache()

        # Concatenate all activations from the fine-tuned model into a single tensor
        all_acts = torch.cat(all_acts)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT, D_MODEL]

        if save:
            torch.save(all_acts, DATAPATH / f"finetune_acts_{SAVING_NAME_MODEL}_on_{SAVING_NAME_DS}.pt")

        return all_acts, all_tokens

In [8]:
def compare_tokenizers(cfg: TokenizerComparisonConfig):
    base_model_tok_name = cfg.BASE_MODEL_TOKENIZER
    finetune_model_tok_name = cfg.FINETUNE_MODEL_TOKENIZER
    saving_name_ft = cfg.FINETUNE_MODEL_TOKENIZER if "/" not in cfg.FINETUNE_MODEL_TOKENIZER else cfg.FINETUNE_MODEL_TOKENIZER.split("/")[-1]

    _, datapath = get_env_var()
    saving_path = datapath / "log" /f'{saving_name_ft}_tokenizer_vocab_comparison_log.txt'
    if not os.path.exists(datapath / "log"):
        os.makedirs(datapath / "log")

    base_tokenizer = AutoTokenizer.from_pretrained(base_model_tok_name)
    finetune_tokenizer = AutoTokenizer.from_pretrained(finetune_model_tok_name)

    # Setup the file-logger
    logger = logging.getLogger('tokenizer_vocab_comparison')

    # Clear any existing handlers to ensure no console logging
    logger.handlers.clear()

    # Set the log level
    logger.setLevel(logging.INFO)

    # Create file handler with UTF-8 encoding
    file_handler = logging.FileHandler(saving_path, encoding='utf-8')

    # Set the logging format
    formatter = logging.Formatter('%(asctime)s - %(message)s')
    file_handler.setFormatter(formatter)

    # Add only the file handler to the logger
    logger.addHandler(file_handler)

    # Disable propagation to prevent any parent loggers from printing to the console
    logger.propagate = False

    # Extract vocabs
    base_vocab = base_tokenizer.get_vocab()
    finetune_vocab = finetune_tokenizer.get_vocab()

    # Run the vocab comparison code
    # 1. Compare the keys (words/tokens)
    base_keys = set(base_vocab.keys())
    finetune_keys = set(finetune_vocab.keys())

    # Keys that are in one tokenizer but not in the other
    only_in_base = base_keys - finetune_keys
    only_in_finetune = finetune_keys - base_keys

    logger.info("Keys only in base tokenizer:")
    for key in only_in_base:
        logger.info(f"  {key}")

    logger.info("\nKeys only in fine-tuned tokenizer:")
    for key in only_in_finetune:
        logger.info(f"  {key}")

    # 2. Compare the values (token ids)
    mismatched_values = {}
    for key in base_keys.intersection(finetune_keys):
        base_value = base_vocab[key]
        finetune_value = finetune_vocab[key]
        if base_value != finetune_value:
            mismatched_values[key] = (base_value, finetune_value)

    logger.info("\nKeys with mismatched token IDs:")
    for key, (base_value, finetune_value) in mismatched_values.items():
        logger.info(f"  {key}: Base ID = {base_value}, Fine-tune ID = {finetune_value}")

    # Ensure the log is flushed
    for handler in logger.handlers:
        handler.flush()

    # Define variables based on results
    base_vocab_size = len(base_vocab)
    finetune_vocab_size = len(finetune_vocab)
    only_in_base_size = len(only_in_base)
    only_in_finetune_size = len(only_in_finetune)
    mismatched_values_size = len(mismatched_values)

    # Calculate good token counts
    good_base_tokens_count = base_vocab_size - only_in_base_size - mismatched_values_size
    good_finetune_tokens_count = finetune_vocab_size - only_in_finetune_size - mismatched_values_size

    # Calculate percentages
    good_base_tokens_percent = good_base_tokens_count / base_vocab_size * 100
    good_finetune_tokens_percent = good_finetune_tokens_count / finetune_vocab_size * 100

    # Summary statistics
    summary_statistics = {
        "Base Tokenizer Size": base_vocab_size,
        "Fine-tune Tokenizer Size": finetune_vocab_size,
        "Keys only in Base": only_in_base_size,
        "Keys only in Fine-tune": only_in_finetune_size,
        "Keys with Mismatched Token IDs": mismatched_values_size,
        "Good Tokens in Base (%)": good_base_tokens_percent,
        "Good Tokens in Fine-tune (%)": good_finetune_tokens_percent
    }

    # Create a pandas DataFrame for display
    summary_df = pd.DataFrame(list(summary_statistics.items()), columns=["Metric", "Value"])
    logger.info(str(summary_df))

    return summary_df

In [9]:
tokenizer_cfg = TokenizerComparisonConfig(BASE_TOKENIZER_NAME, FINETUNE_MODEL)

tokenizer_comp_df = compare_tokenizers(tokenizer_cfg)
tokenizer_comp_df

tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/522 [00:00<?, ?B/s]

Unnamed: 0,Metric,Value
0,Base Tokenizer Size,256000.0
1,Fine-tune Tokenizer Size,256000.0
2,Keys only in Base,1.0
3,Keys only in Fine-tune,1.0
4,Keys with Mismatched Token IDs,0.0
5,Good Tokens in Base (%),99.999609
6,Good Tokens in Fine-tune (%),99.999609


# Activations config

In [10]:
N_CONTEXT = 1024 # number of context tokens to consider
N_BATCHES = 5 # number of batches to consider
TOTAL_BATCHES = 50

In [11]:
cfg = ActivationStoringConfig(BASE_MODEL, FINETUNE_MODEL, DATASET_NAME,
                                          N_CONTEXT, N_BATCHES, TOTAL_BATCHES,
                                          layer_num, SAE_HOOK, FINETUNE_PATH=FINETUNE_PATH)
# STEP 1: Get the device and the python and datapath
device = get_device()
_, datapath = get_env_var()

saving_name_base = cfg.BASE_MODEL if "/" not in cfg.BASE_MODEL else cfg.BASE_MODEL.split("/")[-1]
saving_name_ft = cfg.FINETUNE_MODEL if "/" not in cfg.FINETUNE_MODEL else cfg.FINETUNE_MODEL.split("/")[-1]
saving_name_ds = cfg.DATASET_NAME if "/" not in cfg.DATASET_NAME else cfg.DATASET_NAME.split("/")[-1]


In [12]:
# STEP 2: Init the HookedSAETransformer
base_model = HookedSAETransformer.from_pretrained(cfg.BASE_MODEL, device=device, dtype=cfg.DTYPE)

# STEP 3: load the config for the activation store
activation_store_cfg = LanguageModelSAERunnerConfig(
        # Data Generating Function (Model + Training Distibuion)
        model_name=cfg.BASE_MODEL,
        dataset_path=cfg.DATASET_NAME,
        is_dataset_tokenized=cfg.IS_DATASET_TOKENIZED,
        streaming=True,
        # Activation Store Parameters
        store_batch_size_prompts=cfg.N_BATCHES,
        context_size=cfg.N_CONTEXT,
        # Misc
        device=device,
        seed=42,
    )

# STEP 4: Instantiate an activation store to easily sample tokenized batches from our dataset
activation_store = ActivationsStore.from_config(
        model=base_model,
        cfg=activation_store_cfg
)

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]



Loaded pretrained model google/gemma-2b into HookedTransformer
Run name: 2048-L1-0.001-LR-0.0003-Tokens-2.000e+06
n_tokens_per_buffer (millions): 0.1024
Lower bound: n_contexts_per_buffer (millions): 0.0001
Total training steps: 488
Total wandb updates: 48
n_tokens_per_feature_sampling_window (millions): 8388.608
n_tokens_per_dead_feature_window (millions): 4194.304
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 8.19e+06


Downloading readme:   0%|          | 0.00/331 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

sae_lens.json:   0%|          | 0.00/340 [00:00<?, ?B/s]

In [13]:
# STEP 5: Get all activations and tokens through base model
all_acts, all_tokens = get_activations_and_tokens(base_model, cfg.LAYER_NUM, cfg.SAE_HOOK, cfg.TOTAL_BATCHES, datapath,
                                                  saving_name_base, saving_name_ds, cfg.N_BATCHES, activation_store=activation_store)
all_acts.shape

  all_tokens = torch.load(DATAPATH / f"tokens_{SAVING_NAME_MODEL}_on_{SAVING_NAME_DS}.pt")
100%|██████████| 50/50 [00:32<00:00,  1.56it/s]


In [None]:
# STEP 6: Offload the first model from memory
del base_model, activation_store # also delete activation store as it has base_model captured as a parameter
clear_cache()

In [17]:
# STEP 7: Load the finetuned model
finetune_tokenizer, finetune_model = load_hf_model(cfg.FINETUNE_PATH if cfg.FINETUNE_PATH is not None else cfg.FINETUNE_MODEL,
                                                    device=device, dtype=cfg.DTYPE)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2b into HookedTransformer


In [18]:
# STEP 8: Get all activations through finetuned model
# We should use the same sample of tokens as in the first get_activations_and_tokens() call
tokens_loading_path = datapath / f"tokens_{saving_name_base}_on_{saving_name_ds}.pt"
all_acts_finetuned, all_tokens = get_activations_and_tokens(finetune_model, cfg.LAYER_NUM, cfg.SAE_HOOK, cfg.TOTAL_BATCHES, datapath,
                                                            saving_name_ft, saving_name_ds, cfg.N_BATCHES, tokens_already_loaded=True,
                                                            tokens_loading_path=tokens_loading_path)

  all_tokens = torch.load(tokens_loading_path)
  all_acts = torch.load(DATAPATH / f"finetune_acts_{SAVING_NAME_MODEL}_on_{SAVING_NAME_DS}.pt")
100%|██████████| 50/50 [00:32<00:00,  1.52it/s]


In [22]:
all_acts_finetuned.shape

torch.Size([250, 1024, 2048])