In [1]:
from nnsight import LanguageModel
import pandas as pd
import torch as t
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda:0'

model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=device)


In [3]:
class Probe(nn.Module):
    def __init__(self, activation_dim):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True)

    def forward(self, x):
        logits = self.net(x).squeeze(-1)
        return logits.sigmoid()

In [4]:
lr = 1e-2
epochs = 20
layer = 2

In [5]:
t.manual_seed(42)

# train control probes for each feature individually
train_data = pd.read_csv('data/train_data.csv')
feat1_labels = t.randint(0, 2, (len(train_data),)).to(device)
feat2_labels = t.randint(0, 2, (len(train_data),)).to(device)
inputs = []
for (_, row), label1, label2 in zip(train_data.iterrows(), feat1_labels, feat2_labels):
    if label1 == 0 and label2 == 0:
        inputs.append(row['singular'])
    elif label1 == 0 and label2 == 1:
        inputs.append(row['singular'].upper())
    elif label1 == 1 and label2 == 0:
        inputs.append(row['plural'])
    elif label1 == 1 and label2 == 1:
        inputs.append(row['plural'].upper())

with model.invoke(inputs):
    acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
acts = acts.value.clone()

probe1 = Probe(acts.shape[-1]).to(device)
probe2 = Probe(acts.shape[-1]).to(device)
opt1 = t.optim.AdamW(probe1.parameters(), lr=lr)
opt2 = t.optim.AdamW(probe2.parameters(), lr=lr)

for _ in range(epochs):
    opt1.zero_grad(), opt2.zero_grad()
    logits1 = probe1(acts)
    logits2 = probe2(acts)
    loss1 = nn.BCELoss()(logits1, feat1_labels.float())
    loss2 = nn.BCELoss()(logits2, feat2_labels.float())
    loss1.backward(), loss2.backward()
    opt1.step(), opt2.step()

# test control probes
test_data = pd.read_csv('data/test_data.csv')
feat1_labels = t.randint(0, 2, (len(test_data),)).to(device)
feat2_labels = t.randint(0, 2, (len(test_data),)).to(device)
inputs = []
for (_, row), label1, label2 in zip(test_data.iterrows(), feat1_labels, feat2_labels):
    if label1 == 0 and label2 == 0:
        inputs.append(row['singular'])
    elif label1 == 0 and label2 == 1:
        inputs.append(row['singular'].upper())
    elif label1 == 1 and label2 == 0:
        inputs.append(row['plural'])
    elif label1 == 1 and label2 == 1:
        inputs.append(row['plural'].upper())

with model.invoke(inputs):
    acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
acts = acts.value.clone()

probs1 = probe1(acts)
probs2 = probe2(acts)
preds1, preds2 = probs1.round(), probs2.round()
acc1 = (preds1 == feat1_labels).float().mean().item()
acc2 = (preds2 == feat2_labels).float().mean().item()

print(f'Control probe 1 accuracy: {acc1}')
print(f'Control probe 2 accuracy: {acc2}')


You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Control probe 1 accuracy: 0.970802903175354
Control probe 2 accuracy: 0.9927006959915161


In [6]:
t.manual_seed(42)
probe = Probe(512).to(device)
optimizer = t.optim.AdamW(probe.parameters(), lr=lr)
losses = []

train_data = pd.read_csv('data/train_data.csv')
labels = t.randint(0, 2, (len(train_data),)).to(device)
inputs = [
    row['singular'] if label == 0 else row['plural'].upper() for (_, row), label in zip(train_data.iterrows(), labels)
]
train_inputs = inputs
with model.invoke(inputs):
    acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
acts = acts.value.clone()

for _ in range(epochs):
    optimizer.zero_grad()
    probs = probe(acts)
    loss = nn.BCELoss()(probs, labels.float())
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

In [7]:
losses

[1.0341346263885498,
 0.5637803673744202,
 0.5266983509063721,
 0.44276073575019836,
 0.3562548756599426,
 0.2869155704975128,
 0.23607821762561798,
 0.20007425546646118,
 0.1744411736726761,
 0.1548157036304474,
 0.13781310617923737,
 0.12156275659799576,
 0.10560642927885056,
 0.09034047275781631,
 0.07642094045877457,
 0.064373679459095,
 0.05443716421723366,
 0.04657389223575592,
 0.04056061804294586,
 0.036086101084947586]

