In [1]:
import sys
sys.path.insert(0, '../')

from datasets import load_dataset
import random
from nnsight import LanguageModel
import torch as t
from torch import nn
import matplotlib.pyplot as plt
from attribution import patching_effect
from dictionary_learning import AutoEncoder, ActivationBuffer
from dictionary_learning.interp import examine_dimension
from dictionary_learning.utils import zst_to_generator

  from .autonotebook import tqdm as notebook_tqdm
Matplotlib created a temporary cache directory at /tmp/matplotlib-a87gknl9 because the default path (/share/u/smarks/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [22]:
dataset = load_dataset("LabHC/bias_in_bios")
profession_dict = {'professor' : 21, 'nurse' : 13}
male_prof = 'professor'
female_prof = 'nurse'

DEVICE = 'cuda:0'
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE)
layer = 4

batch_size = 1024
SEED = 42

def get_data(train=True, ambiguous=True, batch_size=128, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg), len(pos)])
        neg, pos = neg[:n], pos[:n]
        data = neg + pos
        labels = [0]*n + [1]*n
        idxs = list(range(2*n))
        random.Random(seed).shuffle(idxs)
        data, labels = [data[i] for i in idxs], [labels[i] for i in idxs]
        true_labels = spurious_labels = labels
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg_neg), len(neg_pos), len(pos_neg), len(pos_pos)])
        neg_neg, neg_pos, pos_neg, pos_pos = neg_neg[:n], neg_pos[:n], pos_neg[:n], pos_pos[:n]
        data = neg_neg + neg_pos + pos_neg + pos_pos
        true_labels     = [0]*n + [0]*n + [1]*n + [1]*n
        spurious_labels = [0]*n + [1]*n + [0]*n + [1]*n
        idxs = list(range(4*n))
        random.Random(seed).shuffle(idxs)
        data, true_labels, spurious_labels = [data[i] for i in idxs], [true_labels[i] for i in idxs], [spurious_labels[i] for i in idxs]

    batches = [
        (data[i:i+batch_size], t.tensor(true_labels[i:i+batch_size], device=DEVICE), t.tensor(spurious_labels[i:i+batch_size], device=DEVICE)) for i in range(0, len(data), batch_size)
    ]

    return batches

def get_subgroups(train=True, ambiguous=True, batch_size=128, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_labels, pos_labels = (0, 0), (1, 1)
        subgroups = [(neg, neg_labels), (pos, pos_labels)]
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_neg_labels, neg_pos_labels, pos_neg_labels, pos_pos_labels = (0, 0), (0, 1), (1, 0), (1, 1)
        subgroups = [(neg_neg, neg_neg_labels), (neg_pos, neg_pos_labels), (pos_neg, pos_neg_labels), (pos_pos, pos_pos_labels)]
    
    out = {}
    for data, label_profile in subgroups:
        out[label_profile] = []
        for i in range(0, len(data), batch_size):
            text = data[i:i+batch_size]
            out[label_profile].append(
                (
                    text,
                    t.tensor([label_profile[0]]*len(text), device=DEVICE),
                    t.tensor([label_profile[1]]*len(text), device=DEVICE)
                )
            )
    return out


        

# def get_batched_data(model, layer, data, true_labels, spurious_labels, batch_size=128):
#     batches = []

#     for i in range(0, len(data), batch_size):
#         text_batch = data[i:i+batch_size]
#         with model.invoke(text_batch):
#             acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
#         acts = acts.value.clone()
#         true_labels_batch = t.tensor(true_labels[i:i+batch_size]).to(acts.device)
#         spurious_labels_batch = t.tensor(spurious_labels[i:i+batch_size]).to(acts.device)
#         batches.append((acts, true_labels_batch, spurious_labels_batch))

#     return batches

# batches = {
#     'train' : {
#         'ambiguous' : get_batched_data(model, layer, *get_data(train=True, ambiguous=True)), 
#         'unambiguous' : get_batched_data(model, layer, *get_data(train=True, ambiguous=False))
#     },
#     'test' : {
#         'ambiguous' : get_batched_data(model, layer, *get_data(train=False, ambiguous=True)), 
#         'unambiguous' : get_batched_data(model, layer, *get_data(train=False, ambiguous=False))
#     }
# }

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
# probe training hyperparameters
lr = 1e-2
epochs = 1

class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits.sigmoid()

In [4]:
def train_probe(get_acts, label_idx=0, batches=get_data(), lr=1e-2, epochs=1, seed=SEED):
    t.manual_seed(seed)
    probe = Probe(512).to('cuda:0')
    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCELoss()

    losses = []
    for epoch in range(epochs):
        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1] 
            acts = get_acts(text)
            logits = probe(acts)
            loss = criterion(logits, labels.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

    return probe, losses

In [5]:
def test_probe(probe, get_acts, label_idx=0, batches=get_data(train=False), seed=SEED):
    with t.no_grad():
        corrects = []

        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1]
            acts = get_acts(text)
            logits = probe(acts)
            preds = (logits > 0.5).long()
            corrects.append((preds == labels).float())
        return t.cat(corrects).mean().item()

