In [None]:
def install_dependencies():
    ! pip install datasets
    ! pip install torch==2.4.*
    ! pip install ipywidgets
    ! pip install transformers
    ! pip install wandb
    ! git clone git@github.com:amirabdullah19852020/sae.git | true
    ! cd sae && ls -a && pip install .
    ! cd .. && pip install . --upgrade

install_dependencies()

In [None]:
import json
from dataclasses import dataclass

import torch
import wandb

from datasets import load_dataset, concatenate_datasets
from huggingface_hub import upload_folder
from sae import SaeConfig, SaeTrainer, TrainConfig
from sae.data import chunk_and_tokenize
from textwrap import dedent
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM
from QuantaTextToSql import sql_interp_model_location
from QuantaTextToSql.ablate.load_model import free_memory


all_cs_num = [1, 2, 3]
all_model_num = [0, 1, 2, 3]

# This is essentially the same as "num_epochs" or number of times the SAE sees each data point.
dataset_multiplication_factor = 3

wandb_project_name = "sql_interp_saes"
sae_repo_name = "withmartian/sql_interp_saes"

wandb.init(project=wandb_project_name)

In [None]:
@dataclass
class ModelDatasetFullName:
    model_id: str
    dataset_id: str
    full_dataset_name: str
    full_model_name: str

# Tinystories only.
selected_model_ids = [1]

# From cs1 through to cs3.
# Need to figure out a dataset for "base" SAE.
dataset_ids = [1, 2, 3]

alpaca_prompt = "### Instruction:\n{}\n### Context:\n{}\n### Response:\n"

In [None]:
all_model_dataset_full_names = []

for model_id in selected_model_ids:
    for dataset_id in dataset_ids:
        full_model_name = sql_interp_model_location(model_id, dataset_id)
        full_dataset_name = f"withmartian/cs{dataset_id}_dataset"

        model_dataset_full_name = ModelDatasetFullName(
            model_id=model_id, dataset_id=dataset_id,
            full_dataset_name=full_dataset_name, full_model_name=full_model_name)

        all_model_dataset_full_names.append(model_dataset_full_name)

In [None]:
def get_hookpoints(model_id: int):
    if model_id == 1:
        hookpoints= ["transformer.h.*.attn", "transformer.h.*.mlp.act"]
        return hookpoints
    else:
        raise Exception(f"Hookpoints not added yet for model id {model_id}")

In [None]:
def get_sae_config(model_dataset_full_name: ModelDatasetFullName, model_abbrev):
    expansion_factor = 4
    k = 32
    sae_config = SaeConfig(expansion_factor=expansion_factor)
    hookpoints = get_hookpoints(model_dataset_full_name.model_id)

    dataset_name = model_dataset_full_name.full_dataset_name
    model_name = model_dataset_full_name.full_model_name

    run_name = f"saes_{model_abbrev}"

    custom_config = {
        "model_name": model_name, "dataset_name": dataset_name,
        "model_abbrev":model_abbrev
    }

    train_config =  TrainConfig(
        sae=sae_config, hookpoints=hookpoints, run_name=run_name, lr=5e-4,
    )
    return train_config, custom_config

In [None]:
def map_text(input_features):
    instruction = input_features['english_prompt']
    context = input_features['create_statement']
    response = input_features['sql_statement']
    final_prompt = alpaca_prompt.format(instruction, context, response)
    output = {"text": final_prompt}
    return output

In [None]:
def get_dataset(model_dataset_full_name: ModelDatasetFullName):
    full_dataset_name = model_dataset_full_name.full_dataset_name
    full_model_name = model_dataset_full_name.full_model_name

    dataset = load_dataset(
        full_dataset_name,
        split="train",
        trust_remote_code=True,
    )

    concatenated_dataset =  concatenate_datasets(dataset_multiplication_factor * [dataset]).shuffle()
    tokenizer = AutoTokenizer.from_pretrained(full_model_name)

    mapped_dataset = concatenated_dataset.map(map_text)
    return mapped_dataset

In [None]:
def upload_saes(train_config, custom_config=None):
    layers = train_config.hookpoints
    for layer in layers:
        folder_path = train_config.run_name

        if custom_config:
            with open(f"{folder_path}/model_config.json", "w") as f_out:
                json.dump(custom_config, f_out)

        model_abbrev = folder_path.replace("saes_", "")

    upload_folder(
        folder_path=f"{folder_path}",
        path_in_repo=f"{folder_path}",
        repo_id=sae_repo_name,
        commit_message=f"Uploading saes for {layers} and {model_abbrev}",  # Optional commit message
        repo_type="model",
        ignore_patterns=["*.tmp", "*.log"],  # Optionally exclude specific files
    )

In [None]:
def train_and_upload_sae(model_and_dataset_full_name: ModelDatasetFullName):

    mdfn = model_and_dataset_full_name
    full_model_name = mdfn.full_model_name

    model = AutoModelForCausalLM.from_pretrained(
        full_model_name, device_map={"": "cuda"},
        torch_dtype=torch.bfloat16,
    )

    model_abbrev = f"bm{mdfn.model_id}_cs{mdfn.dataset_id}"

    print_message = f"""Working with model {full_model_name} with dataset name
    {mdfn.full_dataset_name} for {model_abbrev}.
    """

    print(dedent(print_message))

    mapped_dataset = get_dataset(model_and_dataset_full_name)
    tokenizer = AutoTokenizer.from_pretrained(full_model_name)

    tokenized = chunk_and_tokenize(mapped_dataset, tokenizer)
    train_config, custom_config = get_sae_config(model_and_dataset_full_name, model_abbrev)
    run_name = train_config.run_name

    print(f'Prepping trainer for {model_abbrev} with train_config {train_config}')

    trainer = SaeTrainer(train_config, tokenized, model)

    trainer.fit()

    upload_saes(train_config=train_config, custom_config=custom_config)

In [None]:
for mdfn in all_model_dataset_full_names:
    train_and_upload_sae(mdfn)