In [8]:
t.manual_seed(42)
# get accuracy on ambiguous test set
test_data = pd.read_csv('data/test_data.csv')
labels = t.randint(0, 2, (len(test_data),)).to(device)
inputs = [
    row['singular'] if label == 0 else row['plural'].upper() for (_, row), label in zip(test_data.iterrows(), labels)
]
with model.invoke(inputs):
    acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
acts = acts.value.clone()

preds = probe(acts).round().long()
acc = (preds == labels).float().mean().item()
print(f'acc on ambiguous: {acc}')

t.manual_seed(42)
# get accuracy on disambiguating test set
test_data = pd.read_csv('data/test_data.csv')
feat1_labels = t.randint(0, 2, (len(test_data),)).to(device)
feat2_labels = t.randint(0, 2, (len(test_data),)).to(device)
inputs = []
for (_, row), label1, label2 in zip(test_data.iterrows(), feat1_labels, feat2_labels):
    if label1 == 0 and label2 == 0:
        inputs.append(row['singular'])
    elif label1 == 0 and label2 == 1:
        inputs.append(row['singular'].upper())
    elif label1 == 1 and label2 == 0:
        inputs.append(row['plural'])
    elif label1 == 1 and label2 == 1:
        inputs.append(row['plural'].upper())
test_inputs = inputs

with model.invoke(inputs):
    acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
acts = acts.value.clone()

preds = probe(acts).round().long()
feat1_acc = (preds == feat1_labels).float().mean().item()
feat2_acc = (preds == feat2_labels).float().mean().item()
print(f'feat1_acc: {feat1_acc}, feat2_acc: {feat2_acc}')


acc on ambiguous: 1.0
feat1_acc: 0.5766423344612122, feat2_acc: 0.970802903175354


In [9]:
from attribution import patching_effect
from dictionary_learning.dictionary import AutoEncoder

submodules = [
    model.gpt_neox.layers[i] for i in range(layer + 1)
]
dictionaries = []
for i in range(layer + 1):
    ae = AutoEncoder(512, 64 * 512).to(device)
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{i}/0_32768/ae.pt'))
    dictionaries.append(ae)

def metric_fn(model):
    return probe(model.gpt_neox.layers[layer].output[0][:,-1,:])

effects, _ = patching_effect(
    "YARNS",
    None,
    model,
    submodules,
    dictionaries,
    metric_fn,
    method='ig'
)

effects = {k : v.abs().mean(dim=0).mean(dim=0) for k, v in effects.items()}
for i, submodule in enumerate(submodules):
    print(f"Layer {i}:")
    effect = effects[submodule]
    for feature_idx in t.nonzero(effect):
        value = effect[tuple(feature_idx)]
        if value.abs() > 0.01:
            print(f"    Multindex: {tuple(feature_idx.tolist())}, Value: {value}")

Layer 0:
    Multindex: (4324,), Value: 0.028520982712507248
    Multindex: (4362,), Value: 0.011580727994441986
    Multindex: (5650,), Value: 0.03548774868249893
    Multindex: (8537,), Value: 0.011580727994441986
    Multindex: (17126,), Value: 0.012919182889163494
    Multindex: (22182,), Value: 0.013656112365424633
    Multindex: (25864,), Value: 0.05403883755207062
    Multindex: (29293,), Value: 0.018411755561828613
    Multindex: (32286,), Value: 0.011580727994441986
Layer 1:
    Multindex: (740,), Value: 0.01760883629322052
    Multindex: (14465,), Value: 0.06407757103443146
    Multindex: (16629,), Value: 0.011317379772663116
    Multindex: (17573,), Value: 0.01387042086571455
    Multindex: (18471,), Value: 0.13989761471748352
    Multindex: (19812,), Value: 0.01919487677514553
    Multindex: (22990,), Value: 0.020656585693359375
    Multindex: (27592,), Value: 0.021570025011897087
    Multindex: (27759,), Value: 0.017101265490055084
    Multindex: (27933,), Value: 0.0176157

In [10]:
input = "YARNS"

to_ablate = {
    submodules[0] : [
        5650,
        17126,
        22182,
        25864
    ],
    submodules[1] : [
        14465,
        18471,
        22990
    ],
    submodules[2] : [
        16421,
        22968,
        27888
    ]
}

with model.invoke(input):
    acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
acts = acts.value.clone()
pred_before = probe(acts).item()

