In [1]:

# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops
# Imports for displaying vis in Colab / notebook

torch.set_grad_enabled(False)

# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: mps


In [2]:
import os
from dotenv import load_dotenv

# Load environment variables from the .env file
load_dotenv()

# Access the PYTHONPATH variable
pythonpath = os.getenv('PYTHONPATH')

# Print to verify
print(f"PYTHONPATH: {pythonpath}")


PYTHONPATH: /Users/tommasomencattini/Desktop/GitHub/SAE-Tuning-Merging


In [3]:
datapath = pythonpath + "/data"

In [13]:
N_CONTEXT = 1024 # don't change this: it's the max context length, used by the original Gemma-2 dataset
N_BATCHES = 8
TOTAL_BATCHES = 20

USE_BASE_TOKENIZER=True
GEMMA = True

if GEMMA == True:
    BASE_MODEL = "google/gemma-2b"
    FINETUNE_MODEL = 'shahdishank/gemma-2b-it-finetune-python-codes'
    DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"
else:
    BASE_MODEL = "gpt2-small"
    FINETUNE_MODEL = "gpt2-small" #'pierreguillou/gpt2-small-portuguese'
    DATASET_NAME = "Skylion007/openwebtext"


SAVING_NAME = BASE_MODEL if "/" not in BASE_MODEL else BASE_MODEL.split("/")[-1]
# BASE_MODEL = "gpt2-small"
# FINETUNE_MODEL = 'pierreguillou/gpt2-small-portuguese'

# this is a tokenized language dataset that the base model's SAE was originally trained on
#DATASET_NAME = "Skylion007/openwebtext"

# DATASET_NAME = "ctigges/openwebtext-gemma-1024-cl"

# Or if we want to use Pile:
# DATASET_NAME = "monology/pile-uncopyrighted"

In [14]:
SAVING_NAME

'gemma-2b'

In [6]:
print(f"Base model: {SAVING_NAME}")
print(f"Dataset: {BASE_MODEL}")
print(f"Fine-tune model: {FINETUNE_MODEL}")

Base model: gemma-2b
Dataset: google/gemma-2b
Fine-tune model: shahdishank/gemma-2b-it-finetune-python-codes


In [7]:
LAYER_NUM = 6
SAE_LAYER = f'blocks.{LAYER_NUM}.hook_resid_pre'

In [8]:
from datasets import load_dataset

# Load the dataset in streaming mode
dataset = load_dataset(DATASET_NAME, split="train", streaming=True)

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

In [9]:
from sae_lens import SAE, HookedSAETransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

base_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, dtype=torch.float16)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2b into HookedTransformer


In [15]:
from sae_lens import LanguageModelSAERunnerConfig

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name=BASE_MODEL, # the huggingface model name
    dataset_path=DATASET_NAME,

    is_dataset_tokenized=True,
    streaming=True,

    # Activation Store Parameters
    store_batch_size_prompts=N_BATCHES,
    context_size=N_CONTEXT,

    # Misc
    device=device,
    seed=42,
)

Run name: 2048-L1-0.001-LR-0.0003-Tokens-2.000e+06
n_tokens_per_buffer (millions): 0.16384
Lower bound: n_contexts_per_buffer (millions): 0.00016
Total training steps: 488
Total wandb updates: 48
n_tokens_per_feature_sampling_window (millions): 8388.608
n_tokens_per_dead_feature_window (millions): 4194.304
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 8.19e+06


In [16]:
from sae_lens import ActivationsStore

# Instantiate an activation store to easily sample tokenized batches from our dataset
activation_store = ActivationsStore.from_config(
    model=base_model, # the actual model once loaded
    cfg=cfg
)

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

