In [1]:
from nnsight import LanguageModel
from circuit_dag import get_circuit, slice_dag, CircuitNode, sum_dag, mean_dag
import torch as t
from dictionary_learning import AutoEncoder, ActivationBuffer
from attribution import patching_effect, get_grad
from graph_utils import WeightedDAG

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map='cuda:0')
submodules = []
submodule_names = {}
dictionaries = {}
for layer in range(len(model.gpt_neox.layers)):
    submodule = model.gpt_neox.layers[layer].mlp
    submodule_names[submodule] = f'mlp{layer}'
    submodules.append(submodule)
    ae = AutoEncoder(512, 64 * 512).cuda()
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/mlp_out_layer{layer}/5_32768/ae.pt'))
    dictionaries[submodule] = ae

    submodule = model.gpt_neox.layers[layer]
    submodule_names[submodule] = f'resid{layer}'
    submodules.append(submodule)
    ae = AutoEncoder(512, 64 * 512).cuda()
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{layer}/5_32768/ae.pt'))
    dictionaries[submodule] = ae

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
clean_idx = model.tokenizer(" is").input_ids[-1]
patch_idx = model.tokenizer(" are").input_ids[-1]

def metric_fn(model):
    return model.embed_out.output[:,-1,patch_idx] - model.embed_out.output[:,-1,clean_idx]

dag = get_circuit(
    ["The man"],# "The tall boy"],
    ["The men"],# "The tall boys"],
    model,
    submodules,
    submodule_names,
    dictionaries,
    metric_fn,
)

# reduced_dag = mean_dag(sum_dag(dag, 1), 0, crosses=False)

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [4]:
dag._edges

