In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from datasets import load_dataset
import random
from nnsight import LanguageModel
import torch as t
from torch import nn
from attribution import patching_effect, Submodule
from dictionary_learning import AutoEncoder, ActivationBuffer, JumpReluAutoEncoder
from dictionary_learning.dictionary import IdentityDict
from dictionary_learning.interp import examine_dimension
from dictionary_loading_utils import load_saes_and_submodules
from dictionary_learning.utils import hf_dataset_to_generator
from transformers import BitsAndBytesConfig
from huggingface_hub import hf_hub_download, list_repo_files
from tqdm import tqdm
from typing import Literal
import gc
from collections import defaultdict
import hashlib


DEBUGGING = False

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

# model hyperparameters
DTYPE = t.bfloat16
DEVICE = 'cuda:0'
# model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE, dispatch=True)
model = LanguageModel('google/gemma-2-2b', device_map=DEVICE, dispatch=True,
                      attn_implementation="eager", torch_dtype=DTYPE)
activation_dim = 2304

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
# dataset hyperparameters
dataset = load_dataset("LabHC/bias_in_bios")
profession_dict = {'professor' : 21, 'nurse' : 13}
male_prof = 'professor'
female_prof = 'nurse'

# data preparation hyperparameters
SEED = 42

def get_text_batches(
    split: Literal["train", "test"] = "train",
    ambiguous=True, 
    batch_size=32, 
    seed=SEED
):
    data = dataset[split]
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg), len(pos)])
        neg, pos = neg[:n], pos[:n]
        data = neg + pos
        labels = [0]*n + [1]*n
        idxs = list(range(2*n))
        random.Random(seed).shuffle(idxs)
        data, labels = [data[i] for i in idxs], [labels[i] for i in idxs]
        true_labels = spurious_labels = labels
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg_neg), len(neg_pos), len(pos_neg), len(pos_pos)])
        neg_neg, neg_pos, pos_neg, pos_pos = neg_neg[:n], neg_pos[:n], pos_neg[:n], pos_pos[:n]
        data = neg_neg + neg_pos + pos_neg + pos_pos
        true_labels     = [0]*n + [0]*n + [1]*n + [1]*n
        spurious_labels = [0]*n + [1]*n + [0]*n + [1]*n
        idxs = list(range(4*n))
        random.Random(seed).shuffle(idxs)
        data, true_labels, spurious_labels = [data[i] for i in idxs], [true_labels[i] for i in idxs], [spurious_labels[i] for i in idxs]

    batches = [
        (data[i:i+batch_size], t.tensor(true_labels[i:i+batch_size], device=DEVICE), t.tensor(spurious_labels[i:i+batch_size], device=DEVICE)) for i in range(0, len(data), batch_size)
    ]

    return batches

def get_subgroups(
        split: Literal["train", "test"] = "test",
        batch_size=32,
):
    data = dataset[split]
    neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
    neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
    pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
    pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
    neg_neg_labels, neg_pos_labels, pos_neg_labels, pos_pos_labels = (0, 0), (0, 1), (1, 0), (1, 1)
    subgroups = [(neg_neg, neg_neg_labels), (neg_pos, neg_pos_labels), (pos_neg, pos_neg_labels), (pos_pos, pos_pos_labels)]
    
    out = {}
    for data, label_profile in subgroups:
        out[label_profile] = []
        for i in range(0, len(data), batch_size):
            text = data[i:i+batch_size]
            out[label_profile].append(
                (
                    text,
                    t.tensor([label_profile[0]]*len(text), device=DEVICE),
                    t.tensor([label_profile[1]]*len(text), device=DEVICE)
                )
            )
    return out

In [3]:
def pool_acts(acts, attn_mask):
    return (acts * attn_mask[:, :, None]).sum(1) / attn_mask.sum(1)[:, None]

@t.no_grad()
def collect_activations(
    model,
    layer,
    text_batches,
):
    with tqdm(total=len(text_batches), desc="Collecting activations") as pbar:
        for text_batch, *labels in text_batches:
            with model.trace(text_batch, **tracer_kwargs):
                attn_mask = model.input[1]['attention_mask']
                acts = model.model.layers[layer].output[0]
                pooled_acts = pool_acts(acts, attn_mask).save()
            yield pooled_acts.value, *labels
            pbar.update(1)

