In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from datasets import load_dataset
import random
from nnsight import LanguageModel
import torch as t
from torch import nn
from attribution import patching_effect, Submodule
from dictionary_learning import AutoEncoder, ActivationBuffer, JumpReluAutoEncoder
from dictionary_learning.dictionary import IdentityDict
from dictionary_learning.interp import examine_dimension
from dictionary_loading_utils import load_saes_and_submodules
from dictionary_learning.utils import hf_dataset_to_generator
from transformers import BitsAndBytesConfig
from huggingface_hub import hf_hub_download, list_repo_files
from tqdm import tqdm
from typing import Literal
import gc
from collections import defaultdict
import hashlib


DEBUGGING = False

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

# model hyperparameters
DTYPE = t.bfloat16
DEVICE = 'cuda:0'
# model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE, dispatch=True)
model = LanguageModel('google/gemma-2-2b', device_map=DEVICE, dispatch=True,
                      attn_implementation="eager", torch_dtype=DTYPE)
activation_dim = 2304

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
# dataset hyperparameters
dataset = load_dataset("LabHC/bias_in_bios")
profession_dict = {'professor' : 21, 'nurse' : 13}
male_prof = 'professor'
female_prof = 'nurse'

# data preparation hyperparameters
SEED = 42

def get_text_batches(
    split: Literal["train", "test"] = "train",
    ambiguous=True, 
    batch_size=32, 
    seed=SEED
):
    data = dataset[split]
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg), len(pos)])
        neg, pos = neg[:n], pos[:n]
        data = neg + pos
        labels = [0]*n + [1]*n
        idxs = list(range(2*n))
        random.Random(seed).shuffle(idxs)
        data, labels = [data[i] for i in idxs], [labels[i] for i in idxs]
        true_labels = spurious_labels = labels
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg_neg), len(neg_pos), len(pos_neg), len(pos_pos)])
        neg_neg, neg_pos, pos_neg, pos_pos = neg_neg[:n], neg_pos[:n], pos_neg[:n], pos_pos[:n]
        data = neg_neg + neg_pos + pos_neg + pos_pos
        true_labels     = [0]*n + [0]*n + [1]*n + [1]*n
        spurious_labels = [0]*n + [1]*n + [0]*n + [1]*n
        idxs = list(range(4*n))
        random.Random(seed).shuffle(idxs)
        data, true_labels, spurious_labels = [data[i] for i in idxs], [true_labels[i] for i in idxs], [spurious_labels[i] for i in idxs]

    batches = [
        (data[i:i+batch_size], t.tensor(true_labels[i:i+batch_size], device=DEVICE), t.tensor(spurious_labels[i:i+batch_size], device=DEVICE)) for i in range(0, len(data), batch_size)
    ]

    return batches

def get_subgroups(
        split: Literal["train", "test"] = "test",
        batch_size=32,
):
    data = dataset[split]
    neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
    neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
    pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
    pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
    neg_neg_labels, neg_pos_labels, pos_neg_labels, pos_pos_labels = (0, 0), (0, 1), (1, 0), (1, 1)
    subgroups = [(neg_neg, neg_neg_labels), (neg_pos, neg_pos_labels), (pos_neg, pos_neg_labels), (pos_pos, pos_pos_labels)]
    
    out = {}
    for data, label_profile in subgroups:
        out[label_profile] = []
        for i in range(0, len(data), batch_size):
            text = data[i:i+batch_size]
            out[label_profile].append(
                (
                    text,
                    t.tensor([label_profile[0]]*len(text), device=DEVICE),
                    t.tensor([label_profile[1]]*len(text), device=DEVICE)
                )
            )
    return out

In [3]:
def pool_acts(acts, attn_mask):
    return (acts * attn_mask[:, :, None]).sum(1) / attn_mask.sum(1)[:, None]

@t.no_grad()
def collect_activations(
    model,
    layer,
    text_batches,
):
    with tqdm(total=len(text_batches), desc="Collecting activations") as pbar:
        for text_batch, *labels in text_batches:
            with model.trace(text_batch, **tracer_kwargs):
                attn_mask = model.input[1]['attention_mask']
                acts = model.model.layers[layer].output[0]
                pooled_acts = pool_acts(acts, attn_mask).save()
            yield pooled_acts.value, *labels
            pbar.update(1)

In [4]:
# probe training hyperparameters

layer = 22 # model layer for attaching linear classification head

class Probe(nn.Module):
    def __init__(self, activation_dim, dtype=DTYPE):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True, dtype=dtype)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits

def train_probe(
    activation_batches,
    label_idx=0,
    lr=1e-2,
    epochs=1,
    d_probe=activation_dim,
    seed=SEED,
):
    t.manual_seed(seed)

    probe = Probe(d_probe).to(DEVICE)
    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()
    losses = []

    for epoch in range(epochs):
        for act, *labels, in activation_batches:
            optimizer.zero_grad()
            logits = probe(act)
            loss = criterion(logits, labels[label_idx].to(logits))
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
    return probe, losses