In [6]:
def get_acts(text):
    with model.invoke(text):
        acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
    return acts.value.clone()

In [8]:
probe0, _ = train_probe(get_acts, label_idx=0, batches=get_data(ambiguous=False), lr=lr, epochs=epochs)
probe1, _ = train_probe(get_acts, label_idx=1, batches=get_data(ambiguous=False), lr=lr, epochs=epochs)

In [11]:
batches = get_data(train=False, ambiguous=False)
print('Probe 0 accuracy:', test_probe(probe0, get_acts, batches=batches, label_idx=0))
print('Probe 1 accuracy:', test_probe(probe1, get_acts, batches=batches, label_idx=1))

Probe 0 accuracy: 0.900921642780304
Probe 1 accuracy: 0.9688940048217773


In [7]:
probe, _ = train_probe(get_acts, label_idx=0, lr=lr, epochs=epochs)
print('Ambiguous test accuracy:', test_probe(probe, get_acts, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=1))

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Ambiguous test accuracy: 0.9915195107460022
Ground truth accuracy: 0.6261520981788635
Spurious accuracy: 0.8680875301361084


In [23]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts, batches=batches, label_idx=0))

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Accuracy for (0, 0): 0.9941375851631165
Accuracy for (0, 1): 0.13586179912090302
Accuracy for (1, 0): 0.38479262590408325
Accuracy for (1, 1): 0.9881505370140076


In [32]:
submodules = [
    model.gpt_neox.layers[i] for i in range(layer + 1)
]
dictionaries = {}
for i in range(layer + 1):
    ae = AutoEncoder(512, 64 * 512).to(DEVICE)
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{i}/5_32768/ae.pt'))
    dictionaries[submodules[i]] = ae

def metric_fn(model):
    return probe(model.gpt_neox.layers[layer].output[0][:,-1,:])

neg_inputs, pos_inputs = [], []
for x in dataset['train']:
    if x['profession'] == profession_dict[male_prof] and x['gender'] == 0 and len(neg_inputs) < 16:
        neg_inputs.append(x['hard_text'])
    if x['profession'] == profession_dict[female_prof] and x['gender'] == 1 and len(pos_inputs) < 16:
        pos_inputs.append(x['hard_text'])

neg_effects = patching_effect(
    neg_inputs,
    None,
    model,
    submodules,
    dictionaries,
    metric_fn,
    method='ig'
).effects
neg_effects = {k : v.sum(dim=1).mean(dim=0) for k, v in neg_effects.items()}

pos_effects = patching_effect(
    pos_inputs,
    None,
    model,
    submodules,
    dictionaries,
    metric_fn,
    method='ig'
).effects

pos_effects = {k : v.sum(dim=1).mean(dim=0) for k, v in pos_effects.items()}

In [10]:
threshold = 0.005

for i, submodule in enumerate(submodules):
    print(f"Layer {i}:")
    effect = neg_effects[submodule]
    for feature_idx in t.nonzero(effect):
        value = effect[tuple(feature_idx)]
        if value > threshold:
            print(f"    Multindex: {tuple(feature_idx.tolist())}, Value: {value}")