In [4]:
# probe training hyperparameters

layer = 22 # model layer for attaching linear classification head

class Probe(nn.Module):
    def __init__(self, activation_dim, dtype=DTYPE):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True, dtype=dtype)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits

def train_probe(
    activation_batches,
    label_idx=0,
    lr=1e-2,
    epochs=1,
    seed=SEED,
):
    t.manual_seed(seed)

    probe = Probe(activation_dim).to(DEVICE)
    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()
    losses = []

    for epoch in range(epochs):
        for act, *labels, in activation_batches:
            optimizer.zero_grad()
            logits = probe(act)
            loss = criterion(logits, labels[label_idx].to(logits))
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
    return probe, losses

@t.no_grad()
def test_probe(
    probe,
    activation_batches,
):
    corrects = defaultdict(list)
    for acts, *labels in activation_batches:
        logits = probe(acts)
        preds = (logits > 0.0).long()
        for idx, label in enumerate(labels):
            corrects[idx].append(preds == label)

    accs = {idx: t.cat(corrects[idx]).float().mean().item() for idx in corrects}

    return accs

In [None]:
oracle, _ = train_probe(
    activation_batches=collect_activations(model, layer, get_text_batches(split="train", ambiguous=False))
)

ambiguous_accs = test_probe(oracle, activation_batches=collect_activations(model, layer, get_text_batches(split="test", ambiguous=True)))
print(f"ambiguous test accuracy: {ambiguous_accs[0]}")

unambiguous_accs = test_probe(oracle, activation_batches=collect_activations(model, layer, get_text_batches(split="test", ambiguous=False)))
print(f"ground truth accuracy: {unambiguous_accs[0]}")
print(f"unintended feature accuracy: {unambiguous_accs[1]}")

for subgroup, batches in get_subgroups().items():
    subgroup_accs = test_probe(oracle, activation_batches=collect_activations(model, layer, batches))
    print(f"subgroup {subgroup} accuracy: {subgroup_accs[0]}")

In [5]:
save_path = f"probe_layer_{layer}_{str(DTYPE).split('.')[-1]}.pt"
if os.path.exists(save_path):
    print(f"loading probe from {save_path}")
    probe = t.load(save_path)
else:
    probe, _ = train_probe(
        activation_batches=collect_activations(model, layer, get_text_batches(split="train", ambiguous=True))
    )
    t.save(probe, save_path)
    print(f"probe saved to {save_path}")

# ambiguous_accs = test_probe(probe, activation_batches=collect_activations(model, layer, get_text_batches(split="test", ambiguous=True)))
# print(f"ambiguous test accuracy: {ambiguous_accs[0]}")

# unambiguous_accs = test_probe(probe, activation_batches=collect_activations(model, layer, get_text_batches(split="test", ambiguous=False)))
# print(f"ground truth accuracy: {unambiguous_accs[0]}")
# print(f"unintended feature accuracy: {unambiguous_accs[1]}")

# for subgroup, batches in get_subgroups().items():
#     subgroup_accs = test_probe(probe, activation_batches=collect_activations(model, layer, batches))
#     print(f"subgroup {subgroup} accuracy: {subgroup_accs[0]}")

loading probe from probe_layer_22_bfloat16.pt


In [6]:
# loading dictionaries
submodules, dictionaries = load_saes_and_submodules(
    model, 
    "google/gemma-2-2b", 
    thru_layer=layer,
    include_embed=False,
    dtype=DTYPE,
    device=DEVICE,
)

def metric_fn(model, labels=None):
    attn_mask = model.input[1]['attention_mask']
    acts = model.model.layers[layer].output[0]
    acts = pool_acts(acts, attn_mask)
    
    return t.where(
        labels == 0,
        probe(acts),
        - probe(acts)
    )

Loading Gemma SAEs:   0%|          | 0/23 [00:00<?, ?it/s]

Loading Gemma SAEs: 100%|██████████| 23/23 [01:17<00:00,  3.38s/it]


In [10]:
# find most influential features
# n_batches = 200
n_batches = 100
batch_size = 1

