In [2]:
import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from datasets import load_dataset
import random
from nnsight import LanguageModel
import torch as t
from torch import nn
from attribution import patching_effect
from dictionary_learning import AutoEncoder, ActivationBuffer
from dictionary_learning.dictionary import IdentityDict
# from dictionary_learning.interp import examine_dimension
# no circuitsviz; no z standard library
# from dictionary_learning.utils import hf_dataset_to_generator
from tqdm import tqdm
import gc

DEBUGGING = False

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

# model hyperparameters
# DEVICE = 'cuda:0'
DEVICE = "cpu"
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE, dispatch=True)
activation_dim = 512

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# dataset hyperparameters
dataset = load_dataset("LabHC/bias_in_bios")
profession_dict = {'professor' : 21, 'nurse' : 13}
male_prof = 'professor'
female_prof = 'nurse'

# data preparation hyperparameters
# batch_size = 1024
batch_size = 8
SEED = 42

# To fit on 24GB VRAM GPU, I set the next 2 default batch_sizes to 64
def get_data(train=True, ambiguous=True, batch_size=batch_size, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg), len(pos)])
        neg, pos = neg[:n], pos[:n]
        data = neg + pos
        labels = [0]*n + [1]*n
        idxs = list(range(2*n))
        random.Random(seed).shuffle(idxs)
        data, labels = [data[i] for i in idxs], [labels[i] for i in idxs]
        true_labels = spurious_labels = labels
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg_neg), len(neg_pos), len(pos_neg), len(pos_pos)])
        neg_neg, neg_pos, pos_neg, pos_pos = neg_neg[:n], neg_pos[:n], pos_neg[:n], pos_pos[:n]
        data = neg_neg + neg_pos + pos_neg + pos_pos
        true_labels     = [0]*n + [0]*n + [1]*n + [1]*n
        spurious_labels = [0]*n + [1]*n + [0]*n + [1]*n
        idxs = list(range(4*n))
        random.Random(seed).shuffle(idxs)
        data, true_labels, spurious_labels = [data[i] for i in idxs], [true_labels[i] for i in idxs], [spurious_labels[i] for i in idxs]

    batches = [
        (data[i:i+batch_size], t.tensor(true_labels[i:i+batch_size], device=DEVICE), t.tensor(spurious_labels[i:i+batch_size], device=DEVICE)) for i in range(0, len(data), batch_size)
    ]

    return batches

