In [2]:
import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from nnsight import LanguageModel
import random
from datasets import load_dataset
import torch as t
import torch.nn as nn
from dictionary_learning import AutoEncoder
from circuit import get_circuit
from circuit_plotting import plot_circuit
from activation_utils import SparseAct
from tqdm import tqdm
import json
import gc

DEBUGGING = False
if DEBUGGING:
    tracer_kwargs = {'validate' : True, 'scan' : True}
else:
    tracer_kwargs = {'validate' : False, 'scan' : False}

DEVICE = 'cuda:0'
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=DEVICE, dispatch=True)

In [3]:
# utilities for loading data
dataset = load_dataset("LabHC/bias_in_bios")
profession_dict = {'professor' : 21, 'nurse' : 13}
male_prof = 'professor'
female_prof = 'nurse'

batch_size = 1024
SEED = 42

def get_data(train=True, ambiguous=True, batch_size=128, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg), len(pos)])
        neg, pos = neg[:n], pos[:n]
        data = neg + pos
        labels = [0]*n + [1]*n
        idxs = list(range(2*n))
        random.Random(seed).shuffle(idxs)
        data, labels = [data[i] for i in idxs], [labels[i] for i in idxs]
        true_labels = spurious_labels = labels
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        n = min([len(neg_neg), len(neg_pos), len(pos_neg), len(pos_pos)])
        neg_neg, neg_pos, pos_neg, pos_pos = neg_neg[:n], neg_pos[:n], pos_neg[:n], pos_pos[:n]
        data = neg_neg + neg_pos + pos_neg + pos_pos
        true_labels     = [0]*n + [0]*n + [1]*n + [1]*n
        spurious_labels = [0]*n + [1]*n + [0]*n + [1]*n
        idxs = list(range(4*n))
        random.Random(seed).shuffle(idxs)
        data, true_labels, spurious_labels = [data[i] for i in idxs], [true_labels[i] for i in idxs], [spurious_labels[i] for i in idxs]

    batches = [
        (data[i:i+batch_size], t.tensor(true_labels[i:i+batch_size], device=DEVICE), t.tensor(spurious_labels[i:i+batch_size], device=DEVICE)) for i in range(0, len(data), batch_size)
    ]

    return batches

def get_subgroups(train=True, ambiguous=True, batch_size=128, seed=SEED):
    if train:
        data = dataset['train']
    else:
        data = dataset['test']
    if ambiguous:
        neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_labels, pos_labels = (0, 0), (1, 1)
        subgroups = [(neg, neg_labels), (pos, pos_labels)]
    else:
        neg_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 0]
        neg_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[male_prof] and x['gender'] == 1]
        pos_neg = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 0]
        pos_pos = [x['hard_text'] for x in data if x['profession'] == profession_dict[female_prof] and x['gender'] == 1]
        neg_neg_labels, neg_pos_labels, pos_neg_labels, pos_pos_labels = (0, 0), (0, 1), (1, 0), (1, 1)
        subgroups = [(neg_neg, neg_neg_labels), (neg_pos, neg_pos_labels), (pos_neg, pos_neg_labels), (pos_pos, pos_pos_labels)]
    
    out = {}
    for data, label_profile in subgroups:
        out[label_profile] = []
        for i in range(0, len(data), batch_size):
            text = data[i:i+batch_size]
            out[label_profile].append(
                (
                    text,
                    t.tensor([label_profile[0]]*len(text), device=DEVICE),
                    t.tensor([label_profile[1]]*len(text), device=DEVICE)
                )
            )
    return out

Downloading readme: 100%|██████████| 3.30k/3.30k [00:00<00:00, 61.5kB/s]


In [4]:
# probe training hyperparameters

layer = 4 # the model layer to attach linear classification head to

class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits

def train_probe(get_acts, label_idx=0, batches=get_data(), lr=1e-2, epochs=1, dim=512, seed=SEED):
    t.manual_seed(seed)
    probe = Probe(dim).to('cuda:0')
    optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    losses = []
    for epoch in range(epochs):
        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1] 
            acts = get_acts(text)
            logits = probe(acts)
            loss = criterion(logits, labels.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

    return probe, losses

def test_probe(probe, get_acts, label_idx=0, batches=get_data(train=False), seed=SEED):
    with t.no_grad():
        corrects = []

        for batch in batches:
            text = batch[0]
            labels = batch[label_idx+1]
            acts = get_acts(text)
            logits = probe(acts)
            preds = (logits > 0.0).long()
            corrects.append((preds == labels).float())
        return t.cat(corrects).mean().item()
    
def get_acts(text):
    with t.no_grad(): 
        with model.trace(text, **tracer_kwargs):
            attn_mask = model.input[1]['attention_mask']
            acts = model.gpt_neox.layers[layer].output[0]
            acts = acts * attn_mask[:, :, None]
            acts = acts.sum(1) / attn_mask.sum(1)[:, None]
            acts = acts.save()
        return acts.value

In [5]:
probe, _ = train_probe(get_acts, label_idx=0)
print('Ambiguous test accuracy:', test_probe(probe, get_acts, label_idx=0))
batches = get_data(train=False, ambiguous=False)
print('Ground truth accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=0))
print('Unintended feature accuracy:', test_probe(probe, get_acts, batches=batches, label_idx=1))

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Ambiguous test accuracy: 0.9955855011940002
Ground truth accuracy: 0.6186636090278625
Unintended feature accuracy: 0.8744239807128906


In [7]:
# loading dictionaries
dict_id = 10

embed = model.gpt_neox.embed_in
attns = [l.attention for l in model.gpt_neox.layers[:layer+1]]
mlps = [l.mlp for l in model.gpt_neox.layers[:layer+1]]
resids = model.gpt_neox.layers[:layer+1]

dictionaries = {}
dictionaries[embed] = AutoEncoder.from_pretrained(
    f'../dictionary_learning/dictionaries/pythia-70m-deduped/embed/{dict_id}_32768/ae.pt', #f'../dictionaries/pythia-70m-deduped/embed/{dict_id}_32768/ae.pt',
    device=DEVICE

)
for i in range(layer + 1):
    dictionaries[attns[i]] = AutoEncoder.from_pretrained(
        f'../dictionary_learning/dictionaries/pythia-70m-deduped/attn_out_layer{i}/{dict_id}_32768/ae.pt',
        device=DEVICE
    )
    dictionaries[mlps[i]] = AutoEncoder.from_pretrained(
        f'../dictionary_learning/dictionaries/pythia-70m-deduped/mlp_out_layer{i}/{dict_id}_32768/ae.pt',
        device=DEVICE
    )
    dictionaries[resids[i]] = AutoEncoder.from_pretrained(
        f'../dictionary_learning/dictionaries/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_32768/ae.pt',
        device=DEVICE
    )

In [None]:
# metric function for circuit discovery
def metric_fn(model, labels=None):
    attn_mask = model.input[1]['attention_mask']
    acts = model.gpt_neox.layers[layer].output[0]
    acts = acts * attn_mask[:, :, None]
    acts = acts.sum(1) / attn_mask.sum(1)[:, None]
    
    return t.where(
        labels == 0,
        probe(acts),
        - probe(acts)
    )

In [None]:
# circuit discovery
n_batches = 25
batch_size = 4

running_total = 0
running_nodes = None
running_edges = None
for batch_idx, (clean, labels, _) in tqdm(enumerate(get_data(train=True, ambiguous=True, batch_size=batch_size, seed=SEED)), total=n_batches):
    if batch_idx == n_batches:
        break
    nodes, edges = get_circuit(
        clean,
        None,
        model,
        embed,
        attns,
        mlps,
        resids,
        dictionaries,
        metric_fn,
        metric_kwargs={'labels': labels},
        node_threshold=0.2, # NOTE: use lower threshold if sigmoid
        edge_threshold=0.02,
    )
    running_total += len(clean)
    if running_nodes is None:
        running_nodes = { k : len(clean) * v.to('cpu') if k != 'y' else None for k, v in nodes.items() }
        running_edges = { k : { kk : len(clean) * v.to('cpu') for kk, v in vv.items() } for k, vv in edges.items() }
    else:
        for k, effect in nodes.items():
            if k == 'y': continue
            running_nodes[k] += len(clean) * effect.to('cpu')
        for k in edges.keys():
            for kk, effect in edges[k].items():
                running_edges[k][kk] += len(clean) * effect.to('cpu')
    del nodes, edges
    gc.collect()

for k in running_nodes.keys():
    if k == 'y': continue
    running_nodes[k] = running_nodes[k].to('cuda:0') / running_total
for k in running_edges.keys():
    for kk in running_edges[k].keys():
        running_edges[k][kk] = running_edges[k][kk].to('cuda:0') / running_total

In [None]:
# only plot positive effect nodes
running_nodes = {
    k : SparseAct(act=t.clamp(v.act, min=0), resc=t.clamp(v.resc, min=0)) if v is not None else None for k, v in running_nodes.items()
}

# get annotations
try:
    annotations = {}
    with open(f"../annotations/10_32768.jsonl", 'r') as annotations_data:
        for annotation_line in annotations_data:
            annotation = json.loads(annotation_line)
            annotations[annotation["Name"]] = annotation["Annotation"]
except:
    annotations = None

In [None]:
plot_circuit(
    running_nodes,
    running_edges,
    layers=5,
    node_threshold=0.1,
    edge_threshold=0.01,
    pen_thickness=1,
    save_dir='../circuits/figures/bib_circuit',
    annotations=annotations
)