nodes = None

for batch_idx, (clean, labels, _) in tqdm(enumerate(get_text_batches(split="train", ambiguous=True, batch_size=batch_size)), total=n_batches):
    if batch_idx == n_batches:
        break

    # if effects are already cached, skip
    hash_input = clean + [s.name for s in submodules]
    hash_str = ''.join(hash_input)
    hash_digest = hashlib.md5(hash_str.encode()).hexdigest()
    if os.path.exists(f"effects/{hash_digest}.pt"):
        continue

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    to_save = {
        k.name : v.detach().to("cpu") for k, v in effects.items()
    }
    t.save(to_save, f"effects/{hash_digest}.pt")

    del effects, _
    gc.collect()


  0%|          | 0/100 [00:00<?, ?it/s]You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 100/100 [19:51<00:00, 11.91s/it]


In [11]:
# aggregate effects
aggregated_effects = {submodule.name : 0 for submodule in submodules}

for idx, (clean, *_) in enumerate(get_text_batches(split="train", ambiguous=True, batch_size=batch_size)):
    if idx == n_batches:
        break
    hash_input = clean + [s.name for s in submodules]
    hash_str = ''.join(hash_input)
    hash_digest = hashlib.md5(hash_str.encode()).hexdigest()
    effects = t.load(f"effects/{hash_digest}.pt")
    for submodule in submodules:
        aggregated_effects[submodule.name] += (
            effects[submodule.name].act[:,1:,:] # remove BOS features
        ).sum(dim=1).mean(dim=0)

aggregated_effects = {k : v / (batch_size * n_batches) for k, v in aggregated_effects.items()}

In [30]:
count = 0
for k, v in aggregated_effects.items():
    print(f"\"{k}\": [")
    for idx in (v > 6).nonzero():
        count += 1
        print(f"    {idx.item()} : {v[idx].item()},")
    print("],")
print(f"total features: {count}")

"attn_0": [
],
"mlp_0": [
],
"resid_0": [
    4449 : 7.4375,
],
"attn_1": [
],
"mlp_1": [
],
"resid_1": [
    4521 : 7.25,
    11782 : 6.375,
],
"attn_2": [
],
"mlp_2": [
],
"resid_2": [
    5853 : 6.71875,
],
"attn_3": [
],
"mlp_3": [
],
"resid_3": [
],
"attn_4": [
],
"mlp_4": [
],
"resid_4": [
],
"attn_5": [
],
"mlp_5": [
],
"resid_5": [
    2864 : 9.25,
    11682 : 8.1875,
],
"attn_6": [
],
"mlp_6": [
],
"resid_6": [
    4068 : 13.0625,
    7008 : 7.1875,
],
"attn_7": [
],
"mlp_7": [
],
"resid_7": [
    7111 : 8.0625,
],
"attn_8": [
],
"mlp_8": [
],
"resid_8": [
    6952 : 7.09375,
    9949 : 19.5,
],
"attn_9": [
],
"mlp_9": [
],
"resid_9": [
    6952 : 9.375,
    15246 : 14.875,
],
"attn_10": [
],
"mlp_10": [
],
"resid_10": [
    3711 : 14.5625,
    6952 : 9.125,
],
"attn_11": [
],
"mlp_11": [
],
"resid_11": [
    3013 : 27.375,
    4467 : 7.3125,
    11649 : 6.15625,
],
"attn_12": [
],
"mlp_12": [
],
"resid_12": [
    6335 : 9.25,
    11114 : 36.25,
    11480 : 6.28125,
],
"attn_1

In [43]:
# interpret features with Neuronpedia API

submodule_name = "attn_21"
feature_idx = 10118

from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release="gemma-2-2b", sae_id=None, feature_idx=feature_idx):
    return html_template.format(sae_release, sae_id, feature_idx)

# Extract the type and number from submodule_name
submodule_type, submodule_number = submodule_name.split('_')

# Construct the sae_id based on the type
if submodule_type == "resid":
    sae_id = f"{submodule_number}-gemmascope-res-16k"
elif submodule_type == "attn":
    sae_id = f"{submodule_number}-gemmascope-att-16k"