with model.invoke(input):
    for submodule, dictionary in zip(submodules, dictionaries):
        x = submodule.output
        is_resid = type(x.shape) == tuple
        if is_resid:
            x = x[0]
        x_hat = dictionary(x)
        residual = x - x_hat
        
        f = dictionary.encode(x)
        ablation_idxs = to_ablate[submodule]
        for idx in ablation_idxs:
            f[..., idx] = 0
        x_hat = dictionary.decode(f)
        if is_resid:
            submodule.output[0][:] = x_hat + residual
        else:
            submodule.output = x_hat + residual
    acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
acts = acts.value.clone()
pred_after = probe(acts).item()

print(f'before: {pred_before}, after: {pred_after}')

before: 0.8413950800895691, after: 0.37445271015167236


In [11]:
# get mean feature activations on train set
train_data = pd.read_csv('data/train_data.csv')
labels = t.randint(0, 2, (len(train_data),)).to(device)
inputs = [
    row['singular'] if label == 0 else row['plural'].upper() for (_, row), label in zip(train_data.iterrows(), labels)
]

ablation_values = {}
with model.invoke(inputs):
    for submodule, dictionary in zip(submodules, dictionaries):
        x = submodule.output
        is_resid = type(x.shape) == tuple
        if is_resid:
            x = x[0]
        f = dictionary.encode(x)
        ablation_values[submodule] = f[:, -1, t.Tensor(to_ablate[submodule]).long()].mean(dim=0)

t.manual_seed(42)
# get accuracy on ambiguous test set
test_data = pd.read_csv('data/test_data.csv')
labels = t.randint(0, 2, (len(test_data),)).to(device)
inputs = [
    row['singular'] if label == 0 else row['plural'].upper() for (_, row), label in zip(test_data.iterrows(), labels)
]
with model.invoke(inputs):
    for submodule, dictionary in zip(submodules, dictionaries):
        x = submodule.output
        is_resid = type(x.shape) == tuple
        if is_resid:
            x = x[0]
        x_hat = dictionary(x)
        residual = x - x_hat
        
        f = dictionary.encode(x)
        ablation_idxs = to_ablate[submodule]
        f[:, -1, t.Tensor(ablation_idxs).long()] = ablation_values[submodule]
        x_hat = dictionary.decode(f)
        if is_resid:
            submodule.output[0][:] = x_hat + residual
        else:
            submodule.output = x_hat + residual
    acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
acts = acts.value.clone()

preds = probe(acts).round().long()
acc = (preds == labels).float().mean().item()
print(f'acc on ambiguous: {acc}')

t.manual_seed(42)
# get accuracy on disambiguating test set
test_data = pd.read_csv('data/test_data.csv')
feat1_labels = t.randint(0, 2, (len(test_data),)).to(device)
feat2_labels = t.randint(0, 2, (len(test_data),)).to(device)
inputs = []
for (_, row), label1, label2 in zip(test_data.iterrows(), feat1_labels, feat2_labels):
    if label1 == 0 and label2 == 0:
        inputs.append(row['singular'])
    elif label1 == 0 and label2 == 1:
        inputs.append(row['singular'].upper())
    elif label1 == 1 and label2 == 0:
        inputs.append(row['plural'])
    elif label1 == 1 and label2 == 1:
        inputs.append(row['plural'].upper())

with model.invoke(inputs):
    for submodule, dictionary in zip(submodules, dictionaries):
        x = submodule.output
        is_resid = type(x.shape) == tuple
        if is_resid:
            x = x[0]
        x_hat = dictionary(x)
        residual = x - x_hat
        
        f = dictionary.encode(x)
        ablation_idxs = to_ablate[submodule]
        f[:, :, t.Tensor(ablation_idxs).long()] = 0.
        x_hat = dictionary.decode(f)
        if is_resid:
            submodule.output[0][:] = x_hat + residual
        else:
            submodule.output = x_hat + residual
    acts = model.gpt_neox.layers[layer].output[0][:,-1,:].save()
acts = acts.value.clone()

preds = probe(acts).round().long()
feat1_acc = (preds == feat1_labels).float().mean().item()
feat2_acc = (preds == feat2_labels).float().mean().item()
print(f'feat1_acc: {feat1_acc}, feat2_acc: {feat2_acc}')

acc on ambiguous: 1.0
feat1_acc: 0.5912408828735352, feat2_acc: 0.8978102207183838


In [7]:
test_data = pd.read_json('/share/data/datasets/msgs/syntactic_category_lexical_content_the/test.jsonl', lines=True)

