### Load in the SAEs from Huggingface Hub.

In [None]:
! pip uninstall triton -y

In [None]:
import json
import os
import psutil
import re

from dataclasses import dataclass
from pathlib import Path

import nnsight
import numpy as np
import sae
import torch
import torch.fx

from datasets import load_dataset
from huggingface_hub import snapshot_download
from nnsight import NNsight, LanguageModel
from sae import Sae

from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from TinySQL import sql_interp_model_location
from TinySQL.training_data.fragments import field_names, table_names

In [None]:
# Get the current process
def process_info():
    process = psutil.Process(os.getpid())
    
    # Memory usage in MB
    memory_info = process.memory_info()
    print(f"RSS: {memory_info.rss / (1024 ** 2):.2f} MB")  # Resident Set Size
    print(f"VMS: {memory_info.vms / (1024 ** 2):.2f} MB") 

process_info()

In [None]:
dir(table_names)

In [None]:
all_table_names = table_names.get_TableInfo()
dir(field_names)

In [None]:
repo_name = "withmartian/sql_interp_saes"

# Change this to work with another model alias.
model_alias = "saes_bm1_cs1"
cache_dir = "working_directory"
seed = 42

process_info()

In [None]:
@dataclass
class ModelDatasetFullName:
    model_id: str
    dataset_id: str
    full_dataset_name: str
    full_model_name: str

    def load_model(self):
        model = AutoModelForCausalLM.from_pretrained(full_model_name)
        return model

In [None]:
@dataclass
class SaeOutput:
    sae_name: str
    sae: Sae
    text: str
    tokens: list[str]
    raw_acts: list[float]
    top_acts: list[float]
    top_indices: list[int]

    def zero_out_except_top_n(self, scores, indices, n):
        """
        Zero out all but the top n scores in the scores vector, preserving the order.
        """
        if len(scores) != len(indices):
            raise ValueError("Scores and indices lists must have the same length.")
    
        if n <= 0:
            return [0] * len(scores), indices
    
        # Pair scores with their original indices
        paired = list(zip(scores, range(len(scores))))
    
        # Sort by score in descending order
        paired.sort(key=lambda x: -x[0])
    
        # Get the indices of the top n scores
        top_n_indices = [index for _, index in paired[:n]]
    
        # Create a new scores list with only the top n scores retained
        filtered_scores = [scores[i] if i in top_n_indices else 0 for i in range(len(scores))]
    
        return filtered_scores, indices

    def zero_out_except_top_n_for_multiple(self, scores_list, indices_list, n):
        if len(scores_list) != len(indices_list):
            raise ValueError("Scores and indices lists must have the same length.")
    
        filtered_scores_list = []
        filtered_indices_list = []
    
        for scores, indices in zip(scores_list, indices_list):
            filtered_scores, filtered_indices = self.zero_out_except_top_n(scores, indices, n)
            filtered_scores_list.append(filtered_scores)
            filtered_indices_list.append(filtered_indices)
    
        return filtered_scores_list, filtered_indices_list

    def reconstruction_error(self, k=128):
        decoded_activations = self.decode_to_activations(k)
        raw_acts = torch.tensor(self.raw_acts).cuda()

        difference = decoded_activations - raw_acts

        reconstruction_error = torch.norm(difference) / torch.norm(raw_acts)
        return reconstruction_error.item()

    def decode_to_activations(self, k=128):
        filtered_acts, top_k_indices = self.zero_out_except_top_n_for_multiple(self.top_acts.copy(), self.top_indices.copy(), n=k)
        return self.sae.decode(top_acts=torch.tensor(filtered_acts).cuda(), top_indices=torch.tensor(top_k_indices).cuda())

@dataclass 
class GroupedSaeOutput:
    sae_outputs_by_layer: dict[str, SaeOutput]
    text: str
    tokens: list[str]
    tags: list[str]
    tags_by_index: list[str]

    def apply_tags(self):
        self.tags = []
        for token in self.tokens:
            if token in table_names:
                self.tags.append("TABLE")
            elif token in field_name:
                self.tags.append("FIELD")
            else:
                self.tags.append("NONE")

    def get_reconstruction_error_by_layer(self, layer, k):
        return sae_outputs_by_layer[layer].recontruction_error(k=k)
        

