In [1]:
import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from datasets import load_dataset
import random
from nnsight import LanguageModel
import torch as t
from torch import nn
import matplotlib.pyplot as plt
from attribution import patching_effect
from dictionary_learning import AutoEncoder, ActivationBuffer
from dictionary_learning.dictionary import IdentityDict
from dictionary_learning.interp import examine_dimension
from dictionary_learning.utils import hf_dataset_to_generator
from tqdm import tqdm
from activation_utils import SparseAct
import gc

DEBUGGING = False

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

  from .autonotebook import tqdm as notebook_tqdm
Matplotlib created a temporary cache directory at /tmp/matplotlib-oygbwsli because the default path (/share/u/smarks/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [2]:
# dataset hyperparameters
dataset = load_dataset("LabHC/bias_in_bios")
profession_dict = {'professor' : 21, 'nurse' : 13}
male_prof = 'professor'
female_prof = 'nurse'

# model hyperparameters
DEVICE = 'cuda:0'
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE, dispatch=True)
activation_dim = 512
layer = 4

# dictionary hyperparameters
dict_id = 10
expansion_factor = 64
dictionary_size = expansion_factor * activation_dim

# data preparation hyperparameters
batch_size = 1024
SEED = 42

def get_data(train=True, ambiguous=True, batch_size=128, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg), len(pos)])
        neg, pos = neg[:n], pos[:n]
        data = neg + pos
        labels = [0]*n + [1]*n
        idxs = list(range(2*n))
        random.Random(seed).shuffle(idxs)
        data, labels = [data[i] for i in idxs], [labels[i] for i in idxs]
        true_labels = spurious_labels = labels
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg_neg), len(neg_pos), len(pos_neg), len(pos_pos)])
        neg_neg, neg_pos, pos_neg, pos_pos = neg_neg[:n], neg_pos[:n], pos_neg[:n], pos_pos[:n]
        data = neg_neg + neg_pos + pos_neg + pos_pos
        true_labels     = [0]*n + [0]*n + [1]*n + [1]*n
        spurious_labels = [0]*n + [1]*n + [0]*n + [1]*n
        idxs = list(range(4*n))
        random.Random(seed).shuffle(idxs)
        data, true_labels, spurious_labels = [data[i] for i in idxs], [true_labels[i] for i in idxs], [spurious_labels[i] for i in idxs]

    batches = [
        (data[i:i+batch_size], t.tensor(true_labels[i:i+batch_size], device=DEVICE), t.tensor(spurious_labels[i:i+batch_size], device=DEVICE)) for i in range(0, len(data), batch_size)
    ]

    return batches

def get_subgroups(train=True, ambiguous=True, batch_size=128, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_labels, pos_labels = (0, 0), (1, 1)
        subgroups = [(neg, neg_labels), (pos, pos_labels)]
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_neg_labels, neg_pos_labels, pos_neg_labels, pos_pos_labels = (0, 0), (0, 1), (1, 0), (1, 1)
        subgroups = [(neg_neg, neg_neg_labels), (neg_pos, neg_pos_labels), (pos_neg, pos_neg_labels), (pos_pos, pos_pos_labels)]
    
    out = {}
    for data, label_profile in subgroups:
        out[label_profile] = []
        for i in range(0, len(data), batch_size):
            text = data[i:i+batch_size]
            out[label_profile].append(
                (
                    text,
                    t.tensor([label_profile[0]]*len(text), device=DEVICE),
                    t.tensor([label_profile[1]]*len(text), device=DEVICE)
                )
            )
    return out

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
# probe training hyperparameters
class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits

def train_probe(get_acts, label_idx=0, batches=get_data(), lr=1e-2, epochs=1, dim=512, seed=SEED):
    t.manual_seed(seed)
    probe = Probe(dim).to('cuda:0')
    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    losses = []
    for epoch in range(epochs):
        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1] 
            acts = get_acts(text)
            logits = probe(acts)
            loss = criterion(logits, labels.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

    return probe, losses

def test_probe(probe, get_acts, label_idx=0, batches=get_data(train=False), seed=SEED):
    with t.no_grad():
        corrects = []

        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1]
            acts = get_acts(text)
            logits = probe(acts)
            preds = (logits > 0.0).long()
            corrects.append((preds == labels).float())
        return t.cat(corrects).mean().item()
    