elif submodule_type == "mlp":
    sae_id = f"{submodule_number}-gemmascope-mlp-16k"
else:
    raise ValueError("Unknown submodule type")

html = get_dashboard_html(sae_release="gemma-2-2b", sae_id=sae_id, feature_idx=feature_idx)
IFrame(html, width=1200, height=600)

In [44]:
feats_to_ablate = {
    "attn_0": [
    ],
    "mlp_0": [
    ],
    "resid_0": [
        4449, # "he", "him"
    ],
    "attn_1": [
    ],
    "mlp_1": [
    ],
    "resid_1": [
        4521, # his"
        11782, # # "woman", "she", "her", "female"
    ],
    "attn_2": [
    ],
    "mlp_2": [
    ],
    "resid_2": [
        5853, # "he"
    ],
    "attn_3": [
    ],
    "mlp_3": [
    ],
    "resid_3": [
    ],
    "attn_4": [
    ],
    "mlp_4": [
    ],
    "resid_4": [
    ],
    "attn_5": [
    ],
    "mlp_5": [
    ],
    "resid_5": [
        2864, # "her"
        11682, # "he"
    ],
    "attn_6": [
    ],
    "mlp_6": [
    ],
    "resid_6": [
        4068, # female pronouns
        7008, # gendered pronouns
    ],
    "attn_7": [
    ],
    "mlp_7": [
    ],
    "resid_7": [
        7111, # text about women
    ],
    "attn_8": [
    ],
    "mlp_8": [
    ],
    "resid_8": [
        6952, # gendered pronouns
        9949, # female pronouns
    ],
    "attn_9": [
    ],
    "mlp_9": [
    ],
    "resid_9": [
        6952, # "she"
        15246, # female pronouns
    ],
    "attn_10": [
    ],
    "mlp_10": [
    ],
    "resid_10": [
        3711, # promotes female-associated words
        6952, # masculine pronouns
    ],
    "attn_11": [
    ],
    "mlp_11": [
    ],
    "resid_11": [
        3013, # descriptions of women
        4467, # "He"
        11649, # "she", "her"
    ],
    "attn_12": [
    ],
    "mlp_12": [
    ],
    "resid_12": [
        6335, # "He"
        11114, # descriptions of women
        11480, # "his", "her"
    ],
    "attn_13": [
    ],
    "mlp_13": [
    ],
    "resid_13": [
        192, # descriptions of women
        495, # "her"
        14755, # "he"
    ],
    "attn_14": [
    ],
    "mlp_14": [
    ],
    "resid_14": [
        2354, # descriptions of women
        12665, # male pronouns
    ],
    "attn_15": [
    ],
    "mlp_15": [
    ],
    "resid_15": [
        798, # promotes feminine pronouns
        6211, # gendered pronouns
    ],
    "attn_16": [
    ],
    "mlp_16": [
    ],
    "resid_16": [
        6047, # women, girls, women's names
        15567, # "he", "she"
        16351, # promotes feminine pronouns
    ],
    "attn_17": [
    ],
    "mlp_17": [
    ],
    "resid_17": [
        6011, # promotes feminine pronouns
    ],
    "attn_18": [
    ],
    "mlp_18": [
    ],
    "resid_18": [
        61, # descriptions of female royalty
    ],
    "attn_19": [
    ],
    "mlp_19": [
    ],
    "resid_19": [
        13002, # clauses starting with "she"
    ],
    "attn_20": [
        9711, # "her", "she"
    ],
    "mlp_20": [
    ],
    "resid_20": [
        7116, # promotes masculine pronouns
        # 7324, # nursing
        12545, # promotes female pronouns
    ],
    "attn_21": [
        2740, # promotes feminine pronouns
        10118, # masculine pronouns
    ],
    "mlp_21": [
    ],
    "resid_21": [
        1065, # promotes masculine pronouns
        # 1653, # nursing
        4430, # promotes feminine pronouns
    ],
    "attn_22": [
    ],
    "mlp_22": [
    ],
    "resid_22": [
        1208, # promotes masculine pronouns
        3497, # promotes feminine pronouns
        # 7961, # nursing
    ],
}
print(f"number of features to ablate: {sum([len(v) for v in feats_to_ablate.values()])}")

