### Load in the SAEs from Huggingface Hub.

In [None]:
import json
import os
import re

from dataclasses import dataclass
from pathlib import Path

import nnsight
import sae
import torch
import torch.fx

from datasets import load_dataset
from huggingface_hub import snapshot_download
from nnsight import NNsight, LanguageModel
from sae import Sae
from transformers import AutoTokenizer, AutoModelForCausalLM
from TinySQL import sql_interp_model_location

In [None]:
repo_name = "withmartian/sql_interp_saes"

# Change this to work with another model alias.
model_alias = "saes_bm1_cs1"
cache_dir = "working_directory"

In [None]:
@dataclass
class ModelDatasetFullName:
    model_id: str
    dataset_id: str
    full_dataset_name: str
    full_model_name: str

    def load_model(self):
        model = AutoModelForCausalLM.from_pretrained(full_model_name)
        return model

In [None]:
@dataclass
class SaeOutput:
    sae_name: str
    text: str
    tokens: list[str]
    tags_by_index: list[str]
    feature_acts_and_indices: list[list[(float, int)]]
    features_indices: list[list[int]]


@dataclass 
class GroupedSaeOutput:
    sae_outputs_by_layer: dict[str, SaeOutput]
    text: str
    tokens: list[str]

In [None]:
repo_path = Path(
    snapshot_download(repo_name, allow_patterns=f"{model_alias}/*", local_dir=cache_dir)
)

In [None]:
def format_example(example):
    alpaca_prompt = "### Instruction: {} ### Context: {} ### Response: "
    example['prompt'] = alpaca_prompt.format(example['english_prompt'], example['create_statement'])
    example['response'] = example['sql_statement']
    return example

In [None]:
class LoadedSAES:
    def __init__(self, dataset_name: str, full_model_name: str, model_alias: str,
        tokenizer: AutoTokenizer, language_model: LanguageModel, layers: list[str],
        layer_to_directory: dict, k: str, base_path: str, layer_to_saes: dict):

        self.dataset_name = dataset_name
        self.full_model_name = full_model_name
        self.model_alias = model_alias
        self.tokenizer = tokenizer
        self.language_model = language_model
        self.layers = layers
        self.layer_to_directory = layer_to_directory
        self.k = k
        self.base_path = base_path
        self.layer_to_saes = layer_to_saes

        self.dataset = self.get_dataset()
        self.mapped_dataset = self.dataset.map(format_example)

    @staticmethod
    def get_all_subdirectories(path):
        subdirectories = [
            os.path.join(path, name) for name in os.listdir(path) 
            if os.path.isdir(os.path.join(path, name)) and not name.startswith(".")
        ]
        return subdirectories

    def nnsight_eval_string_for_layer(self, layer: str):
        """
        Converts transformer.h[0].mlp into self.language_model.transformer.h[0].mlp.output.save() for nnsight
        """
        subbed_layer = re.sub(r'\.([0-9]+)\.', r'[\1].', layer)
        return f"self.language_model.{subbed_layer}.output.save()"

    def encode_to_activations_for_layer(self, text: str, layer: str):
        if "bm1" in self.full_model_name:
            with self.language_model.trace() as tracer:
                with tracer.invoke(text) as invoker:
                    eval_string = self.nnsight_eval_string_for_layer(layer)
                    my_output = eval(eval_string)
        if len(my_output) > 1:
            return my_output[0]
        else:
            return my_output

    def encode_to_sae_for_layer(self, text: str, layer: str):
        activations = self.encode_to_activations_for_layer(text, layer)
        print(f'Activations are {activations.shape} at {layer}')

        raw_acts = activations[0].detach().numpy().tolist()
        
        relevant_sae = self.layer_to_saes[layer]
        sae_acts_and_features = relevant_sae.encode(activations)
        tokens = self.tokenizer.tokenize(text)

        top_acts = sae_acts_and_features.top_acts[0].detach().numpy().tolist()
        top_indices = sae_acts_and_features.top_indices[0].detach().numpy().tolist()

        zipped_acts_and_indices = [[(f, i) for f, i in zip(float_sublist, int_sublist)]
               for float_sublist, int_sublist in zip(top_acts, top_indices)]

        sae_output = SaeOutput(
            sae_name=layer, text=text, tokens=tokens, feature_acts_and_indices=zipped_acts_and_indices, 
            features_indices=top_indices
        )
        
        return sae_output

    def encode_to_all_saes(self, text: str):
        sae_outputs_by_layer = {layer: self.encode_to_sae_for_layer(text=text, layer=layer) for layer in self.layers}
        tokens = self.tokenizer.tokenize(text)

        return GroupedSaeOutput(sae_outputs_by_layer=sae_outputs_by_layer, text=text, tokens=tokens, tags=[])
        

    @staticmethod
    def load_from_path(model_alias: str, k: str):
        k = str(k)
        
        base_path = f"{cache_dir}/{model_alias}/k={k}"
        
        print(f"Loading from path {base_path}")
        subdirectories = LoadedSAES.get_all_subdirectories(base_path)
        
        layer_to_directory = {
            directory.split("/")[-1] : directory for directory in subdirectories
        }
        layers = sorted(list(layer_to_directory.keys()))

        with open(f"{base_path}/model_config.json",  "r") as f_in:
            model_config = json.load(f_in)

            dataset_name = model_config["dataset_name"]
            full_model_name = model_config["model_name"]
            language_model = LanguageModel(full_model_name)
            tokenizer = language_model.tokenizer

        layer_to_saes = {layer: Sae.load_from_disk(directory) for layer, directory in layer_to_directory.items()}

        return LoadedSAES(dataset_name=dataset_name, full_model_name=full_model_name,
                          model_alias=model_alias, layers=layers, layer_to_directory=layer_to_directory,
                          tokenizer=tokenizer, k=k, base_path=base_path,
                          layer_to_saes=layer_to_saes, language_model=language_model)

    def get_dataset(self):
        return load_dataset(self.dataset_name)

In [None]:
loaded_saes = LoadedSAES.load_from_path(model_alias=model_alias, k=16)

In [None]:
class SaeCollector:
    """
    This class is responsible for collecting a large amount of text,
    assigning tags to each token for a feature name, and also SAE outputs.
    These can then be used for probes and feature analysis.
    (Still to add: ablations.)
    """

    def __init__(self, loaded_saes, restricted_tags=None):
        self.loaded_saes = loaded_saes
        self.restricted_tags = restricted_tags or []
        self.mapped_dataset = loaded_saes.mapped_dataset
        self.tokenizer = self.loaded_saes.tokenizer
        self.supported_token_tags = {'last_token', 'eng_table_name', 'eng_field_name'}