@t.no_grad()
def test_probe(
    probe,
    activation_batches,
):
    corrects = defaultdict(list)
    for acts, *labels in activation_batches:
        logits = probe(acts)
        preds = (logits > 0.0).long()
        for idx, label in enumerate(labels):
            corrects[idx].append(preds == label)

    accs = {idx: t.cat(corrects[idx]).float().mean().item() for idx in corrects}

    return accs

In [5]:
oracle, _ = train_probe(
    activation_batches=collect_activations(model, layer, get_text_batches(split="train", ambiguous=False))
)

ambiguous_accs = test_probe(oracle, activation_batches=collect_activations(model, layer, get_text_batches(split="test", ambiguous=True)))
print(f"ambiguous test accuracy: {ambiguous_accs[0]}")

unambiguous_accs = test_probe(oracle, activation_batches=collect_activations(model, layer, get_text_batches(split="test", ambiguous=False)))
print(f"ground truth accuracy: {unambiguous_accs[0]}")
print(f"unintended feature accuracy: {unambiguous_accs[1]}")

for subgroup, batches in get_subgroups().items():
    subgroup_accs = test_probe(oracle, activation_batches=collect_activations(model, layer, batches))
    print(f"subgroup {subgroup} accuracy: {subgroup_accs[0]}")

Collecting activations:   0%|          | 0/141 [00:00<?, ?it/s]You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Collecting activations: 100%|██████████| 141/141 [00:24<00:00,  5.87it/s]
Collecting activations: 100%|██████████| 269/269 [00:43<00:00,  6.24it/s]


ambiguous test accuracy: 0.957597553730011


Collecting activations: 100%|██████████| 55/55 [00:08<00:00,  6.46it/s]


ground truth accuracy: 0.9504608511924744
unintended feature accuracy: 0.5057603716850281


Collecting activations: 100%|██████████| 507/507 [01:26<00:00,  5.88it/s]


subgroup (0, 0) accuracy: 0.9862387776374817


Collecting activations: 100%|██████████| 417/417 [01:10<00:00,  5.92it/s]


subgroup (0, 1) accuracy: 0.9644010066986084


Collecting activations: 100%|██████████| 14/14 [00:01<00:00,  7.15it/s]


subgroup (1, 0) accuracy: 0.9308755993843079


Collecting activations: 100%|██████████| 135/135 [00:19<00:00,  6.77it/s]

subgroup (1, 1) accuracy: 0.928438663482666





In [6]:
save_path = f"probe_layer_{layer}_{str(DTYPE).split('.')[-1]}.pt"
if os.path.exists(save_path):
    print(f"loading probe from {save_path}")
    probe = t.load(save_path)
else:
    probe, _ = train_probe(
        activation_batches=collect_activations(model, layer, get_text_batches(split="train", ambiguous=True))
    )
    t.save(probe, save_path)
    print(f"probe saved to {save_path}")

ambiguous_accs = test_probe(probe, activation_batches=collect_activations(model, layer, get_text_batches(split="test", ambiguous=True)))
print(f"ambiguous test accuracy: {ambiguous_accs[0]}")

unambiguous_accs = test_probe(probe, activation_batches=collect_activations(model, layer, get_text_batches(split="test", ambiguous=False)))
print(f"ground truth accuracy: {unambiguous_accs[0]}")
print(f"unintended feature accuracy: {unambiguous_accs[1]}")

for subgroup, batches in get_subgroups().items():
    subgroup_accs = test_probe(probe, activation_batches=collect_activations(model, layer, batches))
    print(f"subgroup {subgroup} accuracy: {subgroup_accs[0]}")

loading probe from probe_layer_22_bfloat16.pt


Collecting activations: 100%|██████████| 269/269 [00:42<00:00,  6.27it/s]


ambiguous test accuracy: 0.9965148568153381


Collecting activations: 100%|██████████| 55/55 [00:08<00:00,  6.44it/s]


ground truth accuracy: 0.6774193644523621
unintended feature accuracy: 0.8191244006156921


Collecting activations: 100%|██████████| 507/507 [01:25<00:00,  5.93it/s]


subgroup (0, 0) accuracy: 0.9985806345939636


Collecting activations: 100%|██████████| 417/417 [01:09<00:00,  5.99it/s]


subgroup (0, 1) accuracy: 0.5165602564811707


Collecting activations: 100%|██████████| 14/14 [00:01<00:00,  7.30it/s]


subgroup (1, 0) accuracy: 0.1820276528596878


Collecting activations: 100%|██████████| 135/135 [00:19<00:00,  6.92it/s]

subgroup (1, 1) accuracy: 0.9944238066673279





In [8]:
# loading dictionaries
submodules, dictionaries = load_saes_and_submodules(
    model, 
    thru_layer=layer,
    include_embed=False,
    dtype=DTYPE,
    device=DEVICE,
)

def metric_fn(model, labels=None):
    attn_mask = model.input[1]['attention_mask']
    acts = model.model.layers[layer].output[0]
    acts = pool_acts(acts, attn_mask)
    
    return t.where(
        labels == 0,
        probe(acts),
        - probe(acts)
    )

Loading Gemma SAEs:   0%|          | 0/23 [00:00<?, ?it/s]

Loading Gemma SAEs: 100%|██████████| 23/23 [01:18<00:00,  3.40s/it]


In [7]:
# find most influential features
# n_batches = 200
n_batches = 100
batch_size = 1

