In [None]:
import json
from dataclasses import dataclass

import torch
import wandb

from datasets import load_dataset, concatenate_datasets
from huggingface_hub import upload_folder
from sae import SaeConfig, SaeTrainer, TrainConfig
from sae.data import chunk_and_tokenize
from textwrap import dedent
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM
from TinySQL import sql_interp_model_location
from TinySQL.ablate.load_model import free_memory


all_cs_num = [1, 2, 3]
all_model_num = [0, 1, 2, 3]

# This is essentially the same as "num_epochs" or number of times the SAE sees each data point.
dataset_multiplication_factor = 3

wandb_project_name = "sql_interp_saes"
sae_repo_name = "withmartian/sql_interp_saes"

wandb.init(project=wandb_project_name)

In [None]:
@dataclass
class ModelDatasetFullName:
    model_id: str
    dataset_id: str
    full_dataset_name: str
    full_model_name: str

    def load_model(self):
        model = AutoModelForCausalLM.from_pretrained(full_model_name)
        return model

# 1, 2 and 3 map to TinyStores, Qwen and Llama respectively.
# Set the below list to any of [1,2,3] based on what you need to train.
selected_model_ids = [3, 2, 1]

# From cs1 through to cs3.
# Need to figure out a dataset for "base" SAE.
dataset_ids = [1, 2, 3]

alpaca_prompt = "### Instruction:\n{}\n### Context:\n{}\n### Response:\n"

In [None]:
all_model_dataset_full_names = []

for model_id in selected_model_ids:
    for dataset_id in dataset_ids:
        full_model_name = sql_interp_model_location(model_id, dataset_id)
        full_dataset_name = f"withmartian/cs{dataset_id}_dataset"

        model_dataset_full_name = ModelDatasetFullName(
            model_id=model_id, dataset_id=dataset_id,
            full_dataset_name=full_dataset_name, full_model_name=full_model_name)

        all_model_dataset_full_names.append(model_dataset_full_name)

In [None]:
def get_hookpoints(model_id: int):
    model_id = int(model_id)
    if model_id == 1:
        hookpoints= ["transformer.h.*.attn", "transformer.h.*.mlp"]
        return hookpoints
    # Qwen and LLama-3 have same module naming conventions.
    elif model_id in [2, 3]:
        hookpoints = ["model.layers.*.mlp", "model.layers.*.self_attn"]
        return hookpoints
    else:
        raise Exception(f"Hookpoints not added yet for model id {model_id}")

In [None]:
def get_sae_config(model_dataset_full_name: ModelDatasetFullName, model_abbrev=None):
    expansion_factor = 4
    k = 32
    sae_config = SaeConfig(expansion_factor=expansion_factor)
    hookpoints = get_hookpoints(model_dataset_full_name.model_id)
    print(hookpoints)

    dataset_name = model_dataset_full_name.full_dataset_name
    model_name = model_dataset_full_name.full_model_name

    if model_abbrev is None:
        model_name_splits = model_name.split("_")
        model_abbrev = f"{model_name_splits[2]}_{model_name_splits[3]}"

    run_name = f"saes_{model_abbrev}"

    custom_config = {
        "model_name": model_name, "dataset_name": dataset_name,
        "model_abbrev":model_abbrev
    }

    train_config =  TrainConfig(
        sae=sae_config, hookpoints=hookpoints, run_name=run_name, lr=5e-4,
    )
    return train_config, custom_config

In [None]:
all_model_dataset_full_names

In [None]:
# For debugging.
get_sae_config(all_model_dataset_full_names[0])

In [None]:
def map_text(input_features):
    instruction = input_features['english_prompt']
    context = input_features['create_statement']
    response = input_features['sql_statement']
    final_prompt = alpaca_prompt.format(instruction, context, response)
    output = {"text": final_prompt}
    return output

