In [1]:
from nnsight import LanguageModel
from circuit import get_circuit, slice_dag, CircuitNode, sum_dag, mean_dag
import torch as t
from dictionary_learning import AutoEncoder
from attribution import patching_effect, get_grad
from graph_utils import WeightedDAG

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map='cpu')
submodules = []
submodule_names = {}
dictionaries = {}
for layer in range(len(model.gpt_neox.layers)):
    submodule = model.gpt_neox.layers[layer].mlp
    submodule_names[submodule] = f'mlp{layer}'
    submodules.append(submodule)
    ae = AutoEncoder(512, 64 * 512).cuda()
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/mlp_out_layer{layer}/5_32768/ae.pt'))
    dictionaries[submodule] = ae

    submodule = model.gpt_neox.layers[layer]
    submodule_names[submodule] = f'resid{layer}'
    submodules.append(submodule)
    ae = AutoEncoder(512, 64 * 512).cuda()
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{layer}/5_32768/ae.pt'))
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{layer}/5_32768/ae.pt'))
    dictionaries[submodule] = ae

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [8]:
clean_idx = model.tokenizer(" is").input_ids[-1]
patch_idx = model.tokenizer(" are").input_ids[-1]

def metric_fn(model):
    return model.embed_out.output[:,-1,patch_idx] - model.embed_out.output[:,-1,clean_idx]

dag = get_circuit(
    ["The man", "The tall boy"],
    ["The men", "The tall boys"],
    model,
    submodules,
    submodule_names,
    dictionaries,
    metric_fn,
)

reduced_dag = mean_dag(sum_dag(dag, 1), 0, crosses=False)

In [10]:
for e in reduced_dag.edges:
    print(f'{e}: {reduced_dag.edge_weight(e)}')

(mlp0/tensor([25126]), y): 0.1246533915400505
(mlp0/tensor([25126]), resid0/tensor([24128])): 0.005013970658183098
(mlp0/tensor([25126]), resid0/tensor([1198])): 0.047263868153095245
(mlp0/tensor([25126]), resid3/tensor([21519])): 0.007365047000348568
(mlp0/tensor([25126]), resid3/tensor([5969])): 0.009712073020637035
(mlp0/tensor([25126]), mlp4/tensor([10980])): 0.034069810062646866
(mlp0/tensor([25126]), resid4/tensor([9879])): 0.010627629235386848
(mlp0/tensor([25126]), mlp5/tensor([7716])): 0.005865936167538166
(mlp0/tensor([18832]), y): 0.2966051399707794
(mlp0/tensor([18832]), resid0/tensor([24128])): 0.017304087057709694
(mlp0/tensor([18832]), resid0/tensor([20837])): 0.06075422838330269
(mlp0/tensor([18832]), resid0/tensor([1198])): 0.6082330942153931
(mlp0/tensor([18832]), resid3/tensor([21519])): 0.008559301495552063
(mlp0/tensor([18832]), mlp4/tensor([10980])): 0.00859974417835474
(mlp0/tensor([18832]), mlp5/tensor([6589])): 0.012030516751110554
(mlp0/tensor([18832]), resid5

In [11]:
reduced_dag._nodes

{y: tensor(9.2292),
 mlp0/tensor([25126]): tensor(0.1301),
 mlp0/tensor([18832]): tensor(0.5547),
 mlp0/tensor([15910]): tensor(0.1396),
 mlp0/tensor([13270]): tensor(0.3526),
 mlp0/tensor([3855]): tensor(0.1948),
 mlp0/tensor([776]): tensor(0.1114),
 mlp0/tensor([27101]): tensor(0.3040),
 mlp0/tensor([6683]): tensor(1.6339),
 resid0/tensor([31147]): tensor(0.0972),
 resid0/tensor([29378]): tensor(0.2960),
 resid0/tensor([24128]): tensor(0.3694),
 resid0/tensor([20837]): tensor(0.0631),
 resid0/tensor([13833]): tensor(2.0817),
 resid0/tensor([9008]): tensor(0.2391),
 resid0/tensor([1198]): tensor(0.5633),
 mlp1/tensor([21038]): tensor(0.1549),
 mlp1/tensor([5594]): tensor(0.2252),
 mlp1/tensor([12117]): tensor(0.1557),
 resid1/tensor([17887]): tensor(0.6304),
 resid1/tensor([8832]): tensor(0.3117),
 resid1/tensor([6538]): tensor(0.0536),
 resid1/tensor([28003]): tensor(0.5377),
 resid1/tensor([27061]): tensor(0.1326),
 resid1/tensor([19307]): tensor(0.0652),
 resid1/tensor([10595]): te