nodes = None

for batch_idx, (clean, labels, _) in tqdm(enumerate(get_text_batches(split="train", ambiguous=True, batch_size=batch_size)), total=n_batches):
    if batch_idx == n_batches:
        break

    # if effects are already cached, skip
    hash_input = clean + [s.name for s in submodules]
    hash_str = ''.join(hash_input)
    hash_digest = hashlib.md5(hash_str.encode()).hexdigest()
    if os.path.exists(f"effects/{hash_digest}.pt"):
        continue

    effects, *_ = patching_effect(
        clean,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    to_save = {
        k.name : v.detach().to("cpu") for k, v in effects.items()
    }
    t.save(to_save, f"effects/{hash_digest}.pt")

    del effects, _
    gc.collect()


100%|██████████| 100/100 [00:00<00:00, 76315.57it/s]


In [8]:
# aggregate effects
aggregated_effects = {submodule.name : 0 for submodule in submodules}

for idx, (clean, *_) in enumerate(get_text_batches(split="train", ambiguous=True, batch_size=batch_size)):
    if idx == n_batches:
        break
    hash_input = clean + [s.name for s in submodules]
    hash_str = ''.join(hash_input)
    hash_digest = hashlib.md5(hash_str.encode()).hexdigest()
    effects = t.load(f"effects/{hash_digest}.pt")
    for submodule in submodules:
        aggregated_effects[submodule.name] += (
            effects[submodule.name].act[:,1:,:] # remove BOS features
        ).sum(dim=1).mean(dim=0)

aggregated_effects = {k : v / (batch_size * n_batches) for k, v in aggregated_effects.items()}

In [10]:
count = 0
for k, v in aggregated_effects.items():
    print(k)
    for idx in (v > 6).nonzero():
        count += 1
        print(f"  {idx.item()} : {v[idx].item()}")
print(f"total features: {count}")

attn_0
mlp_0
resid_0
  4449 : 7.4375
attn_1
mlp_1
resid_1
  4521 : 7.25
  11782 : 6.375
attn_2
mlp_2
resid_2
  5853 : 6.71875
attn_3
mlp_3
resid_3
attn_4
mlp_4
resid_4
attn_5
mlp_5
resid_5
  2864 : 9.25
  11682 : 8.1875
attn_6
mlp_6
resid_6
  4068 : 13.0625
  7008 : 7.1875
attn_7
mlp_7
resid_7
  7111 : 8.0625
attn_8
mlp_8
resid_8
  6952 : 7.09375
  9949 : 19.5
attn_9
mlp_9
resid_9
  6952 : 9.375
  15246 : 14.875
attn_10
mlp_10
resid_10
  3711 : 14.5625
  6952 : 9.125
attn_11
mlp_11
resid_11
  3013 : 27.375
  4467 : 7.3125
  11649 : 6.15625
attn_12
mlp_12
resid_12
  6335 : 9.25
  11114 : 36.25
  11480 : 6.28125
attn_13
mlp_13
resid_13
  192 : 34.5
  495 : 6.09375
  14755 : 6.90625
attn_14
mlp_14
resid_14
  2354 : 30.375
  12665 : 6.46875
attn_15
mlp_15
resid_15
  798 : 25.875
  6211 : 6.6875
attn_16
mlp_16
resid_16
  6047 : 6.40625
  15567 : 6.40625
  16351 : 26.25
attn_17
mlp_17
resid_17
  6011 : 19.125
attn_18
mlp_18
resid_18
  61 : 15.75
attn_19
mlp_19
resid_19
  13002 : 10.5
attn_20

In [43]:
# interpret features with Neuronpedia API

submodule_name = "attn_21"
feature_idx = 10118

from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release="gemma-2-2b", sae_id=None, feature_idx=feature_idx):
    return html_template.format(sae_release, sae_id, feature_idx)

# Extract the type and number from submodule_name
submodule_type, submodule_number = submodule_name.split('_')

# Construct the sae_id based on the type
if submodule_type == "resid":
    sae_id = f"{submodule_number}-gemmascope-res-16k"
elif submodule_type == "attn":
    sae_id = f"{submodule_number}-gemmascope-att-16k"
elif submodule_type == "mlp":
    sae_id = f"{submodule_number}-gemmascope-mlp-16k"
else:
    raise ValueError("Unknown submodule type")

html = get_dashboard_html(sae_release="gemma-2-2b", sae_id=sae_id, feature_idx=feature_idx)
IFrame(html, width=1200, height=600)

