In [1]:
import sys
sys.path.insert(0, '../')

from datasets import load_dataset
import random
from nnsight import LanguageModel
import torch as t
from torch import nn
import matplotlib.pyplot as plt
from attribution import patching_effect
from dictionary_learning import AutoEncoder, ActivationBuffer
from dictionary_learning.interp import examine_dimension
from dictionary_learning.utils import zst_to_generator

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(
Matplotlib created a temporary cache directory at /tmp/matplotlib-8czlsmw3 because the default path (/share/u/smarks/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [2]:
# dataset hyperparameters
dataset = load_dataset("LabHC/bias_in_bios")
profession_dict = {'professor' : 21, 'nurse' : 13}
male_prof = 'professor'
female_prof = 'nurse'

# model hyperparameters
DEVICE = 'cuda:0'
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE)
activation_dim = 512
layer = 4

# dictionary hyperparameters
dict_id = 10
expansion_factor = 64
dictionary_size = expansion_factor * activation_dim

# data preparation hyperparameters
batch_size = 1024
SEED = 42

def get_data(train=True, ambiguous=True, batch_size=128, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg), len(pos)])
        neg, pos = neg[:n], pos[:n]
        data = neg + pos
        labels = [0]*n + [1]*n
        idxs = list(range(2*n))
        random.Random(seed).shuffle(idxs)
        data, labels = [data[i] for i in idxs], [labels[i] for i in idxs]
        true_labels = spurious_labels = labels
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg_neg), len(neg_pos), len(pos_neg), len(pos_pos)])
        neg_neg, neg_pos, pos_neg, pos_pos = neg_neg[:n], neg_pos[:n], pos_neg[:n], pos_pos[:n]
        data = neg_neg + neg_pos + pos_neg + pos_pos
        true_labels     = [0]*n + [0]*n + [1]*n + [1]*n
        spurious_labels = [0]*n + [1]*n + [0]*n + [1]*n
        idxs = list(range(4*n))
        random.Random(seed).shuffle(idxs)
        data, true_labels, spurious_labels = [data[i] for i in idxs], [true_labels[i] for i in idxs], [spurious_labels[i] for i in idxs]

    batches = [
        (data[i:i+batch_size], t.tensor(true_labels[i:i+batch_size], device=DEVICE), t.tensor(spurious_labels[i:i+batch_size], device=DEVICE)) for i in range(0, len(data), batch_size)
    ]

    return batches

def get_subgroups(train=True, ambiguous=True, batch_size=128, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_labels, pos_labels = (0, 0), (1, 1)
        subgroups = [(neg, neg_labels), (pos, pos_labels)]
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_neg_labels, neg_pos_labels, pos_neg_labels, pos_pos_labels = (0, 0), (0, 1), (1, 0), (1, 1)
        subgroups = [(neg_neg, neg_neg_labels), (neg_pos, neg_pos_labels), (pos_neg, pos_neg_labels), (pos_pos, pos_pos_labels)]
    
    out = {}
    for data, label_profile in subgroups:
        out[label_profile] = []
        for i in range(0, len(data), batch_size):
            text = data[i:i+batch_size]
            out[label_profile].append(
                (
                    text,
                    t.tensor([label_profile[0]]*len(text), device=DEVICE),
                    t.tensor([label_profile[1]]*len(text), device=DEVICE)
                )
            )
    return out

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
# probe training hyperparameters
lr = 1e-2
epochs = 1

class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits.sigmoid()

def train_probe(get_acts, label_idx=0, batches=get_data(), lr=1e-2, epochs=1, seed=SEED):
    t.manual_seed(seed)
    probe = Probe(512).to('cuda:0')
    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCELoss()

    losses = []
    for epoch in range(epochs):
        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1] 
            acts = get_acts(text)
            logits = probe(acts)
            loss = criterion(logits, labels.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

    return probe, losses

def test_probe(probe, get_acts, label_idx=0, batches=get_data(train=False), seed=SEED):
    with t.no_grad():
        corrects = []

        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1]
            acts = get_acts(text)
            logits = probe(acts)
            preds = (logits > 0.5).long()
            corrects.append((preds == labels).float())
        return t.cat(corrects).mean().item()

In [4]:
def get_acts(text):
    with model.invoke(text):
        acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
    return acts.value.clone()

In [5]:
probe0, _ = train_probe(get_acts, label_idx=0, batches=get_data(ambiguous=False), lr=lr, epochs=epochs)
probe1, _ = train_probe(get_acts, label_idx=1, batches=get_data(ambiguous=False), lr=lr, epochs=epochs)

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [6]:
batches = get_data(train=False, ambiguous=False)
print('Probe 0 accuracy:', test_probe(probe0, get_acts, batches=batches, label_idx=0))
print('Probe 1 accuracy:', test_probe(probe1, get_acts, batches=batches, label_idx=1))

