In [None]:
def install_dependencies():
    ! rm -rf sae
    ! rm -rf TinySQL
    ! git clone https://github.com/amirabdullah19852020/sae.git
    ! cd sae && pip install .
    ! git clone https://github.com/withmartian/TinySQL.git
    ! cd TinySQL && pip install .

install_dependencies()

In [None]:
from dataclasses import dataclass
from functools import partial
import json
import os
import torch
import wandb

from datasets import load_dataset, concatenate_datasets
from huggingface_hub import upload_folder
from tqdm import tqdm

from sae import SaeConfig, SaeTrainer, TrainConfig
from textwrap import dedent
from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer
from TinySQL import sql_interp_model_location
from TinySQL.load_data.load_model import free_memory

wandb_project_name = "sql_interp_saes"
sae_repo_name = "withmartian/sql_interp_saes"

wandb.init(project=wandb_project_name)

In [None]:
all_cs_num = [1, 2, 3]
all_model_num = [1, 2, 3]

# This is essentially the same as "num_epochs" or number of times the SAE sees each data point.
dataset_multiplication_factor = 7
expansion_factor = 15
max_seq_len = 450
batch_size = 16
all_k = [256]

@dataclass
class ModelDatasetFullName:
    model_id: str
    dataset_id: str
    full_dataset_name: str
    full_model_name: str
    syn: bool
    k: int

    def load_model(self):
        model = AutoModelForCausalLM.from_pretrained(full_model_name)
        return model

    def get_model_abbrev(self):
        model_name_split = self.full_model_name.split("/")[1]
        model_abbrev = f"saes_{model_name_split}_syn={self.syn}"
        return model_abbrev

# 1, 2 and 3 map to TinyStores, Qwen and Llama respectively.
# Set the below list to any of [1,2,3] based on what you need to train.
selected_model_ids = [2]

selected_syn = [True]

# From cs1 through to cs3.
# Need to figure out a dataset for "base" SAE.
dataset_ids = [1, 2, 3]

alpaca_prompt = "### Instruction: {} ### Context: {} ### Response: {}"

In [None]:
def map_text(input_features):
    instruction = input_features['english_prompt']
    context = input_features['create_statement']
    response = input_features['sql_statement']
    final_prompt = alpaca_prompt.format(instruction, context, response)
    output = {"text": final_prompt}
    return output


# Define the tokenization function
def tokenize_function(max_seq_len, tokenizer, example):
    return tokenizer(
        example["text"],
        padding="max_length",  # Pad to the max sequence length
        truncation=True,       # Truncate sequences longer than max_length
        max_length=max_seq_len,    # Set maximum sequence length
        return_tensors='pt'
    )

In [None]:
all_model_dataset_full_names = []

for model_id in selected_model_ids:
    for dataset_id in dataset_ids:
        for syn in selected_syn:
            for k in all_k:
                full_model_name = sql_interp_model_location(model_id, dataset_id, synonym=syn)
                if syn:
                    full_dataset_name = f"withmartian/cs{dataset_id}_dataset_synonyms"
                else:
                    full_dataset_name = f"withmartian/cs{dataset_id}_dataset"
    
                model_dataset_full_name = ModelDatasetFullName(
                    model_id=model_id, dataset_id=dataset_id,
                    full_dataset_name=full_dataset_name, full_model_name=full_model_name, syn=syn, k=k
                )
    
                all_model_dataset_full_names.append(model_dataset_full_name)

In [None]:
all_model_dataset_full_names

In [None]:
def get_hookpoints(model_id: int):
    model_id = int(model_id)
    if model_id == 1:
        hookpoints= ["transformer.h.*.attn", "transformer.h.*.mlp"]
        return hookpoints
    # Qwen and LLama-3 have same module naming conventions.
    elif model_id == 2:
        hookpoints = [f"model.layers.{i}.mlp" for i in range(5, 20)] + [f"model.layers.{i}.self_attn" for i in range(5, 20)]
    elif model_id == 3:
        hookpoints = [f"model.layers.{i}.mlp" for i in range(7, 18)] + [f"model.layers.{i}.self_attn" for i in range(7, 18)]
        return hookpoints
    else:
        raise Exception(f"Hookpoints not added yet for model id {model_id}")

In [None]:
def get_sae_config(model_dataset_full_name: ModelDatasetFullName, model_abbrev=None):

    k = model_dataset_full_name.k
    sae_config = SaeConfig(expansion_factor=expansion_factor, k=k)
    hookpoints = get_hookpoints(model_dataset_full_name.model_id)
    print(hookpoints)

    dataset_name = model_dataset_full_name.full_dataset_name
    model_name = model_dataset_full_name.full_model_name

    if model_abbrev is None:
        model_abbrev = model_dataset_full_name.get_model_abbrev()

    run_name = f"{model_abbrev}/k={k}"

    custom_config = {
        "model_name": model_name, "dataset_name": dataset_name,
        "model_abbrev":model_abbrev
    }

    train_config =  TrainConfig(
        sae=sae_config, hookpoints=hookpoints, run_name=run_name, batch_size=batch_size, save_every=100000
    )
    return train_config, custom_config

In [None]:
all_model_dataset_full_names

In [None]:
# For debugging.
get_sae_config(all_model_dataset_full_names[0])