In [19]:
feats_to_ablate = {
    "attn_0": [
    ],
    "mlp_0": [
    ],
    "resid_0": [
        4449, # "he", "him"
    ],
    "attn_1": [
    ],
    "mlp_1": [
    ],
    "resid_1": [
        4521, # his"
        11782, # # "woman", "she", "her", "female"
    ],
    "attn_2": [
    ],
    "mlp_2": [
    ],
    "resid_2": [
        5853, # "he"
    ],
    "attn_3": [
    ],
    "mlp_3": [
    ],
    "resid_3": [
    ],
    "attn_4": [
    ],
    "mlp_4": [
    ],
    "resid_4": [
    ],
    "attn_5": [
    ],
    "mlp_5": [
    ],
    "resid_5": [
        2864, # "her"
        11682, # "he"
    ],
    "attn_6": [
    ],
    "mlp_6": [
    ],
    "resid_6": [
        4068, # female pronouns
        7008, # gendered pronouns
    ],
    "attn_7": [
    ],
    "mlp_7": [
    ],
    "resid_7": [
        7111, # text about women
    ],
    "attn_8": [
    ],
    "mlp_8": [
    ],
    "resid_8": [
        6952, # gendered pronouns
        9949, # female pronouns
    ],
    "attn_9": [
    ],
    "mlp_9": [
    ],
    "resid_9": [
        6952, # "she"
        15246, # female pronouns
    ],
    "attn_10": [
    ],
    "mlp_10": [
    ],
    "resid_10": [
        3711, # promotes female-associated words
        6952, # masculine pronouns
    ],
    "attn_11": [
    ],
    "mlp_11": [
    ],
    "resid_11": [
        3013, # descriptions of women
        4467, # "He"
        11649, # "she", "her"
    ],
    "attn_12": [
    ],
    "mlp_12": [
    ],
    "resid_12": [
        6335, # "He"
        11114, # descriptions of women
        11480, # "his", "her"
    ],
    "attn_13": [
    ],
    "mlp_13": [
    ],
    "resid_13": [
        192, # descriptions of women
        495, # "her"
        14755, # "he"
    ],
    "attn_14": [
    ],
    "mlp_14": [
    ],
    "resid_14": [
        2354, # descriptions of women
        12665, # male pronouns
    ],
    "attn_15": [
    ],
    "mlp_15": [
    ],
    "resid_15": [
        798, # promotes feminine pronouns
        6211, # gendered pronouns
    ],
    "attn_16": [
    ],
    "mlp_16": [
    ],
    "resid_16": [
        6047, # women, girls, women's names
        15567, # "he", "she"
        16351, # promotes feminine pronouns
    ],
    "attn_17": [
    ],
    "mlp_17": [
    ],
    "resid_17": [
        6011, # promotes feminine pronouns
    ],
    "attn_18": [
    ],
    "mlp_18": [
    ],
    "resid_18": [
        61, # descriptions of female royalty
    ],
    "attn_19": [
    ],
    "mlp_19": [
    ],
    "resid_19": [
        13002, # clauses starting with "she"
    ],
    "attn_20": [
        9711, # "her", "she"
    ],
    "mlp_20": [
    ],
    "resid_20": [
        7116, # promotes masculine pronouns
        # 7324, # nursing
        12545, # promotes female pronouns
    ],
    "attn_21": [
        2740, # promotes feminine pronouns
        10118, # masculine pronouns
    ],
    "mlp_21": [
    ],
    "resid_21": [
        1065, # promotes masculine pronouns
        # 1653, # nursing
        4430, # promotes feminine pronouns
    ],
    "attn_22": [
    ],
    "mlp_22": [
    ],
    "resid_22": [
        1208, # promotes masculine pronouns
        3497, # promotes feminine pronouns
        # 7961, # nursing
    ],
}
print(f"number of features to ablate: {sum([len(v) for v in feats_to_ablate.values()])}")

number of features to ablate: 43


In [20]:
# putting feats_to_ablate in a more useful format
def n_hot(feats, dim):
    out = t.zeros(dim, dtype=t.bool, device=DEVICE)
    for feat in feats:
        out[feat] = True
    return out

feats_to_ablate = {
    submodule : n_hot(feats_to_ablate[submodule.name], dictionaries[submodule].dict_size) for submodule in submodules
}

In [39]:
@t.no_grad()
def collect_acts_ablated(
    text_batches,
    model,
    submodules,
    dictionaries,
    to_ablate,
    layer,
):
    with tqdm(total=len(text_batches), desc="Collecting activations with ablations") as pbar:
        for text, *labels in text_batches:
            with model.trace(text, **tracer_kwargs):
                for submodule in submodules:
                    dictionary = dictionaries[submodule]
                    feat_idxs = to_ablate[submodule]
                    if len(feat_idxs) == 0:
                        continue
                    x = submodule.get_activation()
                    x_hat, f = dictionary(x, output_features=True)
                    res = x - x_hat
                    f[:, :,feat_idxs] = 0. # zero ablation
                    submodule.set_activation(dictionary.decode(f) + res)
                attn_mask = model.input[1]['attention_mask']
                act = model.model.layers[layer].output[0]
                pooled_act = pool_acts(act, attn_mask).save()
            yield pooled_act.value, *labels
            pbar.update(1)


# Accuracy after ablating features judged irrelevant by human annotators

In [14]:
ambiguous_accs = test_probe(
    probe,
    activation_batches=collect_acts_ablated(
        get_text_batches(split="test", ambiguous=True), 
        model, submodules, dictionaries, feats_to_ablate, layer
    ),
)
print(f"Ambiguous test accuracy: {ambiguous_accs[0]}")

unambiguous_accs = test_probe(
    probe,
    activation_batches=collect_acts_ablated(
        get_text_batches(split="test", ambiguous=False),
        model, submodules, dictionaries, feats_to_ablate, layer
    ),
)
print(f"Ground truth accuracy: {unambiguous_accs[0]}")
print(f"Spurious accuracy: {unambiguous_accs[1]}")