number of features to ablate: 43


In [45]:
# putting feats_to_ablate in a more useful format
def n_hot(feats, dim):
    out = t.zeros(dim, dtype=t.bool, device=DEVICE)
    for feat in feats:
        out[feat] = True
    return out

feats_to_ablate = {
    submodule : n_hot(feats_to_ablate[submodule.name], dictionaries[submodule].dict_size) for submodule in submodules
}

In [46]:
@t.no_grad()
def collect_acts_ablated(
    text_batches,
    model,
    submodules,
    dictionaries,
    to_ablate,
    layer,
):
    with tqdm(total=len(text_batches), desc="Collecting activations with ablations") as pbar:
        for text, *labels in text_batches:
            with model.trace(text, **tracer_kwargs):
                for submodule in submodules:
                    dictionary = dictionaries[submodule]
                    feat_idxs = to_ablate[submodule]
                    if len(feat_idxs) == 0:
                        continue
                    x = submodule.get_activation()
                    x_hat, f = dictionary(x, output_features=True)
                    res = x - x_hat
                    f[:, :,feat_idxs] = 0. # zero ablation
                    submodule.set_activation(dictionary.decode(f) + res)
                attn_mask = model.input[1]['attention_mask']
                act = model.model.layers[layer].output[0]
                pooled_act = pool_acts(act, attn_mask).save()
            yield pooled_act.value, *labels
            pbar.update(1)


# Accuracy after ablating features judged irrelevant by human annotators

In [47]:
ambiguous_accs = test_probe(
    probe,
    activation_batches=collect_acts_ablated(
        get_text_batches(split="test", ambiguous=True), 
        model, submodules, dictionaries, feats_to_ablate, layer
    ),
)
print(f"Ambiguous test accuracy: {ambiguous_accs[0]}")

unambiguous_accs = test_probe(
    probe,
    activation_batches=collect_acts_ablated(
        get_text_batches(split="test", ambiguous=False),
        model, submodules, dictionaries, feats_to_ablate, layer
    ),
)
print(f"Ground truth accuracy: {unambiguous_accs[0]}")
print(f"Spurious accuracy: {unambiguous_accs[1]}")


Collecting activations with ablations: 100%|██████████| 269/269 [02:33<00:00,  1.75it/s]


Ambiguous test accuracy: 0.7659153938293457


Collecting activations with ablations: 100%|██████████| 55/55 [00:31<00:00,  1.77it/s]

Ground truth accuracy: 0.7603686451911926
Spurious accuracy: 0.514976978302002





# Concept bottleneck probing baseline

In [41]:
concepts = [    
    ' nurse',
    ' healthcare',
    ' hospital',
    ' patient',
    ' medical',
    ' clinic',
    ' triage',
    ' medication',
    ' emergency',
    ' surgery',
    ' professor',
    ' academia',
    ' research',
    ' university',
    ' tenure',
    ' faculty',
    ' dissertation',
    ' sabbatical',
    ' publication',
    ' grant',
]
# get concept vectors
with t.no_grad(), model.trace(concepts):
    concept_vectors = model.model.layers[layer].output[0][:, -1, :].save()
concept_vectors = concept_vectors.value - concept_vectors.value.mean(0, keepdim=True)

def get_bottleneck(text):
    with t.no_grad(), model.trace(text, **tracer_kwargs):
        attn_mask = model.input[1]['attention_mask']
        acts = model.model.layers[layer].output[0]
        acts = acts * attn_mask[:, :, None]
        acts = acts.sum(1) / attn_mask.sum(1)[:, None]
        # compute cosine similarity with concept vectors
        sims = (acts @ concept_vectors.T) / (acts.norm(dim=-1)[:, None] @ concept_vectors.norm(dim=-1)[None])
        sims = sims.save()
    return sims.value