Probe 0 accuracy: 0.900921642780304
Probe 1 accuracy: 0.9688940048217773


In [7]:
probe, _ = train_probe(get_acts, label_idx=0, lr=lr, epochs=epochs)
print('Ambiguous test accuracy:', test_probe(probe, get_acts, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.9915195107460022
Ground truth accuracy: 0.6261520981788635
Spurious accuracy: 0.8680875301361084


In [8]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9941375851631165
Accuracy for (0, 1): 0.13586179912090302
Accuracy for (1, 0): 0.38479262590408325
Accuracy for (1, 1): 0.9881505370140076


In [4]:
submodules = [
    model.gpt_neox.layers[i] for i in range(layer + 1)
]
dictionaries = {}
for i in range(layer + 1):
    ae = AutoEncoder(512, 64 * 512).to(DEVICE)
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt'))
    dictionaries[submodules[i]] = ae

def metric_fn(model):
    return probe(model.gpt_neox.layers[layer].output[0][:,-1,:])

In [10]:
neg_inputs, pos_inputs = [], []
for x in dataset['train']:
    if x['profession'] == profession_dict[male_prof] and x['gender'] == 0 and len(neg_inputs) < 16:
        neg_inputs.append(x['hard_text'])
    if x['profession'] == profession_dict[female_prof] and x['gender'] == 1 and len(pos_inputs) < 16:
        pos_inputs.append(x['hard_text'])

neg_effects = patching_effect(
    neg_inputs,
    None,
    model,
    submodules,
    dictionaries,
    metric_fn,
    method='ig'
).effects
neg_effects = {k : v.sum(dim=1).mean(dim=0) for k, v in neg_effects.items()}

pos_effects = patching_effect(
    pos_inputs,
    None,
    model,
    submodules,
    dictionaries,
    metric_fn,
    method='ig'
).effects

pos_effects = {k : v.sum(dim=1).mean(dim=0) for k, v in pos_effects.items()}

In [11]:
threshold = 0.005

for i, submodule in enumerate(submodules):
    print(f"Layer {i}:")
    effect = neg_effects[submodule].act
    for feature_idx in t.nonzero(effect):
        value = effect[tuple(feature_idx)]
        if value > threshold:
            print(f"    Multindex: {tuple(feature_idx.tolist())}, Value: {value}")

for i, submodule in enumerate(submodules):
    print(f"Layer {i}:")
    effect = pos_effects[submodule].act
    for feature_idx in t.nonzero(effect):
        value = effect[tuple(feature_idx)]
        if -value > threshold:
            print(f"    Multindex: {tuple(feature_idx.tolist())}, Value: {value}")

Layer 0:
    Multindex: (4074,), Value: 0.007796672638505697
    Multindex: (10282,), Value: 0.005953371524810791
    Multindex: (14149,), Value: 0.008430274203419685
    Multindex: (18775,), Value: 0.016031203791499138
    Multindex: (18967,), Value: 0.044523485004901886
    Multindex: (22084,), Value: 0.014413665048778057
    Multindex: (23255,), Value: 0.006609784439206123
    Multindex: (23898,), Value: 0.00813205074518919
    Multindex: (24435,), Value: 0.006274200975894928
    Multindex: (29626,), Value: 0.012494005262851715
    Multindex: (31616,), Value: 0.006371437571942806
Layer 1:
    Multindex: (2995,), Value: 0.013699337840080261
    Multindex: (8920,), Value: 0.029961880296468735
    Multindex: (12128,), Value: 0.02386729046702385
    Multindex: (12436,), Value: 0.00632051657885313
Layer 2:
    Multindex: (4433,), Value: 0.014091068878769875
    Multindex: (4539,), Value: 0.015132302418351173
    Multindex: (8944,), Value: 0.023961298167705536
    Multindex: (10700,), Val

In [8]:
component_idx = 4
feat_idx = 12420

submodule = submodules[component_idx]
dictionary = dictionaries[submodule]

# interpret some features
data = zst_to_generator('/share/data/datasets/pile/the-eye.eu/public/AI/pile/train/00.jsonl.zst')
buffer = ActivationBuffer(
    data,
    model,
    submodule,
    out_feats=512,
    in_batch_size=128,
    n_ctxs=512,
)

out = examine_dimension(
    model,
    submodule,
    buffer,
    dictionary,
    dim_idx=feat_idx,
    n_inputs=256
)
print(out.top_tokens)
print(out.top_affected)
out.top_contexts

[(' leaves', 3.1683294773101807), (' piece', 2.8998165130615234), (' telling', 2.8359341621398926), (' total', 2.607029914855957), (' tells', 2.443530559539795), (' needing', 2.205601930618286), (' notice', 2.185490608215332), (' finish', 2.0870466232299805), ('accept', 1.889017105102539), (' goes', 1.8599400520324707), (' [...]', 1.8255202770233154), (' 1960', 1.7827792167663574), (' donors', 1.685542345046997), (' Lady', 1.6427953243255615), (' Although', 1.6094988584518433), (' needs', 1.589821219444275), ('ret', 1.516148567199707), (' donations', 1.4957566261291504), (' counting', 1.4677677154541016), (' impressed', 1.4677202701568604), (' jumping', 1.4470441341400146), (' Instead', 1.440384864807129), (' jumped', 1.4324177503585815), ('iza', 1.3016449213027954), (' knees', 1.225874900817871), (' career', 1.162060022354126), (' puzzle', 1.1596356630325317), (' participate', 1.0996900796890259), (' inspired', 0.989930272102356), (' mate', 0.9897094964981079)]
[(' her', 1.75376546382

In [18]:
to_ablate = {
    submodules[0] : [
        # 4074, # predicts code
        # 18775, # certain periods, unclear
        18967, # 'He'
        22084, # 'he'
        29626, # 'his'
        # 951, # 'with'
        1022, # 'she'
        # 1692, # 'or'
        2079, # 'Woman' or 'Ladies'
        # 2493, # 'nurse(s)'
        3122, # 'Her'
        # 5648, # 'care' in medical context
        # 5950, # 'to'
        # 9610, # certain periods, unclear
        9651, # female names
        10060, # 'She'
        # 11778, # 'phone(s)'
        # 13675, # certain verbs
        # 15032, # unclear
        # 23666, # 'and'
        # 24418, # unclear,
        26504, # 'her'
        # 26586, # unclear
        # 31201, # 'nursing'
    ],
    submodules[1] : [
        2995, # 'He'
        8920, # 'he'
        12128, # 'his'
        # 12436, # gendered pronouns?
        4592, # 'her'
        9877, # female names
        # 12882, # 'share'
        # 14918, # certain periods, unclear
        15017, # 'she'
        # 17369, # predicts phone numbers
        26204, # 'Her'
        # 26476, # certain periods, unclear
        # 26969, # related to nursing
        # 28145, # 'medical' 
        30248, # female names
    ],
    submodules[2] : [
        4433, # promotes male-associated words
        4539, # gendered pronouns
        8944, # capitalized gendered pronouns
        # 10700, # unclear
        11656, # promotes male-associated words
        # 14559, # periods in biographies, might promote male words a bit
        # 15225, # something about dates?
        # 15938, # unclear, might promote male words a bit
        # 19014, # certain periods, unclear
        # 27803, # unclear
        29206, # gendered pronouns
        30263, # gendered pronouns
        # 277, # spammy advertisements
        1995, # promotes female-associated words
        # 4770, # unclear
        9128, # female pronouns
        10635, # female-associated word
        # 10757, # unclear
        12440, # promotes female-associated words
        # 14638, # related to contact information?
        # 17774, # related to terms and conditions?
        # 21331, # certain periods, unclear
        # 26413, # fires on periods in ?blog posts/first-person writing?
        # 27838, # promotes parts of urls
        29295, # female names
        # 29371, # certain periods, unclear
        # 31098, # nursing-related words
    ],
    submodules[3] : [
        # 93, # periods in medical contexts
        # 2751, # words related to research
        # 13474, # active in bios of medical professionals
        # 15246, # promotes words about art
        # 24661, # periods in bios
        27334, # promotes male-associated words
        # 27867, # periods in bios, might promote male words a bit
        # 7539, # unclear 
        # 10295, # certain periods
        # 13542, # promotes advertising words
        # 14401, # unclear
        19558, # promotes female-associated words
        # 20526, # promotes medical words
        22152, # promotes female-associated words
        # 23375, # addresses
        23545, # 'she'
        # 24484, # promotes ?names of diseases?
        24806, # 'her'
        # 27051, # promotes capitalized words
        30802, # 'woman'/'women'
        # 31182, # contact information
        # 31751, # accommodation words
    ],
    submodules[4] : [
        # 4125, # periods in bios
        # 11987, # promotes verbs in bios
        # 12332, # promotes words related to academia
        # 14658, # periods
        30220, # promotes male pronouns
        # 1804, # words related to RSVPing
        # 4926, # certain periods
        # 5731, # promotes medicinal words
        # 6869, # promotes verbs about phone calls
        9766, # promotes female-associated words
        12420, # promotes female pronouns
        # 20708, # contact info
        # 21979, # promotes capitalized words
        23207, # promotes gendered words, especial female
        # 30612, # fundraising
        # 31282, # capitalized words about contacting
    ]
        
}

In [19]:
def run_with_ablations(
        model,
        inputs,
        submodules,
        dictionaries,
        to_ablate,
        inference=True,
):
    with model.invoke(inputs, fwd_args={'inference': inference}):
        for submodule in submodules:
            dictionary = dictionaries[submodule]
            x = submodule.output
            is_resid = (type(x.shape) == tuple)
            if is_resid:
                x = x[0]
            x_hat = dictionary(x)
            residual = x - x_hat

            f = dictionary.encode(x)
            ablation_idxs = t.Tensor(to_ablate[submodule]).long()
            f[:, :, ablation_idxs] = 0.
            x_hat = dictionary.decode(f)
            if is_resid:
                submodule.output[0][:] = x_hat + residual
            else:
                submodule.output = x_hat + residual
            
        acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
    return acts.value.clone()

In [20]:
def get_acts_abl(text):
    return run_with_ablations(
        model,
        text,
        submodules,
        dictionaries,
        to_ablate,
    )

In [21]:
print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.8661710023880005
Ground truth accuracy: 0.8254608511924744
Spurious accuracy: 0.5547235012054443


In [22]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9879666566848755
Accuracy for (0, 1): 0.9644010066986084
Accuracy for (1, 0): 0.559907853603363
Accuracy for (1, 1): 0.7453531622886658


In [23]:
new_probe, _ = train_probe(get_acts_abl, label_idx=0, lr=lr, epochs=epochs)
print('Ambiguous test accuracy:', test_probe(new_probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.9350603818893433
Ground truth accuracy: 0.89573734998703
Spurious accuracy: 0.5339861512184143


In [24]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9301449656486511
Accuracy for (0, 1): 0.8476905226707458
Accuracy for (1, 0): 0.8709677457809448
Accuracy for (1, 1): 0.9442378878593445


In [92]:
def out_fn(model):
    return probe(model.gpt_neox.layers[layer].output[0][:,-1,:])

# get accuracy after ablating above features
batches = get_data(train=False, ambiguous=True, batch_size=128, seed=SEED)

with t.no_grad():
    corrects = []
    for batch in batches:
        text = batch[0]
        labels = batch[1]
        probs = run_with_ablations(
            model,
            text,
            submodules,
            dictionaries,
            to_ablate,
            out_fn
        )
        preds = (probs > 0.5).long()
        corrects.append((preds == labels).float())
    print("Accuracy on ambiguous data:", t.cat(corrects).mean().item())

batches = get_data(train=False, ambiguous=False, batch_size=128, seed=SEED)

with t.no_grad():
    truth_corrects, spurious_corrects = [], []
    for batch in batches:
        text = batch[0]
        true_labels = batch[1]
        spurious_labels = batch[2]
        probs = run_with_ablations(
            model,
            text,
            submodules,
            dictionaries,
            to_ablate,
            out_fn
        )
        preds = (probs > 0.5).long()
        truth_corrects.append((preds == true_labels).float())
        spurious_corrects.append((preds == spurious_labels).float())
    print("Ground truth accuracy", t.cat(truth_corrects).mean().item())
    print("Spurious accuracy", t.cat(spurious_corrects).mean().item())

Accuracy on ambiguous data: 0.4993029832839966
Ground truth accuracy 0.4976958632469177
Spurious accuracy 0.5011520981788635


In [None]:
# retrain probe with ablated model

def out_fn(model):
    return model.gpt_neox.layers[layer].output[0][:,-1,:]

t.manual_seed(SEED)
new_probe = Probe(512).to('cuda:0')
optimizer = t.optim.AdamW(new_probe.parameters(), lr=lr)
criterion = nn.BCELoss()

batches = get_data(train=True, ambiguous=True, batch_size=64, seed=SEED)

losses = []
for epoch in range(epochs):
    for batch in batches:
        text = batch[0]
        labels = batch[1]
        acts = run_with_ablations(
            model,
            text,
            submodules,
            dictionaries,
            to_ablate,
            out_fn
        ).clone()
        probs = new_probe(acts)
        loss = criterion(probs, labels.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