for subgroup, batches in get_subgroups().items():
    subgroup_accs = test_probe(
        probe,
        activation_batches=collect_acts_ablated(
            batches,
            model, submodules, dictionaries, feats_to_ablate, layer
        ),
    )
    print(f"Subgroup {subgroup} accuracy: {subgroup_accs[0]}")


Collecting activations with ablations: 100%|██████████| 269/269 [02:33<00:00,  1.76it/s]


Ambiguous test accuracy: 0.7659153938293457


Collecting activations with ablations: 100%|██████████| 55/55 [00:30<00:00,  1.80it/s]


Ground truth accuracy: 0.7603686451911926
Spurious accuracy: 0.514976978302002


Collecting activations with ablations: 100%|██████████| 507/507 [05:08<00:00,  1.64it/s]


Subgroup (0, 0) accuracy: 0.9977167248725891


Collecting activations with ablations: 100%|██████████| 417/417 [04:06<00:00,  1.69it/s]


Subgroup (0, 1) accuracy: 0.9885091781616211


Collecting activations with ablations: 100%|██████████| 14/14 [00:07<00:00,  1.99it/s]


Subgroup (1, 0) accuracy: 0.5


Collecting activations with ablations: 100%|██████████| 135/135 [01:11<00:00,  1.90it/s]

Subgroup (1, 1) accuracy: 0.533224880695343





# Concept bottleneck probing baseline

In [5]:
concepts = [    
    ' nurse',
    ' healthcare',
    ' hospital',
    ' patient',
    ' medical',
    ' clinic',
    ' triage',
    ' medication',
    ' emergency',
    ' surgery',
    ' professor',
    ' academia',
    ' research',
    ' university',
    ' tenure',
    ' faculty',
    ' dissertation',
    ' sabbatical',
    ' publication',
    ' grant',
]
# get concept vectors
with t.no_grad():
    with model.trace(concepts):
        concept_vectors = model.model.layers[layer].output[0][:, -1, :].save()
    concept_vectors = concept_vectors.value - concept_vectors.value.mean(0, keepdim=True)
    concept_vectors = concept_vectors / concept_vectors.norm(dim=-1, keepdim=True)
    

@t.no_grad()
def get_bottleneck_activations(
    text_batches,
    model,
    layer,
):
    with tqdm(total=len(text_batches), desc="Collecting bottleneck activations") as pbar:
        for text, *labels in text_batches:
            with model.trace(text, **tracer_kwargs):
                attn_mask = model.input[1]['attention_mask']
                acts = model.model.layers[layer].output[0]
                pooled_acts = pool_acts(acts, attn_mask)
                pooled_acts = pooled_acts / pooled_acts.norm(dim=-1, keepdim=True)
                # compute cosine similarity with concept vectors
                sims = (pooled_acts @ concept_vectors.T).save()
            yield sims.value, *labels
            pbar.update(1)

You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [6]:
cbp_probe, _ = train_probe(
    get_bottleneck_activations(get_text_batches(split="train", ambiguous=True), model, layer),
    d_probe=len(concepts),
)

ambiguous_accs = test_probe(
    cbp_probe,
    get_bottleneck_activations(get_text_batches(split="test", ambiguous=True), model, layer)
)
print(f"Ambiguous test accuracy: {ambiguous_accs[0]}")

unambiguous_accs = test_probe(
    cbp_probe,
    get_bottleneck_activations(get_text_batches(split="test", ambiguous=False), model, layer)
)
print(f"Ground truth accuracy: {unambiguous_accs[0]}")
print(f"Spurious accuracy: {unambiguous_accs[1]}")

for subgroup, batches in get_subgroups().items():
    subgroup_accs = test_probe(
        cbp_probe,
        get_bottleneck_activations(batches, model, layer)
    )
    print(f"Subgroup {subgroup} accuracy: {subgroup_accs[0]}")

Collecting bottleneck activations: 100%|██████████| 700/700 [01:59<00:00,  5.87it/s]
Collecting bottleneck activations: 100%|██████████| 269/269 [00:43<00:00,  6.12it/s]


Ambiguous test accuracy: 0.9138011336326599


Collecting bottleneck activations: 100%|██████████| 55/55 [00:08<00:00,  6.38it/s]


Ground truth accuracy: 0.9020737409591675
Spurious accuracy: 0.5011520981788635


Collecting bottleneck activations: 100%|██████████| 507/507 [01:26<00:00,  5.87it/s]


Subgroup (0, 0) accuracy: 0.8963282704353333


Collecting bottleneck activations: 100%|██████████| 417/417 [01:09<00:00,  5.98it/s]


Subgroup (0, 1) accuracy: 0.8675178289413452


Collecting bottleneck activations: 100%|██████████| 14/14 [00:01<00:00,  7.26it/s]


Subgroup (1, 0) accuracy: 0.9193548560142517


Collecting bottleneck activations: 100%|██████████| 135/135 [00:19<00:00,  6.91it/s]

Subgroup (1, 1) accuracy: 0.9321561455726624





# Get skyline neuron performance