In [None]:
repo_path = Path(
    snapshot_download(repo_name, allow_patterns=f"{model_alias}/*", local_dir=cache_dir)
)

In [None]:
def format_example(example):
    alpaca_prompt = "### Instruction: {} ### Context: {} ### Response: "
    example['prompt'] = alpaca_prompt.format(example['english_prompt'], example['create_statement'])
    example['response'] = example['sql_statement']
    return example

In [None]:
class LoadedSAES:
    def __init__(self, dataset_name: str, full_model_name: str, model_alias: str,
        tokenizer: AutoTokenizer, language_model: LanguageModel, layers: list[str],
        layer_to_directory: dict, k: str, base_path: str, layer_to_saes: dict):

        self.dataset_name = dataset_name
        self.full_model_name = full_model_name
        self.model_alias = model_alias
        self.tokenizer = tokenizer
        self.language_model = language_model
        self.layers = layers
        self.layer_to_directory = layer_to_directory
        self.k = k
        self.base_path = base_path
        self.layer_to_saes = layer_to_saes

        self.dataset = self.get_dataset()
        self.mapped_dataset = self.dataset.map(format_example)
        

    @staticmethod
    def get_all_subdirectories(path):
        subdirectories = [
            os.path.join(path, name) for name in os.listdir(path) 
            if os.path.isdir(os.path.join(path, name)) and not name.startswith(".")
        ]
        return subdirectories

    def nnsight_eval_string_for_layer(self, layer: str):
        """
        Converts transformer.h[0].mlp into self.language_model.transformer.h[0].mlp.output.save() for nnsight
        """
        subbed_layer = re.sub(r'\.([0-9]+)\.', r'[\1].', layer)
        return f"self.language_model.{subbed_layer}.output.save()"

    def encode_to_activations_for_layer(self, text: str, layer: str):
        if "bm1" in self.full_model_name:
            with self.language_model.trace() as tracer:
                with tracer.invoke(text) as invoker:
                    eval_string = self.nnsight_eval_string_for_layer(layer)
                    my_output = eval(eval_string)
        if len(my_output) > 1:
            return my_output[0]
        else:
            return my_output

    def encode_to_sae_for_layer(self, text: str, layer: str):
        activations = self.encode_to_activations_for_layer(text, layer).cuda()
        raw_acts = activations[0].cpu().detach().numpy().tolist()
        relevant_sae = self.layer_to_saes[layer]
        sae_acts_and_features = relevant_sae.encode(activations)
        tokens = self.tokenizer.tokenize(text)

        top_acts = sae_acts_and_features.top_acts[0].cpu().detach().numpy().tolist()
        top_indices = sae_acts_and_features.top_indices[0].cpu().detach().numpy().tolist()

        sae_output = SaeOutput(
            sae_name=layer, text=text, tokens=tokens, top_acts=top_acts, top_indices=top_indices, raw_acts=raw_acts, sae=relevant_sae
        )
        
        return sae_output

    def encode_to_all_saes(self, text: str):
        sae_outputs_by_layer = {layer: self.encode_to_sae_for_layer(text=text, layer=layer) for layer in self.layers}
        tokens = self.tokenizer.tokenize(text)
        result = GroupedSaeOutput(sae_outputs_by_layer=sae_outputs_by_layer, text=text, tokens=tokens, tags=[], tags_by_index=[])
        result.apply_tags()

    @staticmethod
    def load_from_path(model_alias: str, k: str):
        k = str(k)
        
        base_path = f"{cache_dir}/{model_alias}/k={k}"
        
        print(f"Loading from path {base_path}")
        subdirectories = LoadedSAES.get_all_subdirectories(base_path)
        
        layer_to_directory = {
            directory.split("/")[-1] : directory for directory in subdirectories
        }

        layer_to_directory = {layer: directory for layer, directory in layer_to_directory.items() if "1" in layer}
        layers = sorted(list(layer_to_directory.keys()))

        with open(f"{base_path}/model_config.json",  "r") as f_in:
            model_config = json.load(f_in)

            dataset_name = model_config["dataset_name"]
            full_model_name = model_config["model_name"]
            language_model = LanguageModel(full_model_name)
            tokenizer = language_model.tokenizer

        layer_to_saes = {layer: Sae.load_from_disk(directory).cuda() for layer, directory in layer_to_directory.items()}

        return LoadedSAES(dataset_name=dataset_name, full_model_name=full_model_name,
                          model_alias=model_alias, layers=layers, layer_to_directory=layer_to_directory,
                          tokenizer=tokenizer, k=k, base_path=base_path,
                          layer_to_saes=layer_to_saes, language_model=language_model)

    def get_dataset(self):
        return load_dataset(self.dataset_name)