defaultdict(dict,
            {resid5/[0, 1, 296]: {y: tensor(0.6446, device='cuda:0')},
             resid5/[0, 1, 1593]: {y: tensor(2.1513, device='cuda:0')},
             resid5/[0, 1, 2640]: {y: tensor(0.6481, device='cuda:0')},
             resid5/[0, 1, 20133]: {y: tensor(0.5398, device='cuda:0')},
             resid5/[0, 1, 29911]: {y: tensor(0.2499, device='cuda:0')},
             mlp5/[0, 1, 6589]: {y: tensor(-0.1892, device='cuda:0'),
              resid5/[0, 1, 296]: tensor(0.1737, device='cuda:0'),
              resid5/[0, 1, 1593]: tensor(-0., device='cuda:0'),
              resid5/[0, 1, 2640]: tensor(0.0502, device='cuda:0'),
              resid5/[0, 1, 20133]: tensor(-0., device='cuda:0'),
              resid5/[0, 1, 29911]: tensor(0.1133, device='cuda:0')},
             mlp5/[0, 1, 7716]: {y: tensor(0.3358, device='cuda:0'),
              resid5/[0, 1, 296]: tensor(-0.0138, device='cuda:0'),
              resid5/[0, 1, 1593]: tensor(0., device='cuda:0'),
              

In [4]:
for n in dag.nodes:
    print(n,
          dag.node_weight(n),
          sum([dag.edge_weight((n, m)) for m in dag.get_children(n)])
    )

y tensor([8.8031], device='cuda:0') 0
resid5/[0, 1, 296] tensor(0.6446, device='cuda:0') tensor(0.6446, device='cuda:0')
resid5/[0, 1, 1593] tensor(2.1513, device='cuda:0') tensor(2.1513, device='cuda:0')
resid5/[0, 1, 2640] tensor(0.6481, device='cuda:0') tensor(0.6481, device='cuda:0')
resid5/[0, 1, 20133] tensor(0.5398, device='cuda:0') tensor(0.5398, device='cuda:0')
resid5/[0, 1, 29911] tensor(0.2499, device='cuda:0') tensor(0.2499, device='cuda:0')
mlp5/[0, 1, 6589] tensor(0.1480, device='cuda:0') tensor(0.3371, device='cuda:0')
mlp5/[0, 1, 7716] tensor(0.2168, device='cuda:0') tensor(0.3358, device='cuda:0')
resid4/[0, 1, 15641] tensor(0.1317, device='cuda:0') tensor(0.2590, device='cuda:0')
resid4/[0, 1, 18976] tensor(3.6909, device='cuda:0') tensor(4.3817, device='cuda:0')
resid4/[0, 1, 26554] tensor(0.5209, device='cuda:0') tensor(0.6653, device='cuda:0')
mlp4/[0, 1, 10980] tensor(0.8538, device='cuda:0') tensor(1.0157, device='cuda:0')
mlp4/[0, 1, 22377] tensor(0.1840, devic

In [8]:
for node in dag.nodes:
    if node.name == 'y': continue
    grad_to_y = grads_to_y[node.submodule][node.feat_idx]
    computed_grads = []
    for child in dag.get_children(node):
        if child.name == 'y':
            computed_grads.append(dag.edge_weight((node, child)))
        else:
            computed_grads.append(dag.edge_weight((node, child)) * grads_to_y[child.submodule][child.feat_idx])
    computed_grad = sum(computed_grads)
    print(node, grad_to_y - computed_grad)

resid5/[0, 1, 296] tensor(0., device='cuda:0')
resid5/[0, 1, 1593] tensor(0., device='cuda:0')
resid5/[0, 1, 2640] tensor(0., device='cuda:0')
resid5/[0, 1, 20133] tensor(0., device='cuda:0')
resid5/[0, 1, 29911] tensor(0., device='cuda:0')
mlp5/[0, 1, 6589] tensor(7.4506e-09, device='cuda:0')
mlp5/[0, 1, 7716] tensor(0., device='cuda:0')
resid4/[0, 1, 15641] tensor(0., device='cuda:0')
resid4/[0, 1, 18976] tensor(0., device='cuda:0')
resid4/[0, 1, 26554] tensor(-5.9605e-08, device='cuda:0')
mlp4/[0, 1, 10980] tensor(0., device='cuda:0')
mlp4/[0, 1, 22377] tensor(0., device='cuda:0')
resid3/[0, 1, 10208] tensor(-5.9605e-08, device='cuda:0')
resid3/[0, 1, 10990] tensor(-5.9605e-08, device='cuda:0')
resid3/[0, 1, 14341] tensor(1.1921e-07, device='cuda:0')
resid3/[0, 1, 16170] tensor(0., device='cuda:0')
resid3/[0, 1, 19411] tensor(-1.1921e-07, device='cuda:0')
resid3/[0, 1, 20739] tensor(-2.9802e-08, device='cuda:0')
resid3/[0, 1, 21519] tensor(1.4901e-08, device='cuda:0')
resid2/[0, 1, 

In [5]:
dag._nodes

{y: tensor([8.8031], device='cuda:0'),
 resid5/[0, 1, 296]: tensor(0.6446, device='cuda:0'),
 resid5/[0, 1, 1593]: tensor(2.1513, device='cuda:0'),
 resid5/[0, 1, 2640]: tensor(0.6481, device='cuda:0'),
 resid5/[0, 1, 20133]: tensor(0.5398, device='cuda:0'),
 resid5/[0, 1, 29911]: tensor(0.2499, device='cuda:0'),
 mlp5/[0, 1, 6589]: tensor(0.1480, device='cuda:0'),
 mlp5/[0, 1, 7716]: tensor(0.2168, device='cuda:0'),
 resid4/[0, 1, 15641]: tensor(0.1317, device='cuda:0'),
 resid4/[0, 1, 18976]: tensor(3.6909, device='cuda:0'),
 resid4/[0, 1, 26554]: tensor(0.5209, device='cuda:0'),
 mlp4/[0, 1, 10980]: tensor(0.8538, device='cuda:0'),
 mlp4/[0, 1, 22377]: tensor(0.1840, device='cuda:0'),
 resid3/[0, 1, 10208]: tensor(0.2882, device='cuda:0'),
 resid3/[0, 1, 10990]: tensor(0.1479, device='cuda:0'),
 resid3/[0, 1, 14341]: tensor(0.4305, device='cuda:0'),
 resid3/[0, 1, 16170]: tensor(0.3376, device='cuda:0'),
 resid3/[0, 1, 19411]: tensor(0.1758, device='cuda:0'),
 resid3/[0, 1, 20739]: 

In [4]:
dag._edges

defaultdict(dict,
            {resid5/[0, 1, 296]: {y: tensor(0.6446, device='cuda:0')},
             resid5/[0, 1, 1593]: {y: tensor(2.1513, device='cuda:0')},
             resid5/[0, 1, 2640]: {y: tensor(0.6481, device='cuda:0')},
             resid5/[0, 1, 20133]: {y: tensor(0.5398, device='cuda:0')},
             resid5/[0, 1, 29911]: {y: tensor(0.2499, device='cuda:0')},
             mlp5/[0, 1, 6589]: {resid5/[0, 1, 296]: tensor(0.1737, device='cuda:0'),
              resid5/[0, 1, 2640]: tensor(0.0502, device='cuda:0'),
              resid5/[0, 1, 29911]: tensor(0.1133, device='cuda:0')},
             mlp5/[0, 1, 7716]: {y: tensor(0.3358, device='cuda:0')},
             resid4/[0, 1, 15641]: {y: tensor(0.2590, device='cuda:0')},
             resid4/[0, 1, 18976]: {y: tensor(4.3817, device='cuda:0')},
             resid4/[0, 1, 26554]: {resid5/[0, 1, 296]: tensor(0.4305, device='cuda:0'),
              resid5/[0, 1, 2640]: tensor(0.1260, device='cuda:0'),
              resid5/[0,

In [4]:
for n in dag.nodes:
    if n.name == 'y': continue
    oomphs_out = []
    for m in dag.get_children(n):
        if m.name == 'y':
            oomphs_out.append(dag.edge_weight((n, m)))
        else:
            oomphs_out.append(dag.edge_weight((n, m)) * grads_to_y[m.submodule][m.feat_idx])
    oomphs = sum(oomphs_out)
    print(n, oomphs, grads_to_y[n.submodule][n.feat_idx])

TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'

In [6]:
for n in dag.nodes:
    sum_out = sum([dag.edge_weight((n, m)) for m in dag.get_children(n)])
    print(n, sum_out, dag.node_weight(n))

y 0 tensor([8.8031])
resid5/tensor([  0,   1, 296]) tensor(0.6446) tensor(0.6446)
resid5/tensor([   0,    1, 1593]) tensor(2.1513) tensor(2.1513)
resid5/tensor([   0,    1, 2640]) tensor(0.6481) tensor(0.6481)
resid5/tensor([    0,     1, 20133]) tensor(0.5399) tensor(0.5399)
resid5/tensor([    0,     1, 29911]) tensor(0.2500) tensor(0.2500)
mlp5/tensor([   0,    1, 6589]) tensor(0.3371) tensor(0.1480)
mlp5/tensor([   0,    1, 7716]) tensor(0.3358) tensor(0.2168)
resid4/tensor([    0,     1, 15641]) tensor(0.2533) tensor(0.1317)
resid4/tensor([    0,     1, 18976]) tensor(4.3618) tensor(3.6909)
resid4/tensor([    0,     1, 26554]) tensor(0.6652) tensor(0.5210)
mlp4/tensor([    0,     1, 10980]) tensor(1.1228) tensor(0.8537)
mlp4/tensor([    0,     1, 22377]) tensor(0.4600) tensor(0.1840)
resid3/tensor([    0,     1, 10208]) tensor(0.3910) tensor(0.2883)
resid3/tensor([    0,     1, 10990]) tensor(0.1585) tensor(0.1479)
resid3/tensor([    0,     1, 14341]) tensor(0.5493) tensor(0.4305)


In [5]:
for e in reduced_dag.edges:
    print(f'{e}: {reduced_dag.edge_weight(e)}')

(mlp0/tensor([25126]), y): 0.12455693632364273
(mlp0/tensor([25126]), resid0/tensor([24128])): 0.005014520604163408
(mlp0/tensor([25126]), resid0/tensor([1198])): 0.04726092889904976
(mlp0/tensor([25126]), resid3/tensor([21519])): 0.007359142880886793
(mlp0/tensor([25126]), resid3/tensor([5969])): 0.009711107239127159
(mlp0/tensor([25126]), mlp4/tensor([10980])): 0.03407903388142586
(mlp0/tensor([25126]), resid4/tensor([9879])): 0.010634061880409718
(mlp0/tensor([25126]), mlp5/tensor([7716])): 0.005865642800927162
(mlp0/tensor([18832]), y): 0.2968097925186157
(mlp0/tensor([18832]), resid0/tensor([24128])): 0.01730598509311676
(mlp0/tensor([18832]), resid0/tensor([20837])): 0.06074358522891998
(mlp0/tensor([18832]), resid0/tensor([1198])): 0.6081952452659607
(mlp0/tensor([18832]), resid3/tensor([21519])): 0.008550283499062061
(mlp0/tensor([18832]), mlp4/tensor([10980])): 0.00861691776663065
(mlp0/tensor([18832]), mlp5/tensor([6589])): 0.012028270401060581
(mlp0/tensor([18832]), resid5/t

In [6]:
reduced_dag._nodes

{y: tensor(9.2291),
 mlp0/tensor([25126]): tensor(0.1301),
 mlp0/tensor([18832]): tensor(0.5547),
 mlp0/tensor([15910]): tensor(0.1396),
 mlp0/tensor([13270]): tensor(0.3526),
 mlp0/tensor([3855]): tensor(0.1948),
 mlp0/tensor([776]): tensor(0.1115),
 mlp0/tensor([27101]): tensor(0.3040),
 mlp0/tensor([6683]): tensor(1.6339),
 resid0/tensor([31147]): tensor(0.0972),
 resid0/tensor([29378]): tensor(0.2960),
 resid0/tensor([24128]): tensor(0.3695),
 resid0/tensor([20837]): tensor(0.0631),
 resid0/tensor([13833]): tensor(2.0817),
 resid0/tensor([9008]): tensor(0.2391),
 resid0/tensor([1198]): tensor(0.5633),
 mlp1/tensor([21038]): tensor(0.1549),
 mlp1/tensor([5594]): tensor(0.2252),
 mlp1/tensor([12117]): tensor(0.1557),
 resid1/tensor([17887]): tensor(0.6304),
 resid1/tensor([8832]): tensor(0.3117),
 resid1/tensor([6538]): tensor(0.0536),
 resid1/tensor([28003]): tensor(0.5377),
 resid1/tensor([27061]): tensor(0.1326),
 resid1/tensor([19307]): tensor(0.0652),
 resid1/tensor([10595]): te