1. Assume there’s an input tensors of shape `[N_BATCH, N_CONTEXT]`, containing a sample from the dataset in Task 1.
2. Run the base model with [run_with_cache](https://transformerlensorg.github.io/TransformerLens/generated/code/transformer_lens.hook_points.html#transformer_lens.hook_points.HookedRootModule.run_with_cache) on this input tensor, storing the activations in the cache object as it’s done on the Github notebook.
3. Save the activations tensor using torch.save
4. Similarly, run the finetune on the input tensor, storing & saving the activations
5. Return (save) two tensors of the shape `[N_BATCH, N_CONTEXT, N_ACTIVATIONS]`, one for each model.

In [17]:
base_model.cfg.device # check the device 

'mps'

In [18]:
from tqdm import tqdm

all_acts = []
all_tokens = []  # This will store the tokens for reuse

for k in tqdm(range(TOTAL_BATCHES)):
    # Get a batch of tokens from the dataset
    tokens = activation_store.get_batch_tokens()  # [N_BATCH, N_CONTEXT]
    
    # Store tokens for later reuse
    all_tokens.append(tokens)
    
    # Run the model and store the activations
    _, cache = base_model.run_with_cache(tokens, stop_at_layer=LAYER_NUM + 1, \
                                         names_filter=[SAE_LAYER])  # [N_BATCH, N_CONTEXT, D_MODEL]
    all_acts.append(cache[SAE_LAYER])

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    torch.cuda.empty_cache()

# Concatenate all feature activations into a single tensor
all_acts = torch.cat(all_acts)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT, D_MODEL]

# Concatenate all tokens into a single tensor for reuse
all_tokens = torch.cat(all_tokens)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

100%|██████████| 20/20 [01:16<00:00,  3.81s/it]


In [19]:
torch.save(all_tokens, datapath + f"/tokens_{SAVING_NAME}.pt")
all_tokens.shape, all_acts.shape

(torch.Size([160, 1024]), torch.Size([160, 1024, 2048]))

In [20]:
torch.save(all_acts, datapath + f"/base_acts_{SAVING_NAME}.pt")

# Offload the first model from memory
del base_model
torch.cuda.empty_cache()

In [24]:
# finetune_tokenizer = AutoTokenizer.from_pretrained(FINETUNE_MODEL)
finetune_model_hf = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL)
finetune_model = HookedSAETransformer.from_pretrained(BASE_MODEL, device=device, hf_model=finetune_model_hf, dtype=torch.float16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2b into HookedTransformer


In [25]:
# Load the stored tokens from the previous run (from the base model sampling)
# [TOTAL_BATCHES * N_BATCH, N_CONTEXT]

# Initialize an empty list to store activations from the fine-tuned model
all_acts = []

# Split the tokens back into batches and run the fine-tuned model
for k in tqdm(range(TOTAL_BATCHES)):
    # Calculate the start and end indices for the current batch
    start_idx = k * N_BATCHES
    end_idx = (k + 1) * N_BATCHES
    
    # Get the corresponding batch of tokens from all_tokens
    tokens = all_tokens[start_idx:end_idx]  # [N_BATCH, N_CONTEXT]
    # TODO: check if the tokens map to the same text using the finetune model tokenizer, as with the base model tokenizer
    
    # Run the fine-tuned model and store the activations
    _, cache = finetune_model.run_with_cache(tokens, stop_at_layer=LAYER_NUM + 1, \
                                             names_filter=[SAE_LAYER])  # [N_BATCH, N_CONTEXT, D_MODEL]
    all_acts.append(cache[SAE_LAYER])

    # Explicitly free up memory by deleting the cache and emptying the CUDA cache
    del cache
    torch.cuda.empty_cache()

# Concatenate all activations from the fine-tuned model into a single tensor
all_acts = torch.cat(all_acts)  # [TOTAL_BATCHES * N_BATCH, N_CONTEXT, D_MODEL]
all_acts.shape

100%|██████████| 20/20 [01:15<00:00,  3.78s/it]


torch.Size([160, 1024, 2048])

In [27]:
torch.save(all_acts, datapath + f"/finetune_acts_{SAVING_NAME}.pt")