for i, submodule in enumerate(submodules):
    print(f"Layer {i}:")
    effect = pos_effects[submodule]
    for feature_idx in t.nonzero(effect):
        value = effect[tuple(feature_idx)]
        if -value > threshold:
            print(f"    Multindex: {tuple(feature_idx.tolist())}, Value: {value}")

Layer 0:
    Multindex: (10313,), Value: 0.04134603962302208
    Multindex: (17329,), Value: 0.009056453593075275
    Multindex: (23553,), Value: 0.0062515512108802795
    Multindex: (25794,), Value: 0.010537972673773766
Layer 1:
    Multindex: (11727,), Value: 0.014927267096936703
    Multindex: (12955,), Value: 0.006697795353829861
    Multindex: (16731,), Value: 0.007774203550070524
    Multindex: (18680,), Value: 0.006383746396750212
Layer 2:
    Multindex: (2859,), Value: 0.00724233640357852
    Multindex: (5547,), Value: 0.014884809032082558
    Multindex: (12633,), Value: 0.00643932493403554
    Multindex: (17324,), Value: 0.009294000454246998
Layer 3:
    Multindex: (10539,), Value: 0.006092030089348555
    Multindex: (16734,), Value: 0.006312336772680283
    Multindex: (26887,), Value: 0.00739505747333169
    Multindex: (29985,), Value: 0.006092030089348555
Layer 4:
    Multindex: (8126,), Value: 0.016388770192861557
    Multindex: (23954,), Value: 0.008023547008633614
    Mul

In [51]:
component_idx = 4
feat_idx = 7536

submodule = submodules[component_idx]
dictionary = dictionaries[submodule]

# interpret some features
data = zst_to_generator('/share/data/datasets/pile/the-eye.eu/public/AI/pile/train/00.jsonl.zst')
buffer = ActivationBuffer(
    data,
    model,
    submodule,
    out_feats=512,
    in_batch_size=128,
    n_ctxs=512,
)

out = examine_dimension(
    model,
    submodule,
    buffer,
    dictionary,
    dim_idx=feat_idx,
    n_inputs=256
)
print(out.top_tokens)
print(out.top_affected)
out.top_contexts

