In [1]:
import torch as t
import torch.nn as nn
import pandas as pd
import numpy as np
from tqdm import tqdm
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)

In [2]:
model2b = HookedSAETransformer.from_pretrained("gemma-2-2b", device = 'cuda')



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [3]:
city_alice = pd.read_csv("cities_alice.csv")
neg_city_alice = pd.read_csv("neg_cities_alice.csv")
data = pd.concat([city_alice, neg_city_alice])
#scramble rows
np.random.seed(123)
data = data.sample(frac = 1).reset_index(drop = True)
data.head(10)

Unnamed: 0,statement,label,city,country,correct_country,has_alice,has_not,has_alice xor has_not,has_alice xor label,has_not xor label,has_alice xor has_not xor label
0,Alice: The city of Hegang is in China.,1,Hegang,China,China,True,False,True,False,True,False
1,Bob: The city of Cangzhou is in China.,1,Cangzhou,China,China,False,False,False,True,True,True
2,Alice: The city of Mogadishu is in Somalia.,1,Mogadishu,Somalia,Somalia,True,False,True,False,True,False
3,Bob: The city of Laibin is in Ethiopia.,0,Laibin,Ethiopia,China,False,False,False,False,False,False
4,Bob: The city of Feira de Santana is in Brazil.,1,Feira de Santana,Brazil,Brazil,False,False,False,True,True,True
5,Alice: The city of Karachi is not in Pakistan.,0,Karachi,Pakistan,Pakistan,True,True,False,True,True,False
6,Alice: The city of Luohe is in Brazil.,0,Luohe,Brazil,China,True,False,True,True,False,True
7,Bob: The city of Tlalpan is not in Mexico.,0,Tlalpan,Mexico,Mexico,False,True,True,False,True,True
8,Bob: The city of Bekasi is in China.,0,Bekasi,China,Indonesia,False,False,False,False,False,False
9,Alice: The city of Langfang is not in India.,1,Langfang,India,China,True,True,False,False,False,True


In [4]:
model2b.cfg

HookedTransformerConfig:
{'NTK_by_parts_factor': 8.0,
 'NTK_by_parts_high_freq_factor': 4.0,
 'NTK_by_parts_low_freq_factor': 1.0,
 'act_fn': 'gelu_pytorch_tanh',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 16.0,
 'attn_scores_soft_cap': 50.0,
 'attn_types': ['global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'globa

In [5]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res-canonical",
    sae_id = "layer_19/width_16k/canonical",
    device = "cuda"
)

In [6]:
a = model2b.generate("Bing bong ding dong", max_new_tokens = 10)
a

  0%|          | 0/10 [00:00<?, ?it/s]

"Bing bong ding dong, I'm blowing this one, 9"

In [10]:
some_batch = model2b.tokenizer(["Ding don bing bong", "King kong ding dong whoop whoop"], return_tensors = "pt", padding = True).to("cuda")

In [11]:
some_batch['attention_mask']

tensor([[1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')

In [25]:
leb = model2b.run_with_cache_with_saes(some_batch['input_ids'],saes = sae, names_filter = lambda name: name in ['blocks.19.hook_resid_post.hook_sae_recons', 'blocks.19.hook_resid_post.hook_sae_input'])

In [44]:
leb[1]['blocks.19.hook_resid_post.hook_sae_recons']
# tensor([[[ 6.3655e+01,  2.7575e+02, -1.7398e+02,  ...,  2.9820e+02,
        #   -6.2675e+01,  2.9702e+01],
        #  [-4.8884e+00, -6.8970e-01,  4.2473e-01,  ...,  2.6384e+00,
        #    7.4499e-01,  1.6682e-01],
        #  [-3.9028e+00, -8.6650e-01, -4.0607e+00,  ...,  1.5525e+00,
        #    5.1970e+00, -1.1291e-01],
        #  [-5.4336e+00,  1.4239e+00,  2.3322e+00,  ...,  1.5917e+00,
        #   -5.3451e-01,  2.7083e+00],
        #  [-9.6548e+00,  3.2607e+00,  4.4400e+00,  ..., -2.2930e+00,
        #    1.9895e+00,  2.9883e+00]],

tensor([[[ 6.3655e+01,  2.7575e+02, -1.7398e+02,  ...,  2.9820e+02,
          -6.2675e+01,  2.9702e+01],
         [-4.8884e+00, -6.8970e-01,  4.2473e-01,  ...,  2.6384e+00,
           7.4499e-01,  1.6682e-01],
         [-3.9028e+00, -8.6650e-01, -4.0607e+00,  ...,  1.5525e+00,
           5.1970e+00, -1.1291e-01],
         [-5.4336e+00,  1.4239e+00,  2.3322e+00,  ...,  1.5917e+00,
          -5.3451e-01,  2.7083e+00],
         [-9.6548e+00,  3.2607e+00,  4.4400e+00,  ..., -2.2930e+00,
           1.9895e+00,  2.9883e+00]],

        [[ 6.3655e+01,  2.7575e+02, -1.7398e+02,  ...,  2.9820e+02,
          -6.2675e+01,  2.9702e+01],
         [-6.8911e+00,  8.6366e+00, -2.9912e+00,  ...,  7.0088e-01,
          -1.5728e+00,  1.3779e+00],
         [-4.2630e+00,  1.3513e+00, -2.4401e+00,  ..., -2.6403e+00,
          -4.0709e+00,  3.4806e+00],
         [-3.7446e+00, -6.9691e-01, -8.2747e-01,  ...,  5.4570e+00,
          -1.1321e+00,  3.9979e-01],
         [-6.1591e+00,  2.7186e+00, -5.2874e-01,  ...

In [None]:
class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits

def train_probe(get_acts, label_idx=0, batches=get_data(), lr=1e-2, epochs=1, dim=512, seed=SEED):
    t.manual_seed(seed)
    probe = Probe(dim).to('cuda')
    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    losses = []
    for epoch in range(epochs):
        for batch in tqdm(batches):
            text = batch[0]
            labels = batch[label_idx+1]
            acts = get_acts(text)
            logits = probe(acts)
            loss = criterion(logits, labels.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

    return probe, losses

In [None]:
#Other people's code

labels = [
    'has_alice',
    'has_not',
    'label',
    'has_alice xor has_not',
    'has_alice xor label',
    'has_not xor label',
    'has_alice xor has_not xor label',
]
    
accs = {}
for label in labels:
    dm = DataManager()
    for dataset in ['cities_alice', 'neg_cities_alice']:
        dm.add_dataset(dataset, 'llama-2-13b-reset', 14, label=label, center=False, split=0.8)
    acts, labels = dm.get('train')
    probe = LRProbe.from_data(acts, labels, bias=True)
    acts, labels = dm.get('val')
    acc = (probe(acts).round() == labels).float().mean()
    accs[label] = acc