ling_accs, surface_accs = [], []
# get accuracy on test data
for batch_idx in range(len(test_data) // batch_size):
    inputs = test_data['sentence'][batch_idx * batch_size:(batch_idx + 1) * batch_size].tolist()
    ling_labels = test_data['linguistic_feature_label'][batch_idx * batch_size:(batch_idx + 1) * batch_size].tolist()
    surface_labels = test_data['surface_feature_label'][batch_idx * batch_size:(batch_idx + 1) * batch_size].tolist()

    with model.invoke(inputs) as invoker:
        hidden_states = model.gpt_neox.layers[-3].output[0].save()
    
    with t.no_grad():
        preds = probe(hidden_states.value)
        ling_acc = (preds.round() == t.Tensor(ling_labels).to('cuda:0')).float().mean()
        surface_acc = (preds.round() == t.Tensor(surface_labels).to('cuda:0')).float().mean()
        ling_accs.append(ling_acc.item())
        surface_accs.append(surface_acc.item())

print('ling acc:', sum(ling_accs) / len(ling_accs))
print('surface acc:', sum(surface_accs) / len(surface_accs))
    

ling acc: 0.9539930555555556
surface acc: 0.6663995726495726


In [8]:
from attribution import patching_effect
from dictionary_learning.dictionary import AutoEncoder

In [16]:
clean = "All grandsons do resemble the print and Debra is an organized child."
patch = "All grandsons do resemble a print and Debra is an banana child."

with model.invoke([clean, patch]) as invoker:
    hidden_states = model.gpt_neox.layers[-3].output[0].save()

with t.no_grad():
    preds = probe(hidden_states.value)
preds

tensor([0.4021, 0.0954], device='cuda:0')

In [17]:
def metric_fn(model):
    return probe(model.gpt_neox.layers[-3].output[0])

submodules = [
    model.gpt_neox.layers[i] for i in range(4)
] + [
    model.gpt_neox.layers[i].mlp for i in range(4)
]
dictionaries = []
for i in range(4):
    dictionary = AutoEncoder(512, 64 * 512).to(device)
    dictionary.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{i}/0_32768/ae.pt'))
    dictionaries.append(dictionary)
for i in range(4):
    dictionary = AutoEncoder(512, 64 * 512).to(device)
    dictionary.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/mlp_out_layer{i}/1_32768/ae.pt'))
    dictionaries.append(dictionary)

out = patching_effect(
    clean,
    patch,
    model,
    submodules,
    dictionaries,
    metric_fn,
)

In [19]:
effects, total_effect = out
print(f"Total effect: {total_effect}")
for layer, submodule in enumerate(submodules):
    print(f"Layer {layer}:")
    effect = effects[submodule]
    for feature_idx in t.nonzero(effect):
        value = effect[tuple(feature_idx)]
        if value.abs() > 0.1:
            print(f"    Multindex: {tuple(feature_idx.tolist())}, Value: {value}")

Total effect: tensor([-0.7627], device='cuda:0', grad_fn=<DivBackward0>)
Layer 0:
    Multindex: (0, 6, 23084), Value: -0.31019383668899536
    Multindex: (0, 6, 29115), Value: -0.2026602178812027
    Multindex: (0, 12, 9247), Value: 0.23665377497673035
    Multindex: (0, 12, 19133), Value: -0.13597454130649567
    Multindex: (0, 13, 1147), Value: 0.10391081869602203
    Multindex: (0, 13, 1256), Value: -0.12887312471866608
    Multindex: (0, 13, 3385), Value: -0.34761813282966614
    Multindex: (0, 13, 3613), Value: -0.11544839292764664
    Multindex: (0, 13, 5702), Value: -11.003226280212402
    Multindex: (0, 13, 5962), Value: -0.18838448822498322
    Multindex: (0, 13, 6959), Value: -0.9820340871810913
    Multindex: (0, 13, 15146), Value: -0.12887312471866608
    Multindex: (0, 13, 25951), Value: -0.24457168579101562
    Multindex: (0, 13, 26640), Value: -0.25634047389030457
    Multindex: (0, 13, 27692), Value: -0.12028089165687561
    Multindex: (0, 13, 31525), Value: -0.1246131

In [15]:
for i, tok in enumerate(invoker.input.input_ids[0]):
    print(f"{i}: {model.tokenizer.decode([tok])}")

0: All
1:  gr
2: ands
3: ons
4:  do
5:  resemble
6:  the
7:  print
8:  and
9:  De
10: bra
11:  is
12:  an
13:  organized
14:  child
15: .
