In [1]:
import sys
sys.path.insert(0, '../')

from datasets import load_dataset
import random
from nnsight import LanguageModel
import torch as t
from torch import nn
import matplotlib.pyplot as plt
from attribution import patching_effect
from dictionary_learning import AutoEncoder, ActivationBuffer
from dictionary_learning.dictionary import IdentityDict
from dictionary_learning.interp import examine_dimension
from dictionary_learning.utils import zst_to_generator
from tqdm import tqdm
import gc

DEBUGGING = False

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

  from .autonotebook import tqdm as notebook_tqdm
Matplotlib created a temporary cache directory at /tmp/matplotlib-ws3j_0wx because the default path (/share/u/smarks/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [2]:
# dataset hyperparameters
dataset = load_dataset("LabHC/bias_in_bios")
profession_dict = {'professor' : 21, 'nurse' : 13}
male_prof = 'professor'
female_prof = 'nurse'

# model hyperparameters
DEVICE = 'cuda:0'
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE, dispatch=True)
activation_dim = 512
layer = 4

# dictionary hyperparameters
dict_id = 10
expansion_factor = 64
dictionary_size = expansion_factor * activation_dim

# data preparation hyperparameters
batch_size = 1024
SEED = 42

def get_data(train=True, ambiguous=True, batch_size=128, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg), len(pos)])
        neg, pos = neg[:n], pos[:n]
        data = neg + pos
        labels = [0]*n + [1]*n
        idxs = list(range(2*n))
        random.Random(seed).shuffle(idxs)
        data, labels = [data[i] for i in idxs], [labels[i] for i in idxs]
        true_labels = spurious_labels = labels
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg_neg), len(neg_pos), len(pos_neg), len(pos_pos)])
        neg_neg, neg_pos, pos_neg, pos_pos = neg_neg[:n], neg_pos[:n], pos_neg[:n], pos_pos[:n]
        data = neg_neg + neg_pos + pos_neg + pos_pos
        true_labels     = [0]*n + [0]*n + [1]*n + [1]*n
        spurious_labels = [0]*n + [1]*n + [0]*n + [1]*n
        idxs = list(range(4*n))
        random.Random(seed).shuffle(idxs)
        data, true_labels, spurious_labels = [data[i] for i in idxs], [true_labels[i] for i in idxs], [spurious_labels[i] for i in idxs]

    batches = [
        (data[i:i+batch_size], t.tensor(true_labels[i:i+batch_size], device=DEVICE), t.tensor(spurious_labels[i:i+batch_size], device=DEVICE)) for i in range(0, len(data), batch_size)
    ]

    return batches

def get_subgroups(train=True, ambiguous=True, batch_size=128, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_labels, pos_labels = (0, 0), (1, 1)
        subgroups = [(neg, neg_labels), (pos, pos_labels)]
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_neg_labels, neg_pos_labels, pos_neg_labels, pos_pos_labels = (0, 0), (0, 1), (1, 0), (1, 1)
        subgroups = [(neg_neg, neg_neg_labels), (neg_pos, neg_pos_labels), (pos_neg, pos_neg_labels), (pos_pos, pos_pos_labels)]
    
    out = {}
    for data, label_profile in subgroups:
        out[label_profile] = []
        for i in range(0, len(data), batch_size):
            text = data[i:i+batch_size]
            out[label_profile].append(
                (
                    text,
                    t.tensor([label_profile[0]]*len(text), device=DEVICE),
                    t.tensor([label_profile[1]]*len(text), device=DEVICE)
                )
            )
    return out

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
# probe training hyperparameters
lr = 1e-2
epochs = 1

class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits

def train_probe(get_acts, label_idx=0, batches=get_data(), lr=1e-2, epochs=1, seed=SEED):
    t.manual_seed(seed)
    probe = Probe(512).to('cuda:0')
    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    losses = []
    for epoch in range(epochs):
        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1] 
            acts = get_acts(text)
            logits = probe(acts)
            loss = criterion(logits, labels.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

    return probe, losses

def test_probe(probe, get_acts, label_idx=0, batches=get_data(train=False), seed=SEED):
    with t.no_grad():
        corrects = []

        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1]
            acts = get_acts(text)
            logits = probe(acts)
            preds = (logits > 0.0).long()
            corrects.append((preds == labels).float())
        return t.cat(corrects).mean().item()