def get_subgroups(train=True, ambiguous=True, batch_size=128, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_labels, pos_labels = (0, 0), (1, 1)
        subgroups = [(neg, neg_labels), (pos, pos_labels)]
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_neg_labels, neg_pos_labels, pos_neg_labels, pos_pos_labels = (0, 0), (0, 1), (1, 0), (1, 1)
        subgroups = [(neg_neg, neg_neg_labels), (neg_pos, neg_pos_labels), (pos_neg, pos_neg_labels), (pos_pos, pos_pos_labels)]
    
    out = {}
    for data, label_profile in subgroups:
        out[label_profile] = []
        for i in range(0, len(data), batch_size):
            text = data[i:i+batch_size]
            out[label_profile].append(
                (
                    text,
                    t.tensor([label_profile[0]]*len(text), device=DEVICE),
                    t.tensor([label_profile[1]]*len(text), device=DEVICE)
                )
            )
    return out

In [8]:
# probe training hyperparameters
from torch import nn
layer = 4 # model layer for attaching linear classification head

class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits

def train_probe(get_acts, label_idx=0, batches=get_data(), lr=1e-2, epochs=1, dim=512, seed=SEED):
    t.manual_seed(seed)
    probe = Probe(dim).to(DEVICE)
    print('probe')
    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    losses = []
    for epoch in range(epochs):
        for batch in tqdm(batches):
            
            text = batch[0]
            # print(text)
            labels = batch[label_idx+1] 
            acts = get_acts(text)
            # print(acts.shape) (8, 512) (activations for a batch of 8 datapoints)
            logits = probe(acts)
            loss = criterion(logits, labels.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

    return probe, losses

def test_probe(probe, get_acts, label_idx=0, batches=get_data(train=False), seed=SEED):
    with t.no_grad():
        corrects = []

        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1]
            acts = get_acts(text)
            logits = probe(acts)
            preds = (logits > 0.0).long()
            corrects.append((preds == labels).float())
        return t.cat(corrects).mean().item()
    
def get_acts(text):
    with t.no_grad(): 
        with model.trace(text, **tracer_kwargs):
            # print(text)
            attn_mask = model.inputs[1]['attention_mask']
            acts = model.gpt_neox.layers[layer].output[0]
            acts = acts * attn_mask[:, :, None]
            acts = acts.sum(1) / attn_mask.sum(1)[:, None]
            acts = acts.save()
        return acts.value

In [4]:
batches=get_data(ambiguous=False)
# (text, profession, gender)

In [5]:
print('batches: ', len(batches))
print('each batch: ', len(batches[0][0]), 'data points in tuples of (text, gender, )')
print(batches[0][0][3])
print(f'batch 1 sample: {batches[0][0]} \n {batches[0][1]} \n {batches[0][2]}')
print(f'batch dtype: {len(batches[0])}, {type(batches[0])}, {type(batches[0][1])} \n str : {type(batches[0][0])}')

batches:  564
each batch:  8 data points in tuples of (text, gender, )
After graduating from the University of San Francisco she went on to explore different areas of nursing including, orthopedic trauma/surgery, emergency room, critical care and pediatric intensive care unit. She is also a Public Health Nurse with an expertise in working with multicultural patients as they intersect within a healthcare setting.
batch 1 sample: ['She graduated with honors in 2006. Having more than 11 years of diverse experiences, especially in NURSE PRACTITIONER, Sharon R Hoskin affiliates with many hospitals including Ut Southwestern University Hospital St Paul, Ut Southwestern University Hospital-zale Lipshy, and cooperates with other doctors and specialists in medical group The University Of Texas Southwestern Medical Center At Dallas. Call Sharon R Hoskin on phone number (214) 648-1454 for more information and advises or to book an appointment.', 'Chirag Patel’s long-term research goal is to addres

In [None]:
# oracle, _ = train_probe(get_acts, label_idx=0, batches=get_data(ambiguous=False))
# print("ambiguous test accuracy", test_probe(oracle, get_acts, label_idx=0))

In [26]:
## train oracle, with balanced set
oracle, _ = train_probe(get_acts, label_idx=0, batches=get_data(ambiguous=False))
print("ambiguous test accuracy", test_probe(oracle, get_acts, label_idx=0))
## test with balanced set
batches = get_data(train=False, ambiguous=False)
## test with labels = profession
print("ground truth accuracy:", test_probe(oracle, get_acts, batches=batches, label_idx=0))
## test with labels = gender 
print("unintended feature accuracy:", test_probe(oracle, get_acts, batches=batches, label_idx=1))

probe


100%|██████████| 564/564 [01:03<00:00,  8.83it/s]


ambiguous test accuracy 0.921352207660675
ground truth accuracy: 0.9320276379585266
unintended feature accuracy: 0.4873271882534027


In [None]:
# get worst-group accuracy of oracle probe
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(oracle, get_acts, batches=batches, label_idx=0))

In [9]:

probe, _ = train_probe(get_acts, label_idx=0)

probe


100%|██████████| 2798/2798 [05:07<00:00,  9.09it/s]


In [25]:
# Save only the state dict (model parameters)
t.save(probe.state_dict(), '../exp_redo/24_10_probe_model_state.pth')

In [10]:
probe ## 512

Probe(
  (net): Linear(in_features=512, out_features=1, bias=True)
)

In [11]:
print('Ambiguous test accuracy:', test_probe(probe, get_acts, label_idx=0))

Ambiguous test accuracy: 0.9907063245773315


In [23]:
## get testset for balanced data
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=1))

Ground truth accuracy: 0.6059907674789429
Unintended feature accuracy: 0.8790322542190552


In [None]:
## train with ambiguous data (gender biased data)
probe, _ = train_probe(get_acts, label_idx=0)
print('Ambiguous test accuracy:', test_probe(probe, get_acts, label_idx=0))
## get testset for balanced data
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=1))