In [10]:
# get neurons which are most influential for giving gender label
neuron_dicts = {
    submodule : IdentityDict(activation_dim).to(DEVICE) for submodule in submodules
}

n_batches = 100
batch_size = 1

running_total = 0
nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_text_batches(split="train", ambiguous=False, batch_size=batch_size)), total=n_batches):
    if batch_idx == n_batches:
        break

    # if effects are already cached, skip
    hash_input = clean + [s.name for s in submodules]
    hash_str = ''.join(hash_input) + "_neuron_skyline"
    hash_digest = hashlib.md5(hash_str.encode()).hexdigest()
    if os.path.exists(f"effects/{hash_digest}.pt"):
        continue

    effects, *_  = patching_effect(
        clean,
        None,
        model,
        submodules,
        neuron_dicts,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    to_save = {
        k.name : v.detach().to("cpu") for k, v in effects.items()
    }
    t.save(to_save, f"effects/{hash_digest}.pt")

    del effects, _
    gc.collect()

100%|██████████| 100/100 [18:42<00:00, 11.23s/it]


In [11]:
# aggregated effects
aggregated_effects = {submodule.name : 0 for submodule in submodules}

for idx, (clean, *_) in enumerate(get_text_batches(split="train", ambiguous=False, batch_size=batch_size)):
    if idx == n_batches:
        break
    hash_input = clean + [s.name for s in submodules]
    hash_str = ''.join(hash_input) + "_neuron_skyline"
    hash_digest = hashlib.md5(hash_str.encode()).hexdigest()
    effects = t.load(f"effects/{hash_digest}.pt")
    for submodule in submodules:
        aggregated_effects[submodule.name] += (
            effects[submodule.name].act[:,1:,:] # ignore BOS token
        ).sum(dim=1).mean(dim=0)

aggregated_effects = {k : v / (batch_size * n_batches) for k, v in aggregated_effects.items()}

In [32]:
top_neurons_to_ablate = {}
total_neurons = 0
for submodule in submodules:
    print(submodule.name)
    top_neurons_to_ablate[submodule.name] = []
    for idx in (aggregated_effects[submodule.name] > 4.8).nonzero():
        print(f"  {idx.item()} : {aggregated_effects[submodule.name][idx].item()}")
        top_neurons_to_ablate[submodule.name].append(idx.item())
        total_neurons += 1
print(f"total neurons: {total_neurons}")

attn_0
mlp_0
resid_0
attn_1
mlp_1
resid_1
  1570 : 6.6875
attn_2
mlp_2
resid_2
  1570 : 11.0625
attn_3
mlp_3
resid_3
  629 : 6.4375
attn_4
mlp_4
resid_4
  1570 : 8.8125
attn_5
mlp_5
resid_5
attn_6
mlp_6
resid_6
  624 : 11.625
  682 : 6.3125
  1068 : 6.3125
  1570 : 6.6875
attn_7
mlp_7
resid_7
  334 : 12.75
attn_8
mlp_8
resid_8
attn_9
mlp_9
resid_9
  334 : 6.84375
  682 : 9.25
attn_10
mlp_10
resid_10
attn_11
mlp_11
resid_11
  334 : 16.0
attn_12
mlp_12
resid_12
  334 : 5.5625
  682 : 12.625
  1570 : 6.96875
attn_13
mlp_13
resid_13
  334 : 8.0625
attn_14
mlp_14
resid_14
  334 : 13.1875
  682 : 9.3125
  1546 : 6.53125
  1570 : 5.03125
attn_15
mlp_15
resid_15
attn_16
mlp_16
resid_16
  334 : 5.3125
  1068 : 6.1875
  1149 : 5.4375
  1711 : 4.90625
attn_17
mlp_17
resid_17
  98 : 5.4375
  113 : 5.8125
  682 : 22.125
  784 : 5.75
  1068 : 22.125
  1149 : 6.46875
  1546 : 14.3125
  1570 : 8.5625
  1645 : 11.1875
attn_18
mlp_18
resid_18
  334 : 6.84375
attn_19
mlp_19
resid_19
  682 : 19.0
  784 : 

In [16]:
@t.no_grad()
def collect_acts_ablated_neurons(
    text_batches,
    model,
    submodules,
    to_ablate,
    layer,
):
    with tqdm(total=len(text_batches), desc="Collecting activations with ablations") as pbar:
        for text, *labels in text_batches:
            with model.trace(text, **tracer_kwargs):
                for submodule in submodules:
                    x = submodule.get_activation()
                    x[..., to_ablate[submodule]] = x.mean(dim=(0, 1))[..., to_ablate[submodule]] # mean ablation
                    submodule.set_activation(x)
                attn_mask = model.input[1]['attention_mask']
                act = model.model.layers[layer].output[0]
                pooled_act = pool_acts(act, attn_mask).save()
            yield pooled_act.value, *labels
            pbar.update(1)

In [33]:
top_neurons_to_ablate = {
    submodule : n_hot(
        top_neurons_to_ablate[submodule.name], 
        dictionaries[submodule].W_enc.shape[0] # dimension of the submodule
    ) for submodule in submodules
}

In [34]:
ambiguous_accs = test_probe(
    probe,
    activation_batches=collect_acts_ablated_neurons(
        get_text_batches(split="test", ambiguous=True), 
        model, submodules, top_neurons_to_ablate, layer
    ),
)
print(f"Ambiguous test accuracy: {ambiguous_accs[0]}")

unambiguous_accs = test_probe(
    probe,
    activation_batches=collect_acts_ablated_neurons(
        get_text_batches(split="test", ambiguous=False),
        model, submodules, top_neurons_to_ablate, layer
    ),
)
print(f"Ground truth accuracy: {unambiguous_accs[0]}")
print(f"Spurious accuracy: {unambiguous_accs[1]}")

for subgroup, batches in get_subgroups().items():
    subgroup_accs = test_probe(
        probe,
        activation_batches=collect_acts_ablated_neurons(
            batches,
            model, submodules, top_neurons_to_ablate, layer
        ),
    )
    print(f"Subgroup {subgroup} accuracy: {subgroup_accs[0]}")


Collecting activations with ablations: 100%|██████████| 269/269 [00:47<00:00,  5.68it/s]


Ambiguous test accuracy: 0.9940752387046814


Collecting activations with ablations: 100%|██████████| 55/55 [00:09<00:00,  5.93it/s]


Ground truth accuracy: 0.650921642780304
Spurious accuracy: 0.843317985534668


Collecting activations with ablations: 100%|██████████| 507/507 [01:33<00:00,  5.40it/s]


Subgroup (0, 0) accuracy: 0.9986423254013062


Collecting activations with ablations: 100%|██████████| 417/417 [01:15<00:00,  5.50it/s]


Subgroup (0, 1) accuracy: 0.6099886894226074


Collecting activations with ablations: 100%|██████████| 14/14 [00:02<00:00,  6.24it/s]


Subgroup (1, 0) accuracy: 0.04608295112848282


Collecting activations with ablations: 100%|██████████| 135/135 [00:22<00:00,  6.11it/s]

Subgroup (1, 1) accuracy: 0.990938663482666





# Get skyline feature performance

In [48]:
# get features which are most useful for predicting gender label
n_batches = 100
batch_size = 1

running_total = 0
running_nodes = None

for batch_idx, (clean, _, labels) in tqdm(enumerate(get_text_batches(split="train", ambiguous=False, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break

    hash_input = clean + [s.name for s in submodules]
    hash_str = ''.join(hash_input) + "_feature_skyline"
    hash_digest = hashlib.md5(hash_str.encode()).hexdigest()
    if os.path.exists(f"effects/{hash_digest}.pt"):
        continue

    effects, *_ = patching_effect(
        clean,
        None,
        model,
        submodules,
        dictionaries,
        metric_fn,
        metric_kwargs=dict(labels=labels),
        method='ig'
    )
    to_save = {
        k.name : v.detach().to("cpu") for k, v in effects.items()
    }
    t.save(to_save, f"effects/{hash_digest}.pt")

    del effects
    gc.collect()

100%|██████████| 100/100 [18:59<00:00, 11.39s/it]


In [35]:
# aggregate effects
aggregated_effects = {submodule.name : 0 for submodule in submodules}

for idx, (clean, *_) in enumerate(get_text_batches(split="train", ambiguous=False, batch_size=batch_size)):
    if idx == n_batches:
        break
    hash_input = clean + [s.name for s in submodules]
    hash_str = ''.join(hash_input) + "_feature_skyline"
    hash_digest = hashlib.md5(hash_str.encode()).hexdigest()
    effects = t.load(f"effects/{hash_digest}.pt")
    for submodule in submodules:
        aggregated_effects[submodule.name] += (
            effects[submodule.name].act[:,1:,:] # ignore BOS token
        ).sum(dim=1).mean(dim=0)

aggregated_effects = {k : v / (batch_size * n_batches) for k, v in aggregated_effects.items()}

In [36]:
top_feats_to_ablate = {}
total_features = 0
for submodule in submodules:
    print(f"Component {submodule.name}:")
    top_feats_to_ablate[submodule.name] = []
    for idx in (aggregated_effects[submodule.name] > 6.1).nonzero():
        print(idx.item(), aggregated_effects[submodule.name][idx].item())
        top_feats_to_ablate[submodule.name].append(idx.item())
        total_features += 1
print(f"total features: {total_features}")

Component attn_0:
Component mlp_0:
Component resid_0:
4449 8.0
Component attn_1:
Component mlp_1:
Component resid_1:
11782 6.96875
Component attn_2:
Component mlp_2:
Component resid_2:
5853 6.53125
12818 6.1875
Component attn_3:
Component mlp_3:
Component resid_3:
9949 7.0
Component attn_4:
Component mlp_4:
Component resid_4:
4675 6.53125
Component attn_5:
Component mlp_5:
Component resid_5:
2864 9.8125
11682 7.125
Component attn_6:
Component mlp_6:
Component resid_6:
4068 9.3125
7008 9.625
Component attn_7:
Component mlp_7:
Component resid_7:
7111 6.84375
Component attn_8:
Component mlp_8:
Component resid_8:
6952 7.875
9949 14.625
Component attn_9:
Component mlp_9:
Component resid_9:
6952 10.0
15246 11.25
Component attn_10:
Component mlp_10:
Component resid_10:
3711 11.1875
6952 10.25
Component attn_11:
Component mlp_11:
Component resid_11:
3013 19.75
4467 8.1875
11649 6.15625
Component attn_12:
Component mlp_12:
Component resid_12:
6335 10.1875
11114 25.75
Component attn_13:
Componen

In [37]:
def n_hot(feats, dim):
    out = t.zeros(dim, dtype=t.bool, device=DEVICE)
    for feat in feats:
        out[feat] = True
    return out

top_feats_to_ablate = {
    submodule : n_hot(top_feats_to_ablate[submodule.name], dictionaries[submodule].dict_size) for submodule in submodules
}

In [40]:
ambiguous_accs = test_probe(
    probe,
    activation_batches=collect_acts_ablated(
        get_text_batches(split="test", ambiguous=True), 
        model, submodules, dictionaries, top_feats_to_ablate, layer
    ),
)
print(f"Ambiguous test accuracy: {ambiguous_accs[0]}")

unambiguous_accs = test_probe(
    probe,
    activation_batches=collect_acts_ablated(
        get_text_batches(split="test", ambiguous=False),
        model, submodules, dictionaries, top_feats_to_ablate, layer
    ),
)
print(f"Ground truth accuracy: {unambiguous_accs[0]}")
print(f"Spurious accuracy: {unambiguous_accs[1]}")

for subgroup, batches in get_subgroups().items():
    subgroup_accs = test_probe(
        probe,
        activation_batches=collect_acts_ablated(
            batches,
            model, submodules, dictionaries, top_feats_to_ablate, layer
        ),
    )
    print(f"Subgroup {subgroup} accuracy: {subgroup_accs[0]}")


Collecting activations with ablations: 100%|██████████| 269/269 [02:34<00:00,  1.75it/s]


Ambiguous test accuracy: 0.8300418257713318


Collecting activations with ablations: 100%|██████████| 55/55 [00:31<00:00,  1.77it/s]


Ground truth accuracy: 0.8087557554244995
Spurious accuracy: 0.5368663668632507


Collecting activations with ablations: 100%|██████████| 507/507 [05:13<00:00,  1.62it/s]


Subgroup (0, 0) accuracy: 0.9973464608192444


Collecting activations with ablations: 100%|██████████| 417/417 [04:09<00:00,  1.67it/s]


Subgroup (0, 1) accuracy: 0.9850544333457947


Collecting activations with ablations: 100%|██████████| 14/14 [00:07<00:00,  1.97it/s]


Subgroup (1, 0) accuracy: 0.5668202638626099


Collecting activations with ablations: 100%|██████████| 135/135 [01:11<00:00,  1.90it/s]

Subgroup (1, 1) accuracy: 0.6619423627853394





# Retraining probe on activations after ablating features

In [41]:
retrained_probe, _ = train_probe(
    activation_batches=collect_acts_ablated(
        get_text_batches(split="train", ambiguous=True), 
        model, submodules, dictionaries, feats_to_ablate, layer
    ),
)
ambiguous_test_accs = test_probe(
    retrained_probe,
    activation_batches=collect_acts_ablated(
        get_text_batches(split="test", ambiguous=True),
        model, submodules, dictionaries, feats_to_ablate, layer
    ),
)
print(f"Ambiguous test accuracy: {ambiguous_test_accs[0]}")
unambiguous_test_accs = test_probe(
    retrained_probe,
    activation_batches=collect_acts_ablated(
        get_text_batches(split="test", ambiguous=False),
        model, submodules, dictionaries, feats_to_ablate, layer
    ),
)
print(f"Ground truth accuracy: {unambiguous_test_accs[0]}")
print(f"Spurious accuracy: {unambiguous_test_accs[1]}")

for subgroup, batches in get_subgroups().items():
    subgroup_accs = test_probe(
        retrained_probe,
        activation_batches=collect_acts_ablated(
            batches,
            model, submodules, dictionaries, feats_to_ablate, layer
        ),
    )
    print(f"Subgroup {subgroup} accuracy: {subgroup_accs[0]}")

Collecting activations with ablations: 100%|██████████| 700/700 [06:59<00:00,  1.67it/s]
Collecting activations with ablations: 100%|██████████| 269/269 [02:38<00:00,  1.70it/s]


Ambiguous test accuracy: 0.977811336517334


Collecting activations with ablations: 100%|██████████| 55/55 [00:30<00:00,  1.78it/s]


Ground truth accuracy: 0.9504608511924744
Spurious accuracy: 0.524193525314331


Collecting activations with ablations: 100%|██████████| 507/507 [05:07<00:00,  1.65it/s]


Subgroup (0, 0) accuracy: 0.9780314564704895


Collecting activations with ablations: 100%|██████████| 417/417 [04:09<00:00,  1.67it/s]


Subgroup (0, 1) accuracy: 0.9289522767066956


Collecting activations with ablations: 100%|██████████| 14/14 [00:07<00:00,  1.97it/s]


Subgroup (1, 0) accuracy: 0.9585253596305847


Collecting activations with ablations: 100%|██████████| 135/135 [01:11<00:00,  1.89it/s]

Subgroup (1, 1) accuracy: 0.9795538783073425