In [4]:
def get_acts(text):
    with model.trace(text, **tracer_kwargs), t.no_grad():
        acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
    return acts.value

In [5]:
probe0, _ = train_probe(get_acts, label_idx=0, batches=get_data(ambiguous=False), lr=lr, epochs=epochs)
probe1, _ = train_probe(get_acts, label_idx=1, batches=get_data(ambiguous=False), lr=lr, epochs=epochs)

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [6]:
batches = get_data(train=False, ambiguous=False)
print('Probe 0 accuracy:', test_probe(probe0, get_acts, batches=batches, label_idx=0))
print('Probe 1 accuracy:', test_probe(probe1, get_acts, batches=batches, label_idx=1))
del probe0, probe1
gc.collect()

Probe 0 accuracy: 0.900921642780304
Probe 1 accuracy: 0.9688940048217773


6046

In [7]:
probe, _ = train_probe(get_acts, label_idx=0, lr=lr, epochs=epochs)
print('Ambiguous test accuracy:', test_probe(probe, get_acts, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.9915195107460022
Ground truth accuracy: 0.6261520981788635
Spurious accuracy: 0.8680875301361084


In [8]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9941375851631165
Accuracy for (0, 1): 0.13586179912090302
Accuracy for (1, 0): 0.38479262590408325
Accuracy for (1, 1): 0.9881505370140076


In [5]:
submodules = [
    model.gpt_neox.layers[i] for i in range(layer + 1)
]
dictionaries = {}
for i in range(layer + 1):
    ae = AutoEncoder(512, 64 * 512).to(DEVICE)
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt'))
    dictionaries[submodules[i]] = ae

def metric_fn(model, labels=None):
    return t.where(
        labels == 0,
        probe(model.gpt_neox.layers[layer].output[0][:,-1,:]),
        - probe(model.gpt_neox.layers[layer].output[0][:,-1,:]) # NOTE: 1 - probe if using sigmoid
    )

In [10]:
n_batches = 25
batch_size = 4

running_total = 0
nodes = None

for batch_idx, (clean, labels, _) in tqdm(enumerate(get_data(train=True, ambiguous=True, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if nodes is None:
            nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)
    del effects, _
    gc.collect()

nodes = {k : v / running_total for k, v in nodes.items()}

100%|██████████| 25/25 [00:21<00:00,  1.17it/s]


In [11]:
n_features = 0
for i, effect in zip(range(layer+1), nodes.values()):
    print(f"Layer {i}:")
    for idx in (effect.act.abs() > 0.1).nonzero():
        print(idx.item(), effect.act[idx].item())
        n_features += 1
print(f"total features: {n_features}")

Layer 0:
241 -0.10441476851701736
1022 0.2669174075126648
3122 0.11432410776615143
4074 -0.5427650213241577
9610 0.2321901023387909
9651 0.7651873826980591
10060 0.8378416299819946
10282 0.14649847149848938
18967 0.6124393939971924
22084 0.3060828149318695
23255 0.2215987890958786
23898 0.36252906918525696
24418 0.10960207879543304
24435 0.2099316418170929
26504 0.4111306667327881
29626 0.2115512490272522
Layer 1:
2995 0.17399932444095612
4592 0.46434274315834045
8920 0.656842052936554
9877 0.41685497760772705
10115 0.10725012421607971
12128 0.5694507360458374
14918 0.21164773404598236
15017 1.2835333347320557
17369 0.22672824561595917
18585 0.2379533350467682
26476 -0.13248230516910553
30248 0.7422191500663757
Layer 2:
1995 0.5639375448226929
8944 0.24482528865337372
9128 1.0962494611740112
11656 0.1112290546298027
14559 0.3551482558250427
14638 0.346517413854599
17961 -0.11233101785182953
21331 0.2790890336036682
26413 0.14219459891319275
27838 0.11404399573802948
29206 0.20579946041

In [8]:
component_idx = 4
feat_idx = 12420

submodule = submodules[component_idx]
dictionary = dictionaries[submodule]

# interpret some features
data = zst_to_generator('/share/data/datasets/pile/the-eye.eu/public/AI/pile/train/00.jsonl.zst')
buffer = ActivationBuffer(
    data,
    model,
    submodule,
    out_feats=512,
    in_batch_size=128,
    n_ctxs=512,
)

out = examine_dimension(
    model,
    submodule,
    buffer,
    dictionary,
    dim_idx=feat_idx,
    n_inputs=256
)
print(out.top_tokens)
print(out.top_affected)
out.top_contexts

[(' leaves', 3.1683294773101807), (' piece', 2.8998165130615234), (' telling', 2.8359341621398926), (' total', 2.607029914855957), (' tells', 2.443530559539795), (' needing', 2.205601930618286), (' notice', 2.185490608215332), (' finish', 2.0870466232299805), ('accept', 1.889017105102539), (' goes', 1.8599400520324707), (' [...]', 1.8255202770233154), (' 1960', 1.7827792167663574), (' donors', 1.685542345046997), (' Lady', 1.6427953243255615), (' Although', 1.6094988584518433), (' needs', 1.589821219444275), ('ret', 1.516148567199707), (' donations', 1.4957566261291504), (' counting', 1.4677677154541016), (' impressed', 1.4677202701568604), (' jumping', 1.4470441341400146), (' Instead', 1.440384864807129), (' jumped', 1.4324177503585815), ('iza', 1.3016449213027954), (' knees', 1.225874900817871), (' career', 1.162060022354126), (' puzzle', 1.1596356630325317), (' participate', 1.0996900796890259), (' inspired', 0.989930272102356), (' mate', 0.9897094964981079)]
[(' her', 1.75376546382

In [6]:
feats_to_ablate = {
    submodules[0] : [
        # 4074, # predicts code
        # 18775, # certain periods, unclear
        18967, # 'He'
        22084, # 'he'
        29626, # 'his'
        # 951, # 'with'
        1022, # 'she'
        # 1692, # 'or'
        2079, # 'Woman' or 'Ladies'
        # 2493, # 'nurse(s)'
        3122, # 'Her'
        # 5648, # 'care' in medical context
        # 5950, # 'to'
        # 9610, # certain periods, unclear
        9651, # female names
        10060, # 'She'
        # 11778, # 'phone(s)'
        # 13675, # certain verbs
        # 15032, # unclear
        # 23666, # 'and'
        # 24418, # unclear,
        26504, # 'her'
        # 26586, # unclear
        # 31201, # 'nursing'
    ],
    submodules[1] : [
        2995, # 'He'
        8920, # 'he'
        12128, # 'his'
        # 12436, # gendered pronouns?
        4592, # 'her'
        9877, # female names
        # 12882, # 'share'
        # 14918, # certain periods, unclear
        15017, # 'she'
        # 17369, # predicts phone numbers
        26204, # 'Her'
        # 26476, # certain periods, unclear
        # 26969, # related to nursing
        # 28145, # 'medical' 
        30248, # female names
    ],
    submodules[2] : [
        4433, # promotes male-associated words
        4539, # gendered pronouns
        8944, # capitalized gendered pronouns
        # 10700, # unclear
        11656, # promotes male-associated words
        # 14559, # periods in biographies, might promote male words a bit
        # 15225, # something about dates?
        # 15938, # unclear, might promote male words a bit
        # 19014, # certain periods, unclear
        # 27803, # unclear
        29206, # gendered pronouns
        30263, # gendered pronouns
        # 277, # spammy advertisements
        1995, # promotes female-associated words
        # 4770, # unclear
        9128, # female pronouns
        10635, # female-associated word
        # 10757, # unclear
        12440, # promotes female-associated words
        # 14638, # related to contact information?
        # 17774, # related to terms and conditions?
        # 21331, # certain periods, unclear
        # 26413, # fires on periods in ?blog posts/first-person writing?
        # 27838, # promotes parts of urls
        29295, # female names
        # 29371, # certain periods, unclear
        # 31098, # nursing-related words
    ],
    submodules[3] : [
        # 93, # periods in medical contexts
        # 2751, # words related to research
        # 13474, # active in bios of medical professionals
        # 15246, # promotes words about art
        # 24661, # periods in bios
        27334, # promotes male-associated words
        # 27867, # periods in bios, might promote male words a bit
        # 7539, # unclear 
        # 10295, # certain periods
        # 13542, # promotes advertising words
        # 14401, # unclear
        19558, # promotes female-associated words
        # 20526, # promotes medical words
        22152, # promotes female-associated words
        # 23375, # addresses
        23545, # 'she'
        # 24484, # promotes ?names of diseases?
        24806, # 'her'
        # 27051, # promotes capitalized words
        30802, # 'woman'/'women'
        # 31182, # contact information
        # 31751, # accommodation words
    ],
    submodules[4] : [
        # 4125, # periods in bios
        # 11987, # promotes verbs in bios
        # 12332, # promotes words related to academia
        # 14658, # periods
        30220, # promotes male pronouns
        # 1804, # words related to RSVPing
        # 4926, # certain periods
        # 5731, # promotes medicinal words
        # 6869, # promotes verbs about phone calls
        9766, # promotes female-associated words
        12420, # promotes female pronouns
        # 20708, # contact info
        # 21979, # promotes capitalized words
        23207, # promotes gendered words, especial female
        # 30612, # fundraising
        # 31282, # capitalized words about contacting
    ]      
}

print(f"Number of features to ablate: {sum(len(v) for v in feats_to_ablate.values())}")

Number of features to ablate: 38


In [7]:
def get_acts_abl(
    text,
    model,
    submodules,
    dictionaries,
    to_ablate
):
    with model.trace("test"), t.no_grad():
        is_tuple = {}
        for submodule in submodules:
            is_tuple[submodule] = type(submodule.output.shape) == tuple

    with model.trace(text, **tracer_kwargs), t.no_grad():
        for submodule in submodules:
            dictionary = dictionaries[submodule]
            feat_idxs = to_ablate[submodule]
            x = submodule.output
            if is_tuple[submodule]:
                x = x[0]
            x_hat, f = dictionary(x, output_features=True)
            res = x - x_hat
            for idx in feat_idxs:
                f[..., idx] = 0.
            if is_tuple[submodule]:
                submodule.output[0][:] = dictionary.decode(f) + res
            else:
                submodule.output = dictionary.decode(f) + res
        out = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
    return out.value


# Accuracy after ablating features deemed irrelevant by human annotators

In [22]:
print('Ambiguous test accuracy:', test_probe(probe, lambda text : get_acts_abl(text, model, submodules, dictionaries, feats_to_ablate), label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, lambda text : get_acts_abl(text, model, submodules, dictionaries, feats_to_ablate), batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, lambda text : get_acts_abl(text, model, submodules, dictionaries, feats_to_ablate), batches=batches, label_idx=1))

Ambiguous test accuracy: 0.8661710023880005
Ground truth accuracy: 0.8254608511924744
Spurious accuracy: 0.5547235012054443


In [23]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, lambda text : get_acts_abl(text, model, submodules, dictionaries, feats_to_ablate), batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9879666566848755
Accuracy for (0, 1): 0.9644010066986084
Accuracy for (1, 0): 0.559907853603363
Accuracy for (1, 1): 0.7453531622886658


# Get skyline neuron performance

In [24]:
# get neurons which are most influential for giving gender label
neuron_dicts = {
    submodule : IdentityDict(activation_dim).to(DEVICE) for submodule in submodules
}

n_batches = 25
batch_size = 4

running_total = 0
running_nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_data(train=True, ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        neuron_dicts,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if running_nodes is None:
            running_nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                running_nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)

nodes = {k : v / running_total for k, v in running_nodes.items()}

100%|██████████| 25/25 [00:11<00:00,  2.23it/s]


In [34]:
neurons_to_ablate = {}
for i, effect in zip(range(layer+1), nodes.values()):
    neurons_to_ablate[submodules[i]] = []
    print(f"Layer {i}:")
    for idx in (effect.act.abs() > 0.3).nonzero():
        neurons_to_ablate[submodules[i]].append(idx.item())
        print(idx.item(), effect.act[idx].item())
print(f"total neurons: {sum(len(v) for v in neurons_to_ablate.values())}")

Layer 0:
111 0.552754282951355
156 0.9466100335121155
165 -0.33742231130599976
410 0.5108632445335388
Layer 1:
14 0.3330404460430145
23 -1.9429857730865479
56 0.534831166267395
111 1.3249824047088623
148 -0.4178657829761505
156 -1.1101523637771606
248 -0.3936334252357483
258 0.5677260756492615
369 0.5060036182403564
376 0.7261002063751221
410 0.5521716475486755
478 0.4259152114391327
503 0.41798722743988037
Layer 2:
17 -0.643139123916626
23 0.5534395575523376
47 0.8207612633705139
63 0.31175464391708374
111 -1.3454006910324097
129 -0.4927549958229065
136 0.37165170907974243
137 -0.4441964328289032
156 -4.553019046783447
186 0.38707372546195984
258 0.38416942954063416
271 -0.9130510687828064
365 0.3123455047607422
369 0.6829093098640442
390 0.31727081537246704
416 0.44325557351112366
478 0.3030056357383728
Layer 3:
111 -0.5690432190895081
136 -0.40042388439178467
156 -4.656957149505615
172 -0.49626481533050537
271 -0.4498133063316345
376 0.3341180980205536
478 0.36204877495765686
Layer 

In [35]:
print('Ambiguous test accuracy:', test_probe(probe, lambda text : get_acts_abl(text, model, submodules, neuron_dicts, neurons_to_ablate), label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, lambda text : get_acts_abl(text, model, submodules, neuron_dicts, neurons_to_ablate), batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, lambda text : get_acts_abl(text, model, submodules, neuron_dicts, neurons_to_ablate), batches=batches, label_idx=1))

Ambiguous test accuracy: 0.8055297136306763
Ground truth accuracy: 0.5673962831497192
Spurious accuracy: 0.7344470024108887


In [36]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, lambda text : get_acts_abl(text, model, submodules, neuron_dicts, neurons_to_ablate), batches=batches, label_idx=0))

Accuracy for (0, 0): 0.8897870779037476
Accuracy for (0, 1): 0.4113405644893646
Accuracy for (1, 0): 0.2142857164144516
Accuracy for (1, 1): 0.7265334725379944


# Get skyline feature performance

In [37]:
n_batches = 25
batch_size = 4

running_total = 0
running_nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_data(train=True, ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if running_nodes is None:
            running_nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                running_nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)

nodes = {k : v / running_total for k, v in running_nodes.items()}

100%|██████████| 25/25 [00:16<00:00,  1.48it/s]


In [42]:
top_feats_to_ablate = {}
for i, effect in zip(range(layer+1), nodes.values()):
    top_feats_to_ablate[submodules[i]] = []
    print(f"Layer {i}:")
    for idx in (effect.act.abs() > 0.15).nonzero():
        top_feats_to_ablate[submodules[i]].append(idx.item())
        print(idx.item(), effect.act[idx].item())
print(f"total features: {sum(len(v) for v in top_feats_to_ablate.values())}")

Layer 0:
1022 0.43988922238349915
3122 0.259931743144989
4074 -0.24184541404247284
9610 0.19874531030654907
9651 0.7201375961303711
10060 1.4157994985580444
18967 0.40968555212020874
22084 0.2687976062297821
23255 0.18106205761432648
23898 0.18152987957000732
24435 0.16508552432060242
26504 0.18438325822353363
29626 0.24421903491020203
Layer 1:
2995 -0.1729012280702591
4592 0.25175902247428894
8920 0.508174479007721
9877 0.2943287491798401
12128 0.4129084646701813
14918 0.17536771297454834
15017 2.1895370483398438
18585 0.23559540510177612
26204 0.15615183115005493
30248 0.41855838894844055
Layer 2:
1995 0.7555692195892334
9128 1.641681432723999
14559 0.29228782653808594
21331 0.19326896965503693
29295 0.32329061627388
29371 0.20549967885017395
Layer 3:
19558 1.900686264038086
23545 0.6123898029327393
24806 0.19349133968353271
27334 0.24706138670444489
27867 0.18328824639320374
Layer 4:
4125 0.19962818920612335
9766 0.21931304037570953
12420 2.674450397491455
30220 0.9574550986289978
t

In [43]:
print('Ambiguous test accuracy:', test_probe(probe, lambda text : get_acts_abl(text, model, submodules, dictionaries, top_feats_to_ablate), label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, lambda text : get_acts_abl(text, model, submodules, dictionaries, top_feats_to_ablate), batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, lambda text : get_acts_abl(text, model, submodules, dictionaries, top_feats_to_ablate), batches=batches, label_idx=1))

Ambiguous test accuracy: 0.8877788186073303
Ground truth accuracy: 0.8381336331367493
Spurious accuracy: 0.5604838728904724


In [44]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, lambda text : get_acts_abl(text, model, submodules, dictionaries, top_feats_to_ablate), batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9879666566848755
Accuracy for (0, 1): 0.9487795233726501
Accuracy for (1, 0): 0.6013824939727783
Accuracy for (1, 1): 0.7899628281593323


# Retraining probe on activations after ablating features

In [11]:
print(model.gpt_neox)

ReferenceError: weakly-referenced object no longer exists

In [8]:
new_probe, _ = train_probe(lambda text : get_acts_abl(text, model, submodules, dictionaries, feats_to_ablate), label_idx=0, lr=lr, epochs=epochs)
print('Ambiguous test accuracy:', test_probe(new_probe, lambda text: get_acts_abl(text, model, submodules, dictionaries, feats_to_ablate), label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(new_probe, lambda text : get_acts_abl(text, model, submodules, dictionaries, feats_to_ablate), batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(new_probe, lambda text : get_acts_abl(text, model, submodules, dictionaries, feats_to_ablate), batches=batches, label_idx=1))

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [27136, 32768]], which is output 0 of AsStridedBackward0, is at version 4; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

In [24]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9301449656486511
Accuracy for (0, 1): 0.8476905226707458
Accuracy for (1, 0): 0.8709677457809448
Accuracy for (1, 1): 0.9442378878593445


In [42]:
# get neurons which are most influential for giving gender label
neuron_dicts = {
    submodule : IdentityDict(activation_dim).to(DEVICE) for submodule in submodules
}

n_batches = 25
batch_size = 4

running_total = 0
running_nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_data(train=True, ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        neuron_dicts,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if running_nodes is None:
            running_nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                running_nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)

nodes = {k : v / running_total for k, v in running_nodes.items()}

100%|██████████| 25/25 [00:16<00:00,  1.48it/s]


In [46]:
n_features = 0
to_ablate = {}
for i, effect in zip(range(layer+1), nodes.values()):
    to_ablate[submodules[i]] = []
    print(f"Layer {i}:")
    for idx in (effect.act.abs() > 0.3).nonzero():
        to_ablate[submodules[i]].append(idx.item())
        print(idx.item(), effect.act[idx].item())
        n_features += 1
print(f"total features: {n_features}")

Layer 0:
1022 0.8367617726325989
3122 0.5025932192802429
4074 -0.41542863845825195
9610 0.3709331750869751
9651 1.3308827877044678
10060 2.6564531326293945
18967 0.6820940375328064
22084 0.4557006359100342
23255 0.3255564570426941
24435 0.30812686681747437
26504 0.3546774685382843
29626 0.4150571823120117
Layer 1:
2995 -0.3428948223590851
4592 0.4850669801235199
8920 0.8407686352729797
9877 0.5462824702262878
12128 0.6813495755195618
14918 0.31241682171821594
15017 4.105208396911621
18585 0.4312697947025299
26204 0.30695322155952454
30248 0.7768656611442566
Layer 2:
1995 1.3887615203857422
9128 3.060696840286255
14559 0.5412063598632812
21331 0.3440110981464386
29295 0.6148474812507629
29371 0.3843125104904175
Layer 3:
19558 3.5814743041992188
23545 1.1425334215164185
24806 0.3723568022251129
27334 0.40034037828445435
27867 0.31886884570121765
Layer 4:
4125 0.3233107030391693
9766 0.3931291997432709
12420 5.099288463592529
30220 1.5810459852218628
total features: 37


In [47]:
circuit = {}
for submod in submodules:
    submod_circuit = SparseAct(act=t.zeros(32768, dtype=t.bool), resc=t.zeros(1, dtype=t.bool)).to(DEVICE)
    for idx in to_ablate[submod]:
        submod_circuit.act[idx] = True
    circuit[submod] = submod_circuit

In [49]:
def get_acts_abl(text):
    with model.trace(text), t.no_grad():
        for submodule in submodules:
            dictionary = dictionaries[submodule] #neuron_dicts[submodule]
            submod_nodes = nodes[submodule].clone()
            x = submodule.output
            is_tuple = type(x.shape) == tuple
            if is_tuple:
                x = x[0]
            f = dictionary.encode(x)
            res = x - dictionary(x)

            # ablate features
            f[...,circuit[submodule].act] = 0.
            
            if is_tuple:
                submodule.output[0][:] = dictionary.decode(f) + res
            else:
                submodule.output = dictionary.decode(f) + res

        out = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
    return out.value

In [50]:
print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.9207713603973389
Ground truth accuracy: 0.8577188849449158
Spurious accuracy: 0.570852518081665


In [92]:
def out_fn(model):
    return probe(model.gpt_neox.layers[layer].output[0][:,-1,:])

# get accuracy after ablating above features
batches = get_data(train=False, ambiguous=True, batch_size=128, seed=SEED)

with t.no_grad():
    corrects = []
    for batch in batches:
        text = batch[0]
        labels = batch[1]
        probs = run_with_ablations(
            model,
            text,
            submodules,
            dictionaries,
            to_ablate,
            out_fn
        )
        preds = (probs > 0.5).long()
        corrects.append((preds == labels).float())
    print("Accuracy on ambiguous data:", t.cat(corrects).mean().item())

batches = get_data(train=False, ambiguous=False, batch_size=128, seed=SEED)

with t.no_grad():
    truth_corrects, spurious_corrects = [], []
    for batch in batches:
        text = batch[0]
        true_labels = batch[1]
        spurious_labels = batch[2]
        probs = run_with_ablations(
            model,
            text,
            submodules,
            dictionaries,
            to_ablate,
            out_fn
        )
        preds = (probs > 0.5).long()
        truth_corrects.append((preds == true_labels).float())
        spurious_corrects.append((preds == spurious_labels).float())
    print("Ground truth accuracy", t.cat(truth_corrects).mean().item())
    print("Spurious accuracy", t.cat(spurious_corrects).mean().item())

Accuracy on ambiguous data: 0.4993029832839966
Ground truth accuracy 0.4976958632469177
Spurious accuracy 0.5011520981788635


In [None]:
# retrain probe with ablated model

def out_fn(model):
    return model.gpt_neox.layers[layer].output[0][:,-1,:]

t.manual_seed(SEED)
new_probe = Probe(512).to('cuda:0')
optimizer = t.optim.AdamW(new_probe.parameters(), lr=lr)
criterion = nn.BCELoss()

batches = get_data(train=True, ambiguous=True, batch_size=64, seed=SEED)

losses = []
for epoch in range(epochs):
    for batch in batches:
        text = batch[0]
        labels = batch[1]
        acts = run_with_ablations(
            model,
            text,
            submodules,
            dictionaries,
            to_ablate,
            out_fn
        ).clone()
        probs = new_probe(acts)
        loss = criterion(probs, labels.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