In [None]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts, batches=batches, label_idx=0))

In [17]:
# loading dictionaries

# dictionary hyperparameters
dict_id = 10
expansion_factor = 64
dictionary_size = expansion_factor * activation_dim
DEVICE = 'cpu'
submodules = []
dictionaries = {}
path = "../dictionary_learning"
layer = 4 ## only retrieve c

submodules.append(model.gpt_neox.embed_in)
dictionaries[model.gpt_neox.embed_in] = AutoEncoder.from_pretrained(
    f'{path}/dictionaries/pythia-70m-deduped/embed/{dict_id}_{dictionary_size}/ae.pt',
    device=DEVICE
)
for i in range(layer + 1):
    submodules.append(model.gpt_neox.layers[i].attention)
    dictionaries[model.gpt_neox.layers[i].attention] = AutoEncoder.from_pretrained(
        f'{path}/dictionaries/pythia-70m-deduped/attn_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt',
        device=DEVICE
    )

    submodules.append(model.gpt_neox.layers[i].mlp)
    dictionaries[model.gpt_neox.layers[i].mlp] = AutoEncoder.from_pretrained(
        f'{path}/dictionaries/pythia-70m-deduped/mlp_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt',
        device=DEVICE
    )

    submodules.append(model.gpt_neox.layers[i])
    dictionaries[model.gpt_neox.layers[i]] = AutoEncoder.from_pretrained(
        f'{path}/dictionaries/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt',
        device=DEVICE
    )

def metric_fn(model, labels=None):
    attn_mask = model.inputs[1]['attention_mask']
    acts = model.gpt_neox.layers[layer].output[0]
    acts = acts * attn_mask[:, :, None]
    acts = acts.sum(1) / attn_mask.sum(1)[:, None]
    
    return t.where(
        labels == 0,
        probe(acts),
        - probe(acts)
    )

# somehow middle layer is chosen for training the class
# middle layer - expected to be where higher level features are learnt?

  state_dict = t.load(path, map_location=t.device('cpu'))


In [18]:
submodules

