In [None]:
def install_dependencies():
    ! rm -rf TinySQL || True
    ! git clone -b amir/check_errors https://github.com/withmartian/TinySQL.git
    ! cd TinySQL && pip install .

    ! rm -rf sae || True
    ! git clone https://github.com/amirabdullah19852020/sae.git
    ! cd sae && pip install .

    ! pip install sqlparse
    ! pip install scikit-learn

install_dependencies()

In [None]:
from pathlib import Path

from datasets import Dataset, concatenate_datasets
from huggingface_hub import snapshot_download

from sae.sae_interp import GroupedSaeOutput, SaeOutput, SaeCollector, LoadedSAES

from TinySQL.training_data.data_analyzer import get_errors
from TinySQL.classifiers.t5_classifier import train_t5_classifier
from TinySQL import sql_interp_model_location

import wandb

In [None]:
repo_name = "withmartian/sql_interp_saes"
cache_dir = "working_directory"

syn=False
model_num=1
cs_num=3
syn=True
k=256

full_model_name = sql_interp_model_location(model_num=model_num, cs_num=cs_num, synonym=syn)
model_alias = f"saes_{full_model_name.split('/')[1]}_syn={syn}"
print(model_alias)

# Change this to work with another model alias.
seed = 42

In [None]:
correct_and_errors_dataset = get_errors(fast=False)

In [None]:
errors = correct_and_errors_dataset['errors']
correct_only = correct_and_errors_dataset['correct_predictions']
print(len(errors))
print(len(correct_only))

In [None]:
run = wandb.init(
    project="sql_sae_linear_probe",  name=f"{model_alias}_{model_num}_{cs_num}_{syn}_{k}",
    config={"model_alias": model_alias, "k": k, "model_num": model_num, "cs_num": cs_num}
)

In [None]:
errors = errors.map(lambda x: {"label": 0})
errors = errors.map(lambda x: {"prompt": x["full_output"]})
correct_only = correct_only.map(lambda x: {"label": 1})
correct_only = correct_only.map(lambda x: {"prompt": x["full_output"]})

all_labels = concatenate_datasets([errors, correct_only])
len(all_labels)
all_labels = all_labels.shuffle(seed=42)

### Baseline use a T5 classifier

In [None]:
t5_labels = [(sample["prompt"], sample["label"]) for sample in all_labels]

In [None]:
accuracy, model = train_t5_classifier(t5_labels)
print(accuracy)

### Use Linear probes on SAEs.

In [None]:
from pathlib import Path

from huggingface_hub import snapshot_download
from sae.sae_interp import GroupedSaeOutput, SaeOutput, SaeCollector, LoadedSAES
from TinySQL import sql_interp_model_location

In [None]:
repo_path = Path(
    snapshot_download(repo_name, allow_patterns=f"{model_alias}/*", local_dir=cache_dir)
)

In [None]:
loaded_saes = LoadedSAES.load_from_path(
    model_alias=model_alias, k=k, cache_dir=cache_dir, 
    store_activations=False, dataset=all_labels)

In [None]:
sae_collector = SaeCollector(loaded_saes=loaded_saes, seed=seed, sample_size=3000)

In [None]:
from TinySQL.classifiers.logistic_regression_classifier import train_linear_probe_sparse

In [None]:
dataset = sae_collector.encoded_set

In [None]:
accuracy, top_features, y_pred, y_test = train_linear_probe_sparse(dataset, representation_column="averaged_representation")

In [None]:
top_features

In [None]:
accuracy

In [None]:
len(errors)

In [None]:
len(all_labels)