[(' →', 3.498818874359131), (')?', 0.7700076103210449), ('!', 0.6087144613265991), ('%).', 0.6013126373291016), ('.', 0.5607401132583618), (').', 0.19541782140731812), ('.', 0.16909795999526978), ('?', 0.15509594976902008), ('com', 0.13155770301818848), ('…', 0.05967291444540024), ('...', 0.0465102344751358), (' -', 0.02337704598903656), (']', 0.022350026294589043), (' available', 0.02159261703491211), ('References', 0.002975702518597245), (':', 0.0013361757155507803), (',', 0.0005869610467925668), ('It', 0.0), (' is', 0.0), (' done', 0.0), (' and', 0.0), (' submitted', 0.0), (' You', 0.0), (' can', 0.0), (' play', 0.0), (' “', 0.0), ('Sur', 0.0), ('vival', 0.0), (' of', 0.0), (' the', 0.0)]
[('UPDATE', 1.3970803022384644), ('#####', 1.3345152139663696), ('fefefe', 1.3260353803634644), ('Docket', 1.3101134300231934), ('EOF', 1.300095558166504), ('txt', 1.2410948276519775), ('<?', 1.2343525886535645), ('"""', 1.2244322299957275), ('################################', 1.2139382362365723),

In [52]:
to_ablate = {
    submodules[0] : [
        10313, # He
        17329, # he
        #23553, # unclear
        25794, # his
        2910, # she but in very few cases?
        7187, # Her + His on two examples?
        # 11379, # or 
        # 11674, # hospital
        # 11909, # call
        # 13051, # unclear
        # 13094, # share 
        15628, # female names
        # 17078, # primary care or health care
        # 22846, # unclear
        # 29183, # certain verbs
        # 30927, # unclear
        31251, # her
        # 32356, # f
    ],
    submodules[1] : [
        11727, # He
        # 12955, # unclear 
        16731, # his
        # 18680, # certain /'s
        2578, # female names
        20964, # certain instances of her
        22287, # she
    ],
    submodules[2] : [
        # 2859, # periods which end sentences
        5547, # He
        # 12633, # unclear
        # 17324, # periods in medical contexts
    ],
    submodules[3] : [
        # 10539, # periods in biographies
        # 16734, # unclear
        # 26887, # fires a lot in the biography of a particular male professor, unclear
        # 29985, # periods in biographies
        # 3216, # certain periods
        # 31219, # certain periods 
    ],
    submodules[4] : [
        # 8126, # periods in biographies
        # 23954, # certain periods
        # 30226, # same as 26887 above
        # 7536, # certain periods
        25160, # unclear, but predicts female-associated words
    ]
}

In [53]:
def run_with_ablations(
        model,
        inputs,
        submodules,
        dictionaries,
        to_ablate,
        inference=True,
):
    with model.invoke(inputs, fwd_args={'inference': inference}):
        for submodule in submodules:
            dictionary = dictionaries[submodule]
            x = submodule.output
            is_resid = (type(x.shape) == tuple)
            if is_resid:
                x = x[0]
            x_hat = dictionary(x)
            residual = x - x_hat

            f = dictionary.encode(x)
            ablation_idxs = t.Tensor(to_ablate[submodule]).long()
            f[:, :, ablation_idxs] = 0.
            x_hat = dictionary.decode(f)
            if is_resid:
                submodule.output[0][:] = x_hat + residual
            else:
                submodule.output = x_hat + residual
            
        acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
    return acts.value.clone()

In [54]:
def get_acts_abl(text):
    return run_with_ablations(
        model,
        text,
        submodules,
        dictionaries,
        to_ablate,
    )

In [55]:
new_probe, _ = train_probe(get_acts_abl, label_idx=0, lr=lr, epochs=epochs)
print('Ambiguous test accuracy:', test_probe(new_probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.9613150358200073
Ground truth accuracy: 0.856566846370697
Spurious accuracy: 0.6111751198768616


In [56]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))

In [92]:
def out_fn(model):
    return probe(model.gpt_neox.layers[layer].output[0][:,-1,:])

# get accuracy after ablating above features
batches = get_data(train=False, ambiguous=True, batch_size=128, seed=SEED)

with t.no_grad():
    corrects = []
    for batch in batches:
        text = batch[0]
        labels = batch[1]
        probs = run_with_ablations(
            model,
            text,
            submodules,
            dictionaries,
            to_ablate,
            out_fn
        )
        preds = (probs > 0.5).long()
        corrects.append((preds == labels).float())
    print("Accuracy on ambiguous data:", t.cat(corrects).mean().item())

batches = get_data(train=False, ambiguous=False, batch_size=128, seed=SEED)

with t.no_grad():
    truth_corrects, spurious_corrects = [], []
    for batch in batches:
        text = batch[0]
        true_labels = batch[1]
        spurious_labels = batch[2]
        probs = run_with_ablations(
            model,
            text,
            submodules,
            dictionaries,
            to_ablate,
            out_fn
        )
        preds = (probs > 0.5).long()
        truth_corrects.append((preds == true_labels).float())
        spurious_corrects.append((preds == spurious_labels).float())
    print("Ground truth accuracy", t.cat(truth_corrects).mean().item())
    print("Spurious accuracy", t.cat(spurious_corrects).mean().item())

Accuracy on ambiguous data: 0.4993029832839966
Ground truth accuracy 0.4976958632469177
Spurious accuracy 0.5011520981788635


In [None]:
# retrain probe with ablated model

def out_fn(model):
    return model.gpt_neox.layers[layer].output[0][:,-1,:]

t.manual_seed(SEED)
new_probe = Probe(512).to('cuda:0')
optimizer = t.optim.AdamW(new_probe.parameters(), lr=lr)
criterion = nn.BCELoss()

batches = get_data(train=True, ambiguous=True, batch_size=64, seed=SEED)

losses = []
for epoch in range(epochs):
    for batch in batches:
        text = batch[0]
        labels = batch[1]
        acts = run_with_ablations(
            model,
            text,
            submodules,
            dictionaries,
            to_ablate,
            out_fn
        ).clone()
        probs = new_probe(acts)
        loss = criterion(probs, labels.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