[Embedding(50304, 512),
 GPTNeoXSdpaAttention(
   (rotary_emb): GPTNeoXRotaryEmbedding()
   (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
   (dense): Linear(in_features=512, out_features=512, bias=True)
   (attention_dropout): Dropout(p=0.0, inplace=False)
 ),
 GPTNeoXMLP(
   (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
   (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
   (act): GELUActivation()
 ),
 GPTNeoXLayer(
   (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
   (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
   (post_attention_dropout): Dropout(p=0.0, inplace=False)
   (post_mlp_dropout): Dropout(p=0.0, inplace=False)
   (attention): GPTNeoXSdpaAttention(
     (rotary_emb): GPTNeoXRotaryEmbedding()
     (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
     (dense): Linear(in_features=512, out_features=512, bias=True)
     

In [20]:
len(dictionaries)

16

In [21]:
dictionaries.keys()

dict_keys([Embedding(50304, 512), GPTNeoXSdpaAttention(
  (rotary_emb): GPTNeoXRotaryEmbedding()
  (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
  (dense): Linear(in_features=512, out_features=512, bias=True)
  (attention_dropout): Dropout(p=0.0, inplace=False)
), GPTNeoXMLP(
  (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
  (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
  (act): GELUActivation()
), GPTNeoXLayer(
  (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (post_attention_dropout): Dropout(p=0.0, inplace=False)
  (post_mlp_dropout): Dropout(p=0.0, inplace=False)
  (attention): GPTNeoXSdpaAttention(
    (rotary_emb): GPTNeoXRotaryEmbedding()
    (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
    (dense): Linear(in_features=512, out_features=512, bias=True)
    (attention_

In [13]:
data = get_data(train=True, ambiguous=True, batch_size=batch_size, seed=SEED)
len(data) 

2798

In [15]:
data[1] # (text, profession, gender)

(['She graduated with honors from University Of Mississippi School Of Medicine in 2011. Having more than 5 years of diverse experiences, especially in NURSE PRACTITIONER, Jennifer S Stewart affiliates with Highland Community Hospital, and cooperates with other doctors and specialists in medical group Hattiesburg Clinic Pa. Call Jennifer S Stewart on phone number (601) 268-5888 for more information and advises or to book an appointment.',
  'She graduated with honors in 2009. Having more than 7 years of diverse experiences, especially in NURSE PRACTITIONER, Tammy M Milholen affiliates with many hospitals including Decatur County General Hospital, Jackson-madison County General Hospital, Henderson County Community Hospital, and cooperates with other doctors and specialists in medical group Christian Family Medicine Inc. Call Tammy M Milholen on phone number (731) 847-6010 for more information and advises or to book an appointment.',
  'His work seeks to understand how individuals, across

In [27]:
# find most influential features
# n_batches = 25
n_batches = 2
batch_size = 4

running_total = 0
nodes = None

## training data, for ambiguous (batch idx; text

for batch_idx, (clean, labels, _) in tqdm(enumerate(get_data(train=True, ambiguous=True, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break
    print(batch_idx, clean, labels)

### estimate causal impact based on integrated gradients 
### take text, model, dictionaries, metric function

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )

    with t.no_grad():
        if nodes is None:
            nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)
    del effects, _
    gc.collect()

## get most relevant nodes
nodes = {k : v / running_total for k, v in nodes.items()}

  0%|          | 0/2 [00:00<?, ?it/s]

0 ['She remains actively involved with the Burn Treatment Center after working there for nearly 18 years. Throughout these years she was a Nursing Assistant, Staff Nurse, Assistant Nurse Manager, and most recently the Nurse Manager. She is a member of the American Burn Association, the American Association of Critical Care Nurses, and the Health Information Management Systems Society. She has been co-director of Miracle Burn Camp since 2001 & assisted in creating the Young Adult Burn Survivor retreat in 2007. Alison is the mother of two and stays active as a board member for the local youth hockey association and keeping up with her boys. To Contact Alison, please: *protected email*', 'She graduated with honors in 1999. Having more than 17 years of diverse experiences, especially in NURSE PRACTITIONER, Shannon R Russom affiliates with University Of Vermont Medical Center, and cooperates with other doctors and specialists in medical group University Of Vermont Medical Center Inc. Call S




AttributeError: Above exception when execution Node: 'fetch_attr_1' in Graph: '139893558527760'

In [None]:
n_features = 0
for component_idx, effect in enumerate(nodes.values()):
    print(f"Component {component_idx}:")
    for idx in (effect > 0.1).nonzero():
        print(idx.item(), effect[idx].item())
        n_features += 1
print(f"total features: {n_features}")

#### to note: 
- how are the attributing it back to dataset? 

In [None]:
# interpret features
from dictionary_learning.interp import examine_dimension
# no circuitsviz; no z standard library
from dictionary_learning.utils import hf_dataset_to_generator

# change the following two lines to pick which feature to interpret
component_idx = 9
feat_idx = 31098

submodule = submodules[component_idx]
dictionary = dictionaries[submodule]

# interpret some features
data = hf_dataset_to_generator("monology/pile-uncopyrighted")
buffer = ActivationBuffer(
    data,
    model,
    submodule,
    d_submodule=512,
    refresh_batch_size=128, # decrease to fit on smaller GPUs
    n_ctxs=512, # decrease to fit on smaller GPUs
    device=DEVICE
)

out = examine_dimension(
    model,
    submodule,
    buffer,
    dictionary,
    dim_idx=feat_idx,
    n_inputs=256 # decrease to fit on smaller GPUs
)
print(out.top_tokens)
print(out.top_affected)
out.top_contexts

In [None]:
feats_to_ablate = {
    submodules[0] : [
        946, # 'his'
        # 5719, # 'research'
        7392, # 'He'
        # 10784, # 'Nursing'
        17846, # 'He'
        22068, # 'His'
        # 23079, # 'tastes'
        # 25904, # 'nursing'
        28533, # 'She'
        29476, # 'he'
        31461, # 'His'
        31467, # 'she'
        32081, # 'her'
        32469, # 'She'
    ],
    submodules[1] : [
        # 23752, # capitalized words, especially pronouns
    ],
    submodules[2] : [
        2995, # 'he'
        3842, # 'She'
        10258, # female names
        13387, # 'she'
        13968, # 'He'
        18382, # 'her'
        19369, # 'His'
        28127, # 'She'
        30518, # 'He'
    ],
    submodules[3] : [
        1022, # 'she'
        9651, # female names
        10060, # 'She'
        18967, # 'He'
        22084, # 'he'
        23898, # 'His'
        # 24799, # promotes surnames
        26504, # 'her'
        29626, # 'his'
        # 31201, # 'nursing'
    ],
    submodules[4] : [
        # 8147, # unclear, something with names
    ],
    submodules[5] : [
        24159, # 'She', 'she'
        25018, # female names
    ],
    submodules[6] : [
        4592, # 'her'
        8920, # 'he'
        9877, # female names
        12128, # 'his'
        15017, # 'she'
        # 17369, # contact info
        # 26969, # related to nursing
        30248, # female names
    ],
    submodules[7] : [
        13570, # promotes male-related words
        27472, # female names, promotes female-related words
    ],
    submodules[8] : [
    ],
    submodules[9] : [
        1995, # promotes female-associated words
        9128, # feminine pronouns
        11656, # promotes male-associated words
        12440, # promotes female-associated words
        # 14638, # related to contact information?
        29206, # gendered pronouns
        29295, # female names
        # 31098, # nursing-related words
    ],
    submodules[10] : [
        2959, # promotes female-associated words
        19128, # promotes male-associated words
        22029, # promotes female-associated words
    ],
    submodules[11] : [
    ],
    submodules[12] : [
        19558, # promotes female-associated words
        23545, # 'she'
        24806, # 'her'
        27334, # promotes male-associated words
        31453, # female names
    ],
    submodules[13] : [
        31101, # promotes female-associated words
    ],
    submodules[14] : [
    ],
    submodules[15] : [
        9766, # promotes female-associated words
        12420, # promotes female pronouns
        30220, # promotes male pronouns
    ]
}

print(f"Number of features to ablate: {sum(len(v) for v in feats_to_ablate.values())}")

In [None]:
# putting feats_to_ablate in a more useful format
def n_hot(feats, dim=dictionary_size):
    out = t.zeros(dim, dtype=t.bool, device=DEVICE)
    for feat in feats:
        out[feat] = True
    return out

feats_to_ablate = {
    submodule : n_hot(feats) for submodule, feats in feats_to_ablate.items()
}


In [None]:
# utilities for ablating features
is_tuple = {}
with t.no_grad(), model.trace("_"):
    for submodule in submodules:
        is_tuple[submodule] = type(submodule.output.shape) == tuple

def get_acts_ablated(
    text,
    model,
    submodules,
    dictionaries,
    to_ablate
):

    with t.no_grad(), model.trace(text, **tracer_kwargs):
        for submodule in submodules:
            dictionary = dictionaries[submodule]
            feat_idxs = to_ablate[submodule]
            x = submodule.output
            if is_tuple[submodule]:
                x = x[0]
            x_hat, f = dictionary(x, output_features=True)
            res = x - x_hat
            f[...,feat_idxs] = 0. # zero ablation
            if is_tuple[submodule]:
                submodule.output[0][:] = dictionary.decode(f) + res
            else:
                submodule.output = dictionary.decode(f) + res
        attn_mask = model.input[1]['attention_mask']
        act = model.gpt_neox.layers[layer].output[0]
        act = act * attn_mask[:, :, None]
        act = act.sum(1) / attn_mask.sum(1)[:, None]
        act = act.save()
    return act.value


# Accuracy after ablating features judged irrelevant by human annotators

In [None]:
get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, feats_to_ablate)

print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))

In [None]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))

# Concept bottleneck probing baseline

In [None]:
concepts = [    
    ' nurse',
    ' healthcare',
    ' hospital',
    ' patient',
    ' medical',
    ' clinic',
    ' triage',
    ' medication',
    ' emergency',
    ' surgery',
    ' professor',
    ' academia',
    ' research',
    ' university',
    ' tenure',
    ' faculty',
    ' dissertation',
    ' sabbatical',
    ' publication',
    ' grant',
]
# get concept vectors
with t.no_grad(), model.trace(concepts):
    concept_vectors = model.gpt_neox.layers[layer].output[0][:, -1, :].save()
concept_vectors = concept_vectors.value - concept_vectors.value.mean(0, keepdim=True)

def get_bottleneck(text):
    with t.no_grad(), model.trace(text, **tracer_kwargs):
        attn_mask = model.input[1]['attention_mask']
        acts = model.gpt_neox.layers[layer].output[0]
        acts = acts * attn_mask[:, :, None]
        acts = acts.sum(1) / attn_mask.sum(1)[:, None]
        # compute cosine similarity with concept vectors
        sims = (acts @ concept_vectors.T) / (acts.norm(dim=-1)[:, None] @ concept_vectors.norm(dim=-1)[None])
        sims = sims.save()
    return sims.value

In [None]:
cbp_probe, _ = train_probe(get_bottleneck, label_idx=0, dim=len(concepts))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=1))


In [None]:
# get subgroup accuracies
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=0))

# Get skyline neuron performance

In [None]:
# get neurons which are most influential for giving gender label
neuron_dicts = {
    submodule : IdentityDict(activation_dim).to(DEVICE) for submodule in submodules
}

n_batches = 25
batch_size = 4

running_total = 0
nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_data(train=True, ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        neuron_dicts,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if nodes is None:
            nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)
    del effects, _
    gc.collect()

nodes = {k : v / running_total for k, v in nodes.items()}

In [None]:
neurons_to_ablate = {}
total_neurons = 0
for component_idx, effect in enumerate(nodes.values()):
    print(f"Component {component_idx}:")
    neurons_to_ablate[submodules[component_idx]] = []
    for idx in (effect.act > 0.2135).nonzero():
        print(idx.item(), effect[idx].item())
        neurons_to_ablate[submodules[component_idx]].append(idx.item())
        total_neurons += 1
print(f"total neurons: {total_neurons}")

neurons_to_ablate = {
    submodule : n_hot([neuron_idx], dim=512) for submodule, neuron_idx in neurons_to_ablate.items()
}

In [None]:
def get_acts_abl(text):
    with t.no_grad(), model.trace(text, **tracer_kwargs):
        for submodule in submodules:
            x = submodule.output
            if is_tuple[submodule]:
                x = x[0]
            x[...,neurons_to_ablate[submodule]] = x.mean(dim=(0,1))[...,neurons_to_ablate[submodule]] # mean ablation
            if is_tuple[submodule]:
                submodule.output[0][:] = x
            else:
                submodule.output = x

        attn_mask = model.input[1]['attention_mask']
        act = model.gpt_neox.layers[layer].output[0]
        act = act * attn_mask[:, :, None]
        act = act.sum(1) / attn_mask.sum(1)[:, None]
        act = act.save()
    return act.value

In [None]:
print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))

In [None]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))

# Get skyline feature performance

In [None]:
# get features which are most useful for predicting gender label
n_batches = 25
batch_size = 4

running_total = 0
running_nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_data(train=True, ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if running_nodes is None:
            running_nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                running_nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)
    del effects, _
    gc.collect()

nodes = {k : v / running_total for k, v in running_nodes.items()}

In [None]:
top_feats_to_ablate = {}
total_features = 0
for component_idx, effect in enumerate(nodes.values()):
    print(f"Component {component_idx}:")
    top_feats_to_ablate[submodules[component_idx]] = []
    for idx in (effect > 0.1107).nonzero():
        print(idx.item(), effect[idx].item())
        top_feats_to_ablate[submodules[component_idx]].append(idx.item())
        total_features += 1
print(f"total features: {total_features}")

In [None]:
top_feats_to_ablate = {
    submodule : n_hot(feats) for submodule, feats in top_feats_to_ablate.items()
}
get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, top_feats_to_ablate)

In [None]:
print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))

In [None]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))

# Retraining probe on activations after ablating features

In [None]:
get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, feats_to_ablate)

new_probe, _ = train_probe(get_acts_abl, label_idx=0)
print('Ambiguous test accuracy:', test_probe(new_probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=1))

In [None]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))