In [1]:
from datasets import load_dataset
from tqdm.auto import trange

dd = load_dataset("nyu-mll/glue", "cola")

In [2]:
import torch
from sae_lens import SAE, HookedSAETransformer
from probe_lens import ActivationStore, ProbeTrainer

model = HookedSAETransformer.from_pretrained("gpt2-small")
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

hook_point = "blocks.7.hook_resid_pre"
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb", sae_id=hook_point, device=device
)
model.add_sae(sae, hook_point)

Loaded pretrained model gpt2-small into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [3]:
labels = dd['train'].features['label'].names

def filter_fn(activations):
    return activations[:, -1:, :]

store = ActivationStore(model, hooks=[sae.hook_sae_acts_post], class_names=labels, act_filter_fn=filter_fn)

In [4]:
ds = dd['train']
store.set_split("train")

store.add_labels(ds['label'][:20*64])

batch_size = 64
prompts_train = ds['sentence'][:20*64]
for i in trange(0, len(prompts_train), batch_size):
    batch = prompts_train[i : i + batch_size]
    model(batch)


  0%|          | 0/20 [00:00<?, ?it/s]

In [5]:
ds = dd['validation']
store.set_split("test")

store.add_labels(ds['label'])

batch_size = 64
prompts_test = ds['sentence']
for i in trange(0, len(prompts_test), batch_size):
    batch = prompts_test[i : i + batch_size]
    model(batch)


  0%|          | 0/17 [00:00<?, ?it/s]

In [6]:
store.detach()
act_dd = store.compile_dataset()
act_dd

DatasetDict({
    train: Dataset({
        features: ['label', 'blocks.7.hook_resid_pre.hook_sae_acts_post'],
        num_rows: 1280
    })
    test: Dataset({
        features: ['label', 'blocks.7.hook_resid_pre.hook_sae_acts_post'],
        num_rows: 1043
    })
})

In [8]:
probe_trainer = ProbeTrainer(
    act_dd,
    flatten_T="batch",
    wandb_project="Glue-COLA",
    max_iter=1000,
)
probe_trainer.train()
probe_trainer.save_probes("./glue_cola_probes")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Training probes:   0%|          | 0/1 [00:00<?, ?it/s]

[34m[1mwandb[0m: 
[34m[1mwandb[0m: Plotting blocks.7.hook_resid_pre.hook_sae_acts_post.


Probe Metrics for blocks.7.hook_resid_pre.hook_sae_acts_post:
              precision    recall  f1-score   support

unacceptable       0.34      0.07      0.12       322
  acceptable       0.69      0.94      0.80       721

    accuracy                           0.67      1043
   macro avg       0.52      0.51      0.46      1043
weighted avg       0.59      0.67      0.59      1043



[34m[1mwandb[0m: Logged feature importances.