In [None]:
def get_metrics_for_wandb_run_and_layer(run_name, layer_name):
    api = wandb.Api()
    runs = api.runs(path="sae")
    matching_runs = [run for run in runs if run.name == run_name]

    # Get the latest run (by creation time)
    latest_run = max(matching_runs, key=lambda run: run.created_at) if matching_runs else None

    if latest_run:
        # Fetch the latest history of the run
        history = latest_run.history()

        # Filter out the latest values of metrics starting with 'fvu'
        latest_metrics = {
            key: history[key].iloc[-1]
            for key in history.columns if key.endswith(layer_name)
        }
    
        latest_metrics = {key.split("/")[0]: float(value) for key, value in latest_metrics.items()}

        if not latest_metrics:
            print(f'Warning! No layer metrics found matching {layer_name}. Return empty metrics.')            

        return latest_metrics
    else:
        print(f'Warning! No run found matching {run_name}. Return empty metrics.')

In [None]:
def upload_saes(train_config, custom_config=None):
    run_name = train_config.run_name
    layers = train_config.hookpoints
    folder_path = train_config.run_name

    for layer in layers:
        layer_metrics_path = f"{folder_path}/{layer}/metrics.json"
        metrics = get_metrics_for_wandb_run_and_layer(run_name, layer)

        with open(layer_metrics_path, "w") as f_out:
            metrics = json.dump(metrics, f_out)

    if custom_config:
        with open(f"{folder_path}/model_config.json", "w") as f_out:
            json.dump(custom_config, f_out)

    model_abbrev = folder_path.replace("saes_", "")

    upload_folder(
        folder_path=f"{folder_path}",
        path_in_repo=f"{folder_path}",
        repo_id=sae_repo_name,
        commit_message=f"Uploading saes for {layers} and {model_abbrev}",  # Optional commit message
        repo_type="model",
        ignore_patterns=["*.tmp", "*.log"],  # Optionally exclude specific files
    )

In [None]:
def train_and_upload_sae(model_and_dataset_full_name: ModelDatasetFullName):

    mdfn = model_and_dataset_full_name
    full_model_name = mdfn.full_model_name

    print('Freeing memory')
    free_memory()

    print(f'Loading {full_model_name}')
    
    model = AutoModelForCausalLM.from_pretrained(
        full_model_name, device_map={"": "cuda"},
        torch_dtype=torch.bfloat16,
    )
    print(model)

    tokenizer = AutoTokenizer.from_pretrained(full_model_name)

    model.config.pad_token_id = tokenizer.pad_token_id

    print(f"pad_token is {tokenizer.decode(model.config.pad_token_id)}")

    model_abbrev = model_and_dataset_full_name.get_model_abbrev()

    print_message = f"""Working with model {full_model_name} with dataset name
    {mdfn.full_dataset_name} for {model_abbrev}.
    """

    print(dedent(print_message))

    partial_tokenize_function = partial(tokenize_function, max_seq_len, tokenizer)

    dataset = load_dataset(full_dataset_name)['train']
    mapped_dataset = dataset.map(map_text, remove_columns=dataset.features)
    copies = [mapped_dataset] * dataset_multiplication_factor
    duplicated_dataset = concatenate_datasets(copies)

    tokenized_dataset = duplicated_dataset.map(partial_tokenize_function, batched=True, num_proc=8)
    tokenized_dataset = tokenized_dataset.shuffle(seed=42)
    print(tokenized_dataset)
    tokenized_dataset.set_format(type="torch")

    train_config, custom_config = get_sae_config(model_and_dataset_full_name, model_abbrev)
    print(train_config)
    run_name = train_config.run_name
    print(f'Prepping trainer for {model_abbrev} with train_config {train_config}')

    trainer = SaeTrainer(train_config, tokenized_dataset, model=model)
    trainer.fit()

    wandb.finish()
    print("Uploading sae")
    upload_saes(train_config=train_config, custom_config=custom_config)
    return model

In [None]:
for mdfn in all_model_dataset_full_names:
    tokenized=train_and_upload_sae(mdfn)

Freeing memory
Loading withmartian/sql_interp_bm2_cs1_experiment_4.2
Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwe

VBox(children=(Map (num_proc=8):   0%|          | 0/535500 [00:00<?, ? examples/s],))

Dataset({
    features: ['text', 'input_ids', 'attention_mask'],
    num_rows: 535500
})
None
TrainConfig(sae=SaeConfig(expansion_factor=15, normalize_decoder=True, num_latents=0, k=256, multi_topk=False, skip_connection=False), batch_size=16, grad_acc_steps=1, micro_acc_steps=1, lr=None, lr_warmup_steps=1000, auxk_alpha=0.0, dead_feature_threshold=10000000, hookpoints=None, init_seeds=[0], layers=[], layer_stride=1, transcode=False, distribute_modules=False, save_every=100000, log_to_wandb=True, run_name='saes_sql_interp_bm2_cs1_experiment_4.2_syn=True/k=256', wandb_log_frequency=1)
Prepping trainer for saes_sql_interp_bm2_cs1_experiment_4.2_syn=True with train_config TrainConfig(sae=SaeConfig(expansion_factor=15, normalize_decoder=True, num_latents=0, k=256, multi_topk=False, skip_connection=False), batch_size=16, grad_acc_steps=1, micro_acc_steps=1, lr=None, lr_warmup_steps=1000, auxk_alpha=0.0, dead_feature_threshold=10000000, hookpoints=None, init_seeds=[0], layers=[], layer_strid

Number of SAE parameters: 578_371_584
Number of model parameters: 494_032_768


VBox(children=(Training:   0%|          | 0/33468 [00:00<?, ?it/s],))