def get_acts(text):
    with t.no_grad(): 
        with model.trace(text, **tracer_kwargs):
            attn_mask = model.input[1]['attention_mask']
            acts = model.gpt_neox.layers[layer].output[0]
            acts = acts * attn_mask[:, :, None]
            acts = acts.sum(1) / attn_mask.sum(1)[:, None]
            acts = acts.save()
        return acts.value

In [7]:
oracle, _ = train_probe(get_acts, label_idx=0, batches=get_data(ambiguous=False))
print("ambiguous test accuracy", test_probe(oracle, get_acts, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print("ground truth accuracy:", test_probe(oracle, get_acts, batches=batches, label_idx=0))
print("unintended feature accuracy:", test_probe(oracle, get_acts, batches=batches, label_idx=1))

ambiguous test accuracy 0.9318076372146606
Oracle probe
ground truth accuracy: 0.9302995204925537
unintended feature accuracy: 0.49366357922554016


In [23]:
# get worst-group accuracy of oracle probe
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(oracle, get_acts, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9468682408332825
Accuracy for (0, 1): 0.926398754119873
Accuracy for (1, 0): 0.9239631295204163
Accuracy for (1, 1): 0.9189126491546631


In [4]:
probe, _ = train_probe(get_acts, label_idx=0)
print('Ambiguous test accuracy:', test_probe(probe, get_acts, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=1))

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Ambiguous test accuracy: 0.9955855011940002
Ground truth accuracy: 0.6186636090278625
Unintended feature accuracy: 0.8744239807128906


In [25]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9978401064872742
Accuracy for (0, 1): 0.24355988204479218
Accuracy for (1, 0): 0.2626728117465973
Accuracy for (1, 1): 0.9934943914413452


In [5]:
submodules = []
dictionaries = {}

submodules.append(model.gpt_neox.embed_in)
ae = AutoEncoder(512, dictionary_size).to(DEVICE)
ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/embed/{dict_id}_{dictionary_size}/ae.pt'))
dictionaries[model.gpt_neox.embed_in] = ae

for i in range(layer + 1):
    submodules.append(model.gpt_neox.layers[i].attention)
    ae = AutoEncoder(512, dictionary_size).to(DEVICE)
    ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/attn_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt'))
    dictionaries[model.gpt_neox.layers[i].attention] = ae

    submodules.append(model.gpt_neox.layers[i].mlp)
    ae = AutoEncoder(512, dictionary_size).to(DEVICE)
    ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/mlp_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt'))
    dictionaries[model.gpt_neox.layers[i].mlp] = ae

    submodules.append(model.gpt_neox.layers[i])
    ae = AutoEncoder(512, dictionary_size).to(DEVICE)
    ae.load_state_dict(t.load(f'../dictionaries/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_{dictionary_size}/ae.pt'))
    dictionaries[model.gpt_neox.layers[i]] = ae

def metric_fn(model, labels=None):
    attn_mask = model.input[1]['attention_mask']
    acts = model.gpt_neox.layers[layer].output[0]
    acts = acts * attn_mask[:, :, None]
    acts = acts.sum(1) / attn_mask.sum(1)[:, None]
    
    return t.where(
        labels == 0,
        probe(acts),
        - probe(acts)
    )

In [6]:
n_batches = 25
batch_size = 4

running_total = 0
nodes = None

for batch_idx, (clean, labels, _) in tqdm(enumerate(get_data(train=True, ambiguous=True, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if nodes is None:
            nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)
    del effects, _
    gc.collect()

nodes = {k : v / running_total for k, v in nodes.items()}

100%|██████████| 25/25 [00:56<00:00,  2.24s/it]


In [8]:
n_features = 0
for component_idx, effect in enumerate(nodes.values()):
    print(f"Component {component_idx}:")
    for idx in (effect > 0.1).nonzero():
        print(idx.item(), effect[idx].item())
        n_features += 1
print(f"total features: {n_features}")

Component 0:
946 0.21847236156463623
5719 0.1465141624212265
7392 0.30848637223243713
10784 0.1519770622253418
17846 0.314985066652298
22068 0.1732397824525833
23079 0.14993813633918762
25904 0.10273231565952301
28533 0.18635113537311554
29476 0.19284993410110474
31461 0.17133042216300964
31467 0.1645527184009552
32081 0.3282443881034851
32469 1.367475986480713
Component 1:
23752 0.10250095278024673
Component 2:
2995 0.10375188291072845
3842 0.14587925374507904
10258 0.2980944812297821
13387 0.14533501863479614
13968 0.13012699782848358
18382 0.25857868790626526
19369 0.18626706302165985
28127 1.0542335510253906
30518 0.1976676732301712
Component 3:
1022 0.3151320517063141
9651 0.5876054763793945
10060 2.4383997917175293
18967 0.6847831606864929
22084 0.2555842399597168
23898 0.48170098662376404
24799 0.1013026013970375
26504 0.3015204668045044
29626 0.28801754117012024
31201 0.17514200508594513
Component 4:
8147 0.10604129731655121
Component 5:
24159 0.24751698970794678
25018 0.494141

In [7]:
component_idx = 9
feat_idx = 31098

submodule = submodules[component_idx]
dictionary = dictionaries[submodule]

# interpret some features
data = hf_dataset_to_generator("monology/pile-uncopyrighted")
buffer = ActivationBuffer(
    data,
    model,
    submodule,
    out_feats=512,
    in_batch_size=128,
    n_ctxs=512,
)

out = examine_dimension(
    model,
    submodule,
    buffer,
    dictionary,
    dim_idx=feat_idx,
    n_inputs=256
)
print(out.top_tokens)
print(out.top_affected)
out.top_contexts

[(' nursing', 5.101802825927734), (' Nursing', 3.935664653778076), (' nurse', 2.8416316509246826), (' nurses', 2.8025169372558594), (' RN', 1.5293747186660767), (' Teaching', 0.7016168832778931), (' rehabilitation', 0.6920667886734009), ('wife', 0.6749844551086426), ('unte', 0.6074659824371338), (' sewing', 0.6034234762191772), (' caring', 0.4292556643486023), ('ancy', 0.38825851678848267), ('akers', 0.38148021697998047), (' lending', 0.34468209743499756), (' volunteers', 0.3259948790073395), (' drinking', 0.2667080760002136), (' Clinical', 0.2624419331550598), (' inpatient', 0.23919713497161865), (' architect', 0.22629010677337646), (' Medical', 0.2175636887550354), (' relational', 0.19916152954101562), (' Dou', 0.19640293717384338), (' dialysis', 0.176143079996109), (' Leadership', 0.17218172550201416), ('hin', 0.14378610253334045), ('okin', 0.14130336046218872), (' executive', 0.11437082290649414), (' teaching', 0.114189013838768), (' poet', 0.10126757621765137), (' hospitals', 0.09

In [6]:
feats_to_ablate = {
    submodules[0] : [
        946, # 'his'
        # 5719, # 'research'
        7392, # 'He'
        # 10784, # 'Nursing'
        17846, # 'He'
        22068, # 'His'
        # 23079, # 'tastes'
        # 25904, # 'nursing'
        28533, # 'She'
        29476, # 'he'
        31461, # 'His'
        31467, # 'she'
        32081, # 'her'
        32469, # 'She'
    ],
    submodules[1] : [
        # 23752, # capitalized words, especially pronouns
    ],
    submodules[2] : [
        2995, # 'he'
        3842, # 'She'
        10258, # female names
        13387, # 'she'
        13968, # 'He'
        18382, # 'her'
        19369, # 'His'
        28127, # 'She'
        30518, # 'He'
    ],
    submodules[3] : [
        1022, # 'she'
        9651, # female names
        10060, # 'She'
        18967, # 'He'
        22084, # 'he'
        23898, # 'His'
        # 24799, # promotes surnames
        26504, # 'her'
        29626, # 'his'
        # 31201, # 'nursing'
    ],
    submodules[4] : [
        # 8147, # unclear, something with names
    ],
    submodules[5] : [
        24159, # 'She', 'she'
        25018, # female names
    ],
    submodules[6] : [
        4592, # 'her'
        8920, # 'he'
        9877, # female names
        12128, # 'his'
        15017, # 'she'
        # 17369, # contact info
        # 26969, # related to nursing
        30248, # female names
    ],
    submodules[7] : [
        13570, # promotes male-related words
        27472, # female names, promotes female-related words
    ],
    submodules[8] : [
    ],
    submodules[9] : [
        1995, # promotes female-associated words
        9128, # feminine pronouns
        11656, # promotes male-associated words
        12440, # promotes female-associated words
        # 14638, # related to contact information?
        29206, # gendered pronouns
        29295, # female names
        # 31098, # nursing-related words
    ],
    submodules[10] : [
        2959, # promotes female-associated words
        19128, # promotes male-associated words
        22029, # promotes female-associated words
    ],
    submodules[11] : [
    ],
    submodules[12] : [
        19558, # promotes female-associated words
        23545, # 'she'
        24806, # 'her'
        27334, # promotes male-associated words
        31453, # female names
    ],
    submodules[13] : [
        31101, # promotes female-associated words
    ],
    submodules[14] : [
    ],
    submodules[15] : [
        9766, # promotes female-associated words
        12420, # promotes female pronouns
        30220, # promotes male pronouns
    ]
}

print(f"Number of features to ablate: {sum(len(v) for v in feats_to_ablate.values())}")

Number of features to ablate: 55


In [7]:
def n_hot(feats, dim=dictionary_size):
    out = t.zeros(dim, dtype=t.bool, device=DEVICE)
    for feat in feats:
        out[feat] = True
    return out

feats_to_ablate = {
    submodule : n_hot(feats) for submodule, feats in feats_to_ablate.items()
}


In [8]:
is_tuple = {}
with t.no_grad(), model.trace("_"):
    for submodule in submodules:
        is_tuple[submodule] = type(submodule.output.shape) == tuple

def get_acts_ablated(
    text,
    model,
    submodules,
    dictionaries,
    to_ablate
):

    with t.no_grad(), model.trace(text, **tracer_kwargs):
        for submodule in submodules:
            dictionary = dictionaries[submodule]
            feat_idxs = to_ablate[submodule]
            x = submodule.output
            if is_tuple[submodule]:
                x = x[0]
            x_hat, f = dictionary(x, output_features=True)
            res = x - x_hat
            f[...,feat_idxs] = 0.
            if is_tuple[submodule]:
                submodule.output[0][:] = dictionary.decode(f) + res
            else:
                submodule.output = dictionary.decode(f) + res
        attn_mask = model.input[1]['attention_mask']
        act = model.gpt_neox.layers[layer].output[0]
        act = act * attn_mask[:, :, None]
        act = act.sum(1) / attn_mask.sum(1)[:, None]
        act = act.save()
    return act.value


# Accuracy after ablating features judged irrelevant by human annotators

In [30]:
get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, feats_to_ablate)

print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.926347553730011
Ground truth accuracy: 0.8853686451911926
Spurious accuracy: 0.5397465229034424


In [23]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, lambda text : get_acts_abl(text, model, submodules, dictionaries, feats_to_ablate), batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9879666566848755
Accuracy for (0, 1): 0.9644010066986084
Accuracy for (1, 0): 0.559907853603363
Accuracy for (1, 1): 0.7453531622886658


# Concept bottleneck probing baseline

In [53]:
concepts = [    
    ' nurse',
    ' healthcare',
    ' hospital',
    ' patient',
    ' medical',
    ' clinic',
    ' triage',
    ' medication',
    ' emergency',
    ' surgery',
    ' professor',
    ' academia',
    ' research',
    ' university',
    ' tenure',
    ' faculty',
    ' dissertation',
    ' sabbatical',
    ' publication',
    ' grant',
]
# get concept vectors
with t.no_grad(), model.trace(concepts):
    concept_vectors = model.gpt_neox.layers[layer].output[0][:, -1, :].save()
concept_vectors = concept_vectors.value - concept_vectors.value.mean(0, keepdim=True)

def get_bottleneck(text):
    with t.no_grad(), model.trace(text, **tracer_kwargs):
        attn_mask = model.input[1]['attention_mask']
        acts = model.gpt_neox.layers[layer].output[0]
        acts = acts * attn_mask[:, :, None]
        acts = acts.sum(1) / attn_mask.sum(1)[:, None]
        # compute cosine similarity with concept vectors
        sims = (acts @ concept_vectors.T) / (acts.norm(dim=-1)[:, None] @ concept_vectors.norm(dim=-1)[None])
        sims = sims.save()
    return sims.value

In [54]:
cbp_probe, _ = train_probe(get_bottleneck, label_idx=0, dim=len(concepts))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=1))


Ground truth accuracy: 0.8335253596305847
Unintended feature accuracy: 0.600806474685669


In [55]:
# get subgroup accuracies
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9276148676872253
Accuracy for (0, 1): 0.7864062786102295
Accuracy for (1, 0): 0.6774193644523621
Accuracy for (1, 1): 0.9409851431846619


# Get skyline neuron performance

In [56]:
# get neurons which are most influential for giving gender label
neuron_dicts = {
    submodule : IdentityDict(activation_dim).to(DEVICE) for submodule in submodules
}

n_batches = 25
batch_size = 4

running_total = 0
running_nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_data(train=True, ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        neuron_dicts,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if running_nodes is None:
            running_nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                running_nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)

nodes = {k : v / running_total for k, v in running_nodes.items()}

100%|██████████| 25/25 [00:29<00:00,  1.19s/it]


In [78]:
neurons_to_ablate = {}
total_neurons = 0
for component_idx, effect in enumerate(nodes.values()):
    print(f"Component {component_idx}:")
    neurons_to_ablate[submodules[component_idx]] = []
    for idx in (effect.act > 0.2135).nonzero():
        print(idx.item(), effect[idx].item())
        neurons_to_ablate[submodules[component_idx]].append(idx.item())
        total_neurons += 1
print(f"total neurons: {total_neurons}")

neurons_to_ablate = {
    submodule : n_hot([neuron_idx], dim=512) for submodule, neuron_idx in neurons_to_ablate.items()
}

Component 0:
2 0.2645128667354584
42 0.23859986662864685
57 0.22970567643642426
99 0.32683151960372925
130 0.2573539912700653
135 0.5450218319892883
187 0.39180734753608704
197 0.33717024326324463
256 0.2539202570915222
335 0.29710012674331665
343 0.2440398782491684
400 0.2511880695819855
417 0.23313088715076447
421 0.24550659954547882
Component 1:
23 0.21983039379119873
111 0.2551604211330414
156 0.3130127191543579
Component 2:
23 0.3266536593437195
156 0.4058776795864105
Component 3:
23 0.8572849631309509
66 0.23336535692214966
111 0.3592255413532257
156 1.0655235052108765
162 0.25350555777549744
193 0.24525998532772064
209 0.221907839179039
271 0.40670523047447205
334 0.23112164437770844
378 0.2777670621871948
386 0.21385173499584198
394 0.2510640621185303
410 0.23709604144096375
473 0.2567783296108246
503 0.30946648120880127
Component 4:
Component 5:
Component 6:
14 0.322757363319397
56 0.4079655706882477
89 0.22097961604595184
98 0.44845256209373474
111 1.0405858755111694
156 0.29

In [85]:
def get_act_abl(text):
    with t.no_grad(), model.trace(text, **tracer_kwargs):
        for submodule in submodules:
            x = submodule.output
            if is_tuple[submodule]:
                x = x[0]
            x[...,neurons_to_ablate[submodule]] = 0.
            if is_tuple[submodule]:
                submodule.output[0][:] = x
            else:
                submodule.output = x

        attn_mask = model.input[1]['attention_mask']
        act = model.gpt_neox.layers[layer].output[0]
        act = act * attn_mask[:, :, None]
        act = act.sum(1) / attn_mask.sum(1)[:, None]
        act = act.save()
    return act.value

In [86]:
print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.7592936754226685
Ground truth accuracy: 0.6296082735061646
Spurious accuracy: 0.6238479018211365


In [82]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.5157667398452759
Accuracy for (0, 1): 0.0402553491294384
Accuracy for (1, 0): 0.9838709831237793
Accuracy for (1, 1): 0.999535322189331


# Get skyline feature performance

In [9]:
n_batches = 25
batch_size = 4

running_total = 0
running_nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_data(train=True, ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if running_nodes is None:
            running_nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                running_nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)
    del effects, _
    gc.collect()

nodes = {k : v / running_total for k, v in running_nodes.items()}

100%|██████████| 25/25 [00:50<00:00,  2.04s/it]


In [24]:
top_feats_to_ablate = {}
total_features = 0
for component_idx, effect in enumerate(nodes.values()):
    print(f"Component {component_idx}:")
    top_feats_to_ablate[submodules[component_idx]] = []
    for idx in (effect > 0.1107).nonzero():
        print(idx.item(), effect[idx].item())
        top_feats_to_ablate[submodules[component_idx]].append(idx.item())
        total_features += 1
print(f"total features: {total_features}")

Component 0:
946 0.2463148683309555
7392 0.7006653547286987
17846 0.259246826171875
28533 0.3435993492603302
29088 0.1107618659734726
29476 0.14512550830841064
31467 0.34349972009658813
32081 0.17352652549743652
32469 1.1966392993927002
Component 1:
Component 2:
3842 0.18715900182724
10258 0.20733851194381714
13387 0.2763107717037201
18382 0.1151869148015976
19369 0.11484511196613312
28127 0.9618880748748779
30518 0.20647713541984558
31645 0.24555140733718872
Component 3:
1022 0.49627578258514404
3122 0.2066667079925537
9651 0.44658446311950684
10060 2.178098678588867
18967 0.975749135017395
22084 0.16960865259170532
23898 0.27509018778800964
26504 0.135464608669281
29626 0.35908496379852295
Component 4:
Component 5:
24159 0.30867698788642883
25018 0.32441073656082153
Component 6:
4592 0.2154429703950882
8920 0.5691888332366943
9877 0.2772234380245209
12128 0.5521409511566162
12436 0.2059258371591568
15017 3.2283241748809814
26204 0.12540720403194427
30248 0.6767085790634155
Component 

In [26]:
top_feats_to_ablate = {
    submodule : n_hot(feats) for submodule, feats in top_feats_to_ablate.items()
}
get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, top_feats_to_ablate)

In [28]:
print('Ambiguous test accuracy:', test_probe(probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))
print('Spurious accuracy:', test_probe(probe, get_acts_abl, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.9309943914413452
Ground truth accuracy: 0.8853686451911926
Spurious accuracy: 0.5432027578353882


In [30]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(probe, get_acts_abl, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9738969206809998
Accuracy for (0, 1): 0.9248967170715332
Accuracy for (1, 0): 0.7603686451911926
Accuracy for (1, 1): 0.8887081742286682


# Retraining probe on activations after ablating features

In [9]:
get_acts_abl = lambda text : get_acts_ablated(text, model, submodules, dictionaries, feats_to_ablate)

new_probe, _ = train_probe(get_acts_abl, label_idx=0)
print('Ambiguous test accuracy:', test_probe(new_probe, get_acts_abl, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=1))

Ambiguous test accuracy: 0.9572490453720093
Ground truth accuracy: 0.9308755993843079
Unintended feature accuracy: 0.5195852518081665


In [10]:
subgroups = get_subgroups(train=False, ambiguous=False)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(new_probe, get_acts_abl, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.9477938413619995
Accuracy for (0, 1): 0.8895981907844543
Accuracy for (1, 0): 0.9423962831497192
Accuracy for (1, 1): 0.9679368138313293