In [None]:
def get_dataset(model_dataset_full_name: ModelDatasetFullName):
    full_dataset_name = model_dataset_full_name.full_dataset_name
    full_model_name = model_dataset_full_name.full_model_name

    dataset = load_dataset(
        full_dataset_name,
        split="train",
        trust_remote_code=True,
    )

    concatenated_dataset =  concatenate_datasets(dataset_multiplication_factor * [dataset]).shuffle()
    tokenizer = AutoTokenizer.from_pretrained(full_model_name)

    mapped_dataset = concatenated_dataset.map(map_text)
    return mapped_dataset

In [None]:
def get_metrics_for_wandb_run_and_layer(run_name, layer_name):
    api = wandb.Api()
    runs = api.runs(path="sae")
    matching_runs = [run for run in runs if run.name == run_name]

    # Get the latest run (by creation time)
    latest_run = max(matching_runs, key=lambda run: run.created_at) if matching_runs else None

    if latest_run:
        # Fetch the latest history of the run
        history = latest_run.history()
    
        # Filter out the latest values of metrics starting with 'fvu'
        latest_metrics = {
            key: history[key].iloc[-1]
            for key in history.columns if key.endswith(layer_name)
        }
    
        latest_metrics = {key.split("/")[0]: float(value) for key, value in latest_metrics.items()}

        if not latest_metrics:
            print(f'Warning! No layer metrics found matching {layer_name}. Return empty metrics.')
            
    
        return latest_metrics
    else:
        print(f'Warning! No run found matching {run_name}. Return empty metrics.')

In [None]:
def upload_saes(train_config, custom_config=None):
    run_name = train_config.run_name
    layers = train_config.hookpoints
    folder_path = train_config.run_name

    for layer in layers:
        layer_metrics_path = f"{folder_path}/{layer}/metrics.json"
        metrics = get_metrics_for_wandb_run_and_layer(run_name, layer)

        with open(layer_metrics_path, "w") as f_out:
            metrics = json.dump(metrics, f_out)

    if custom_config:
        with open(f"{folder_path}/model_config.json", "w") as f_out:
            json.dump(custom_config, f_out)

    model_abbrev = folder_path.replace("saes_", "")

    upload_folder(
        folder_path=f"{folder_path}",
        path_in_repo=f"{folder_path}",
        repo_id=sae_repo_name,
        commit_message=f"Uploading saes for {layers} and {model_abbrev}",  # Optional commit message
        repo_type="model",
        ignore_patterns=["*.tmp", "*.log"],  # Optionally exclude specific files
    )

In [None]:
def train_and_upload_sae(model_and_dataset_full_name: ModelDatasetFullName):

    mdfn = model_and_dataset_full_name
    full_model_name = mdfn.full_model_name

    print('Freeing memory')
    free_memory()

    print(f'Loading {full_model_name}')
    
    model = AutoModelForCausalLM.from_pretrained(
        full_model_name, device_map={"": "cuda"},
        torch_dtype=torch.bfloat16,
    )

    model_abbrev = f"bm{mdfn.model_id}_cs{mdfn.dataset_id}"

    print_message = f"""Working with model {full_model_name} with dataset name
    {mdfn.full_dataset_name} for {model_abbrev}.
    """

    print(dedent(print_message))

    mapped_dataset = get_dataset(model_and_dataset_full_name)
    tokenizer = AutoTokenizer.from_pretrained(full_model_name)

    tokenized = chunk_and_tokenize(mapped_dataset, tokenizer)
    train_config, custom_config = get_sae_config(model_and_dataset_full_name, model_abbrev)
    run_name = train_config.run_name

    print(f'Prepping trainer for {model_abbrev} with train_config {train_config}')

    trainer = SaeTrainer(train_config, tokenized, model)
    trainer.fit()

    wandb.finish()

    upload_saes(train_config=train_config, custom_config=custom_config)

In [None]:
for mdfn in all_model_dataset_full_names:
    train_and_upload_sae(mdfn)