In [None]:
loaded_saes = LoadedSAES.load_from_path(model_alias=model_alias, k=128)

In [None]:
class SaeCollector:
    """
    This class is responsible for collecting a large amount of text,
    assigning tags to each token for a feature name, and also SAE outputs.
    These can then be used for probes and feature analysis.
    (Still to add: ablations.)
    """

    def __init__(self, loaded_saes, sample_size=2, restricted_tags=None):
        self.loaded_saes = loaded_saes
        self.restricted_tags = restricted_tags or []
        self.sample_size = sample_size
        self.mapped_dataset = loaded_saes.mapped_dataset
        self.mapped_dataset.shuffle(seed=seed)
        self.tokenizer = self.loaded_saes.tokenizer
        self.layers = self.loaded_saes.layers
        self.encoded_set = self.create_and_load_random_subset(sample_size=self.sample_size)

    def get_prompt_and_encoding_for_text(self, feature):
        prompt = feature["prompt"]
        response = feature["response"]
        encoding = self.loaded_saes.encode_to_all_saes(prompt)

        return_dict = {
            "prompt": prompt,
            "response": response,
            "encoding": encoding
        }
        return return_dict

    def get_avg_reconstruction_error_for_all_k_and_layers(self):
        all_reconstruction_errors = {layer: self.get_avg_reconstruction_error_for_all_k(layer) for layer in self.layers}
        return all_reconstruction_errors

    def get_avg_reconstruction_error_for_all_k(self, layer, min_range=0, max_range=128):
        all_reconstruction_errors = {}
        for element in tqdm(self.encoded_set):
            encoding = element["encoding"]
            for k in range(min_range, max_range):
                recon_error = encoding.sae_outputs_by_layer[layer].reconstruction_error(k)
                curr_reconstruction_error_list = all_reconstruction_errors.get(k, [])
                curr_reconstruction_error_list.append(recon_error)
                all_reconstruction_errors[k] = curr_reconstruction_error_list

        average_reconstruction_errors = {k: np.average(error_list) for k, error_list in all_reconstruction_errors.items()}
        return average_reconstruction_errors

    def create_and_load_random_subset(self, sample_size: int):
        sampled_set = self.mapped_dataset['train'].select(range(sample_size))
        encoded_set = []
        for element in tqdm(sampled_set):
            encoded_element = self.get_prompt_and_encoding_for_text(element)
            encoded_set.append(encoded_element)
        return encoded_set

In [None]:
sae_collector = SaeCollector(loaded_saes)

### Maximally activating datasets

In [None]:
element = sae_collector.encoded_set[0]["encoding"]
element.tags

### Monitor reconstruction Errors.



In [None]:
reconstruction_error_by_k_and_layer = sae_collector.get_avg_reconstruction_error_for_all_k_and_layers()

In [None]:
# reconstruction_error_by_k_and_layer

### Monitor Ablation Errors.

### Add taggers

In [None]:
from TinySQL.training_data.fragments import field_names, table_names

In [None]:
import TinySQL
help(TinySQL.fragments)

In [None]:
first_element = encoded_set[0]

In [None]:
# first_element