In [42]:
cbp_probe, _ = train_probe(get_bottleneck, label_idx=0, dim=len(concepts), batches=get_data(batch_size=batch_size))
batches = get_data(train=False, ambiguous=False, batch_size=batch_size)
print('Ground truth accuracy:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=1))



epoch 0:  [0.6971435546875, 0.7413330078125, 0.70391845703125, 0.6898193359375, 0.67218017578125, 0.67254638671875, 0.699127197265625, 0.68316650390625, 0.68304443359375, 0.6861419677734375, 0.6903018951416016, 0.6856677532196045, 0.6833877563476562, 0.67919921875, 0.6834793090820312, 0.6787109375, 0.6787109375, 0.6835174560546875, 0.673492431640625, 0.6913928985595703, 0.666854977607727, 0.6630859375, 0.676666259765625, 0.679840087890625, 0.6655349731445312, 0.6761436462402344, 0.6749162673950195, 0.6784286499023438, 0.6638031005859375, 0.6699142456054688, 0.6470947265625, 0.6814117431640625, 0.6573638916015625, 0.665270209312439, 0.636077880859375, 0.6762542724609375, 0.6952438354492188, 0.6825408935546875, 0.691162109375, 0.661895751953125, 0.63525390625, 0.685699462890625, 0.6450138092041016, 0.660736083984375, 0.66497802734375, 0.63134765625, 0.6502838134765625, 0.65863037109375, 0.6770496368408203, 0.6722412109375, 0.6806640625, 0.638916015625, 0.65374755859375, 0.6380615234375,

In [43]:
layer

19

In [44]:
# get subgroup accuracies
subgroups = get_subgroups(train=False, ambiguous=False, batch_size=batch_size)
for label_profile, batches in subgroups.items():
    print(f'Accuracy for {label_profile}:', test_probe(cbp_probe, get_bottleneck, batches=batches, label_idx=0))

Accuracy for (0, 0): 0.8610305190086365
Accuracy for (0, 1): 0.8427337408065796
Accuracy for (1, 0): 0.9193548560142517
Accuracy for (1, 1): 0.928438663482666


# Get skyline neuron performance

In [15]:
# get neurons which are most influential for giving gender label
neuron_dicts = {
    submodule : IdentityDict(activation_dim).to(DEVICE) for submodule in submodules
}

n_batches = 100
batch_size = 1

running_total = 0
nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_text_batches(split="train", ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    effects, _, _, _ = patching_effect(
        clean,
        None,
        model,
        submodules,
        neuron_dicts,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    with t.no_grad():
        if nodes is None:
            nodes = {k : len(clean) * v.sum(dim=1).mean(dim=0) for k, v in effects.items()}
        else:
            for k, v in effects.items():
                nodes[k] += len(clean) * v.sum(dim=1).mean(dim=0)
        running_total += len(clean)
    del effects, _
    gc.collect()

nodes = {k : v / running_total for k, v in nodes.items()}

100%|██████████| 200/200 [48:11<00:00, 14.46s/it]


In [16]:
for i, node in enumerate(nodes):
    t.save(nodes[node], open(f"effects_l0s_nearest_100/node_neuronskyline_{i}.pt", "wb"))

In [32]:
def n_hot(feats, dim=2304):
    out = t.zeros(dim, dtype=t.bool, device=DEVICE)
    for feat in feats:
        out[feat] = True
    return out

neurons_to_ablate = {}
total_neurons = 0
for component_idx, effect in enumerate(nodes.values()):
    print(f"Component {component_idx}:")
    neurons_to_ablate[submodules[component_idx]] = []
    if effect.act.shape[-1] != 2304:
        continue
    for idx in (effect.act > 1.8).nonzero():
        print(idx.item(), effect[idx].item())
        neurons_to_ablate[submodules[component_idx]].append(idx.item())
        total_neurons += 1
print(f"total neurons: {total_neurons}")

neurons_to_ablate = {
    submodule : n_hot([neuron_idx], dim=2304) for submodule, neuron_idx in neurons_to_ablate.items()
}

Component 0:
331 1.90625
535 6.90625
540 2.015625
1068 2.15625
1142 4.125
1393 6.4375
Component 1:
Component 2:
Component 3:
243 2.40625
679 3.296875
881 4.3125
1170 2.1875
1570 3.375
2135 2.328125
Component 4:
Component 5:
Component 6:
1002 4.65625
1570 3.640625
2135 8.25
Component 7:
Component 8:
Component 9:
334 3.78125
535 16.375
1570 2.15625
2135 6.71875
Component 10:
Component 11:
Component 12:
629 4.3125
1570 4.46875
Component 13:
Component 14:
Component 15:
482 2.234375
1068 2.140625
1261 2.09375
1570 4.25
Component 16:
Component 17:
Component 18:
624 2.65625
682 2.375
1068 2.203125
1408 3.1875
1570 4.1875
Component 19:
Component 20:
Component 21:
292 1.8515625
624 5.875
682 4.53125
1068 4.8125
1570 3.78125
Component 22:
Component 23:
Component 24:
535 17.25
1261 1.8671875
1699 1.9921875
Component 25:
Component 26:
1340 1.859375
Component 27:
1699 2.375
Component 28:
Component 29:
Component 30:
682 4.375
1068 2.78125
1170 2.421875
Component 31:
Component 32:
Component 33:
334 6

In [None]:
ambiguous_text_batches = get_text_batches(split="test", ambiguous=True, batch_size=batch_size)
print(
    'Ambiguous test accuracy:', 
    test_probe(
        probe,
        activation_batches=collect_acts_ablated(ambiguous_text_batches, model, submodules, neuron_dicts, neurons_to_ablate, layer),
    )
)
unambiguous_text_batches = get_text_batches(split="test", ambiguous=False, batch_size=batch_size)
print(
    "Ground truth accuracy:",
    test_probe(
        probe,
        activation_batches=collect_acts_ablated(unambiguous_text_batches, model, submodules, neuron_dicts, neurons_to_ablate, layer),
        label_idx=0,
    ),
)
print(
    "Spurious accuracy:",
    test_probe(
        probe,
        activation_batches=collect_acts_ablated(unambiguous_text_batches, model, submodules, neuron_dicts, neurons_to_ablate, layer),
        label_idx=1,
    ),
)
subgroups = get_subgroups(split="test", ambiguous=False, batch_size=batch_size)
for label_profile, batches in subgroups.items():
    print(
        f"Accuracy for {label_profile}:",
        test_probe(
            probe,
            activation_batches=collect_acts_ablated(batches, model, submodules, neuron_dicts, neurons_to_ablate, layer),
            label_idx=0,
        ),
    )

# Get skyline feature performance

In [48]:
# get features which are most useful for predicting gender label
n_batches = 100
batch_size = 1

running_total = 0
running_nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_text_batches(split="train", ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    hash_input = clean + [s.name for s in submodules]
    hash_str = ''.join(hash_input) + "_feature_skyline"
    hash_digest = hashlib.md5(hash_str.encode()).hexdigest()
    if os.path.exists(f"effects/{hash_digest}.pt"):
        continue

    effects, *_ = patching_effect(
        clean,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    to_save = {
        k.name : v.detach().to("cpu") for k, v in effects.items()
    }
    t.save(to_save, f"effects/{hash_digest}.pt")

    del effects
    gc.collect()

100%|██████████| 100/100 [18:59<00:00, 11.39s/it]


In [49]:
# aggregate effects
aggregated_effects = {submodule.name : 0 for submodule in submodules}

for idx, (clean, *_) in enumerate(get_text_batches(split="train", ambiguous=False, batch_size=batch_size)):
    if idx == n_batches:
        break
    hash_input = clean + [s.name for s in submodules]
    hash_str = ''.join(hash_input) + "_feature_skyline"
    hash_digest = hashlib.md5(hash_str.encode()).hexdigest()
    effects = t.load(f"effects/{hash_digest}.pt")
    for submodule in submodules:
        aggregated_effects[submodule.name] += (
            effects[submodule.name].act[:,1:,:] # ignore BOS token
        ).sum(dim=1).mean(dim=0)

aggregated_effects = {k : v / (batch_size * n_batches) for k, v in aggregated_effects.items()}

In [53]:
top_feats_to_ablate = {}
total_features = 0
for submodule in submodules:
    print(f"Component {submodule.name}:")
    top_feats_to_ablate[submodule.name] = []
    for idx in (aggregated_effects[submodule.name] > 6.1).nonzero():
        print(idx.item(), aggregated_effects[submodule.name][idx].item())
        top_feats_to_ablate[submodule.name].append(idx.item())
        total_features += 1
print(f"total features: {total_features}")

Component attn_0:
Component mlp_0:
Component resid_0:
4449 8.0
Component attn_1:
Component mlp_1:
Component resid_1:
11782 6.96875
Component attn_2:
Component mlp_2:
Component resid_2:
5853 6.53125
12818 6.1875
Component attn_3:
Component mlp_3:
Component resid_3:
9949 7.0
Component attn_4:
Component mlp_4:
Component resid_4:
4675 6.53125
Component attn_5:
Component mlp_5:
Component resid_5:
2864 9.8125
11682 7.125
Component attn_6:
Component mlp_6:
Component resid_6:
4068 9.3125
7008 9.625
Component attn_7:
Component mlp_7:
Component resid_7:
7111 6.84375
Component attn_8:
Component mlp_8:
Component resid_8:
6952 7.875
9949 14.625
Component attn_9:
Component mlp_9:
Component resid_9:
6952 10.0
15246 11.25
Component attn_10:
Component mlp_10:
Component resid_10:
3711 11.1875
6952 10.25
Component attn_11:
Component mlp_11:
Component resid_11:
3013 19.75
4467 8.1875
11649 6.15625
Component attn_12:
Component mlp_12:
Component resid_12:
6335 10.1875
11114 25.75
Component attn_13:
Componen

In [54]:
def n_hot(feats, dim):
    out = t.zeros(dim, dtype=t.bool, device=DEVICE)
    for feat in feats:
        out[feat] = True
    return out

top_feats_to_ablate = {
    submodule : n_hot(top_feats_to_ablate[submodule.name], dictionaries[submodule].dict_size) for submodule in submodules
}

In [55]:
ambiguous_accs = test_probe(
    probe,
    activation_batches=collect_acts_ablated(
        get_text_batches(split="test", ambiguous=True), 
        model, submodules, dictionaries, top_feats_to_ablate, layer
    ),
)
print(f"Ambiguous test accuracy: {ambiguous_accs[0]}")

unambiguous_accs = test_probe(
    probe,
    activation_batches=collect_acts_ablated(
        get_text_batches(split="test", ambiguous=False),
        model, submodules, dictionaries, top_feats_to_ablate, layer
    ),
)
print(f"Ground truth accuracy: {unambiguous_accs[0]}")
print(f"Spurious accuracy: {unambiguous_accs[1]}")


Collecting activations with ablations: 100%|██████████| 269/269 [02:32<00:00,  1.76it/s]


Ambiguous test accuracy: 0.8300418257713318


Collecting activations with ablations: 100%|██████████| 55/55 [00:30<00:00,  1.79it/s]

Ground truth accuracy: 0.8087557554244995
Spurious accuracy: 0.5368663668632507





# Retraining probe on activations after ablating features

In [57]:
retrained_probe, _ = train_probe(
    activation_batches=collect_acts_ablated(
        get_text_batches(split="train", ambiguous=True), 
        model, submodules, dictionaries, feats_to_ablate, layer
    ),
)
ambiguous_test_accs = test_probe(
    retrained_probe,
    activation_batches=collect_acts_ablated(
        get_text_batches(split="test", ambiguous=True),
        model, submodules, dictionaries, feats_to_ablate, layer
    ),
)
print(f"Ambiguous test accuracy: {ambiguous_test_accs[0]}")
unambiguous_test_accs = test_probe(
    retrained_probe,
    activation_batches=collect_acts_ablated(
        get_text_batches(split="test", ambiguous=False),
        model, submodules, dictionaries, feats_to_ablate, layer
    ),
)
print(f"Ground truth accuracy: {unambiguous_test_accs[0]}")
print(f"Spurious accuracy: {unambiguous_test_accs[1]}")

Collecting activations with ablations: 100%|██████████| 700/700 [06:59<00:00,  1.67it/s]
Collecting activations with ablations: 100%|██████████| 269/269 [02:40<00:00,  1.68it/s]


Ambiguous test accuracy: 0.977811336517334


Collecting activations with ablations: 100%|██████████| 55/55 [00:31<00:00,  1.75it/s]

Ground truth accuracy: 0.9504608511924744
Spurious accuracy: 0.524193525314331



