In [1]:
from nnsight import LanguageModel
from activation_utils import SparseAct
import torch as t
import plotly.graph_objects as go
from loading_utils import load_examples
from dictionary_learning import AutoEncoder
from dictionary_learning.dictionary import IdentityDict
from ablation_sam import run_with_ablations

  from .autonotebook import tqdm as notebook_tqdm


In [26]:
device = 'cuda:0'
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=device, dispatch=True)

start_layer = 0 # explain the model starting here

dataset = 'rc'

feature_circuit = t.load(f'circuits/{dataset}_train_dict10_node0.2_edge0.02_n100_aggsum.pt')['nodes']
neuron_circuit = t.load(f'circuits/{dataset}_train_dictid_node0.2_edge0.02_n100_aggsum.pt')['nodes']

examples = load_examples(f'/share/projects/dictionary_circuits/data/phenomena/{dataset}_test.json', 40, model, length=6)

ablation = 'mean'

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [35]:
# load submodules
submodules = []
if start_layer < 0: submodules.append(model.gpt_neox.embed_in)
for i in range(start_layer, len(model.gpt_neox.layers)):
    submodules.extend([
        model.gpt_neox.layers[i].attention,
        model.gpt_neox.layers[i].mlp,
        model.gpt_neox.layers[i]
    ])

submod_names = {
    model.gpt_neox.embed_in : 'embed'
}
for i in range(len(model.gpt_neox.layers)):
    submod_names[model.gpt_neox.layers[i].attention] = f'attn_{i}'
    submod_names[model.gpt_neox.layers[i].mlp] = f'mlp_{i}'
    submod_names[model.gpt_neox.layers[i]] = f'resid_{i}'

In [36]:
# load dictionaries
dict_id = 10

activation_dim = 512
expansion_factor = 64
dict_size = expansion_factor * activation_dim

feat_dicts = {}
ae = AutoEncoder(activation_dim, dict_size).to(device)
ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/embed/{dict_id}_{dict_size}/ae.pt'))
feat_dicts[model.gpt_neox.embed_in] = ae
for i in range(len(model.gpt_neox.layers)):
    ae = AutoEncoder(activation_dim, dict_size).to(device)
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/attn_out_layer{i}/{dict_id}_{dict_size}/ae.pt'))
    feat_dicts[model.gpt_neox.layers[i].attention] = ae

    ae = AutoEncoder(activation_dim, dict_size).to(device)
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/mlp_out_layer{i}/{dict_id}_{dict_size}/ae.pt'))
    feat_dicts[model.gpt_neox.layers[i].mlp] = ae

    ae = AutoEncoder(activation_dim, dict_size).to(device)
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_{dict_size}/ae.pt'))
    feat_dicts[model.gpt_neox.layers[i]] = ae

neuron_dicts = {
    submod : IdentityDict(activation_dim).to(device) for submod in submodules
}


