In [1]:
from nnsight import LanguageModel
from circuit import get_circuit, slice_dag, CircuitNode, sum_dag, mean_dag
import torch as t
from dictionary_learning import AutoEncoder
from attribution import patching_effect, get_grad
from graph_utils import WeightedDAG

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map='cpu')
submodules = []
submodule_names = {}
dictionaries = {}
for layer in range(len(model.gpt_neox.layers)):
    submodule = model.gpt_neox.layers[layer].mlp
    submodule_names[submodule] = f'mlp{layer}'
    submodules.append(submodule)
    ae = AutoEncoder(512, 64 * 512).cuda()
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/mlp_out_layer{layer}/5_32768/ae.pt'))
    dictionaries[submodule] = ae

    submodule = model.gpt_neox.layers[layer]
    submodule_names[submodule] = f'resid{layer}'
    submodules.append(submodule)
    ae = AutoEncoder(512, 64 * 512).cuda()
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{layer}/5_32768/ae.pt'))
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{layer}/5_32768/ae.pt'))
    dictionaries[submodule] = ae

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
clean_idx = model.tokenizer(" is").input_ids[-1]
patch_idx = model.tokenizer(" are").input_ids[-1]

def metric_fn(model):
    return model.embed_out.output[:,-1,patch_idx] - model.embed_out.output[:,-1,clean_idx]

dag = get_circuit(
    ["The man", "The tall boy"],
    ["The men", "The tall boys"],
    model,
    submodules,
    submodule_names,
    dictionaries,
    metric_fn,
)

reduced_dag = mean_dag(sum_dag(dag, 1), 0, crosses=False)

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