In [37]:
# load data
clean_inputs = t.cat([e['clean_prefix'] for e in examples], dim=0).to('cuda:0')
clean_answer_idxs = t.tensor([e['clean_answer'] for e in examples], dtype=t.long, device='cuda:0')
patch_inputs = t.cat([e['patch_prefix'] for e in examples], dim=0).to('cuda:0')
patch_answer_idxs = t.tensor([e['patch_answer'] for e in examples], dtype=t.long, device='cuda:0')
def metric_fn(model):
    return (
        - t.gather(model.embed_out.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) + \
        t.gather(model.embed_out.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
    )

In [38]:
# set ablation fn
# if args.ablation == 'resample': ablation_fn = lambda x: x[t.randperm(x.act.shape[0])] # TODO this is wrong for SparseActs
if ablation == 'zero': ablation_fn = lambda x: x.zeros_like()
if ablation == 'mean': ablation_fn = lambda x: x.mean(dim=0).expand_as(x)


In [39]:
# get F(M)
with model.trace(clean_inputs), t.no_grad():
    metric = metric_fn(model).save()
fm = metric.value.mean()
print(f"F(M) = {fm}")

F(M) = 1.4239044189453125


In [40]:
# get F(∅)
fempty = run_with_ablations(
    clean_inputs,
    patch_inputs,
    model,
    submodules,
    feat_dicts,
    nodes = {
        submod : SparseAct(act=t.zeros(dict_size, dtype=t.bool), resc=t.zeros(1, dtype=t.bool)).to(device)
        for submod in submodules
    },
    metric_fn=metric_fn,
    ablation_fn=ablation_fn,
)
fempty = fempty.mean()
print(f"F(∅) = {fempty}")

F(∅) = -0.020782470703125


In [41]:
# sweeping over thresholds, get fc and fc' for:
# feature circuit with residuals
# feature circuit without residuals
# neuron circuit
# random feature circuit with all residuals

fc = {'features' : {}, 'features_wo_resids' : {}, 'neurons' : {}, 'random' : {}}
fccomp = {'features' : {}, 'features_wo_resids' : {}, 'neurons' : {}, 'random' : {}}
n_nodes = {'features' : {}, 'features_wo_resids' : {}, 'neurons' : {}, 'random' : {}}
thresholds = t.logspace(-4, 0, 15, 10)
with t.no_grad():
    for threshold in thresholds:
        feat_nodes = {
            submod : feature_circuit[submod_names[submod]].abs() > threshold for submod in submodules
        }
        neuron_nodes = {
            submod : neuron_circuit[submod_names[submod]].abs() > threshold for submod in submodules
        }
        n_nodes['features'][threshold] = sum([feat_nodes[submod].act.sum().item() + feat_nodes[submod].resc.sum().item() for submod in submodules])
        n_nodes['features_wo_resids'][threshold] = sum([feat_nodes[submod].act.sum().item() for submod in submodules])
        n_nodes['neurons'][threshold] = sum([neuron_nodes[submod].act.sum().item() for submod in submodules])
        n_nodes['random'][threshold] = n_nodes['features_wo_resids'][threshold]

        random_nodes = {}
        for submod in submodules:
            nodes = SparseAct(act=t.zeros(dict_size, dtype=t.bool), resc=t.ones(1, dtype=t.bool)).to(device)
            nodes.act[t.randperm(dict_size)[:n_nodes['random'][threshold]]] = True
            random_nodes[submod] = nodes

        fc['features'][threshold] = run_with_ablations(
            clean_inputs,
            patch_inputs,
            model,
            submodules,
            feat_dicts,
            nodes=feat_nodes,
            metric_fn=metric_fn,
            ablation_fn=ablation_fn
        ).mean()
        fc['features_wo_resids'][threshold] = run_with_ablations(
            clean_inputs,
            patch_inputs,
            model,
            submodules,
            feat_dicts,
            nodes=feat_nodes,
            metric_fn=metric_fn,
            ablation_fn=ablation_fn,
            handle_resids='remove'
        ).mean()
        fc['neurons'][threshold] = run_with_ablations(
            clean_inputs,
            patch_inputs,
            model,
            submodules,
            neuron_dicts,
            nodes=neuron_nodes,
            metric_fn=metric_fn,
            ablation_fn=ablation_fn
        ).mean()
        fc['random'][threshold] = run_with_ablations(
            clean_inputs,
            patch_inputs,
            model,
            submodules,
            feat_dicts,
            nodes=random_nodes,
            metric_fn=metric_fn,
            ablation_fn=ablation_fn
        ).mean()
        fccomp['features'][threshold] = run_with_ablations(
            clean_inputs,
            patch_inputs,
            model,
            submodules,
            feat_dicts,
            nodes=feat_nodes,
            metric_fn=metric_fn,
            ablation_fn=ablation_fn,
            complement=True,
        ).mean()
        fccomp['features_wo_resids'][threshold] = run_with_ablations(
            clean_inputs,
            patch_inputs,
            model,
            submodules,
            feat_dicts,
            nodes=feat_nodes,
            metric_fn=metric_fn,
            ablation_fn=ablation_fn,
            handle_resids='keep',
            complement=True
        ).mean()
        fccomp['neurons'][threshold] = run_with_ablations(
            clean_inputs,
            patch_inputs,
            model,
            submodules,
            neuron_dicts,
            nodes=neuron_nodes,
            metric_fn=metric_fn,
            ablation_fn=ablation_fn,
            complement=True
        ).mean()
        fccomp['random'][threshold] = run_with_ablations(
            clean_inputs,
            patch_inputs,
            model,
            submodules,
            feat_dicts,
            nodes=random_nodes,
            metric_fn=metric_fn,
            ablation_fn=ablation_fn,
            complement=True
        ).mean()


In [42]:
def faithfulness(fc, fempty, fm):
    return ((fc - fempty) / (fm - fempty)).item()

fig = go.Figure()

# Plot feature circuit
fig.add_trace(go.Scatter(x=list(n_nodes['features'].values()), y=[faithfulness(x, fempty, fm) for x in fc['features'].values()], name='Feature Circuit'))

# Plot feature circuit without residuals
fig.add_trace(go.Scatter(x=list(n_nodes['features_wo_resids'].values()), y=[faithfulness(x, fempty, fm) for x in fc['features_wo_resids'].values()], name='Feature Circuit w/o Residuals'))

# Plot neuron circuit
fig.add_trace(go.Scatter(x=list(n_nodes['neurons'].values()), y=[faithfulness(x, fempty, fm) for x in fc['neurons'].values()], name='Neuron Circuit'))

# Plot random feature circuit
fig.add_trace(go.Scatter(x=list(n_nodes['random'].values()), y=[faithfulness(x, fempty, fm) for x in fc['random'].values()], name='Random Feature Circuit'))

fig.update_layout(
        xaxis_type='log',
        xaxis_title='Number of nodes',
        yaxis_title='Faithfulness',
        title='Faithfulness of C'
)

fig.show()


In [43]:
fig = go.Figure()

# Plot feature circuit
fig.add_trace(go.Scatter(x=list(n_nodes['features'].values()), y=[faithfulness(x, fempty, fm) for x in fccomp['features'].values()], name='Feature Circuit'))

# Plot feature circuit without residuals
fig.add_trace(go.Scatter(x=list(n_nodes['features_wo_resids'].values()), y=[faithfulness(x, fempty, fm) for x in fccomp['features_wo_resids'].values()], name='Feature Circuit w/o Residuals'))

# Plot neuron circuit
fig.add_trace(go.Scatter(x=list(n_nodes['neurons'].values()), y=[faithfulness(x, fempty, fm) for x in fccomp['neurons'].values()], name='Neuron Circuit'))

# Plot random feature circuit
fig.add_trace(go.Scatter(x=list(n_nodes['random'].values()), y=[faithfulness(x, fempty, fm) for x in fccomp['random'].values()], name='Random Feature Circuit'))

fig.update_layout(
        xaxis_type='log',
        xaxis_title='Number of nodes',
        yaxis_title='Faithfulness',
        title='Faithfulness of M \ C'
)

fig.show()