## Investigating circuit quality

##### Effect of circuit size on quality metrics
- Dataset: Type of circuit (grammar, context, bigram)
    - grammar: RC dataset
    - context: BiB datset
    - bigram: cluster with bigrams
- Dataset size (number od samples being averaged on)

- Thresholds: Circuit size
    - Node 
    - Edge

##### Which type of nodes are significantly discovered best, worst?
- model region (layer depth; layer type)

Restrictions:
- layer_type: attn, mlp, resid (no embeddings)

##### What part of the ground truth are we missing?
- tracr
- **model perplexity after ablating circuit**

##### Metrics:
- Faithfulness:
    1.  (F(C) - F(∅)) / (F(M) - F(∅))
    2.  |F(C) - F(M)|

- Completeness:
    1. |F(∅) - F(M \ C)|

In [1]:
import sys
sys.path.append('/home/can/dictionary-circuits')
from ablation_sam import run_with_ablations
import torch as t
from argparse import ArgumentParser
from nnsight import LanguageModel
from dictionary_learning import AutoEncoder

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map='cuda:0', dispatch=True)


# submodules = \
#     [layer.attention for layer in model.gpt_neox.layers] + \
#     [layer.mlp for layer in model.gpt_neox.layers] + \
#     [layer for layer in model.gpt_neox.layers]

submodules = [model.gpt_neox.embed_in] + \
    [layer.attention for layer in model.gpt_neox.layers] + \
    [layer.mlp for layer in model.gpt_neox.layers] + \
    [layer for layer in model.gpt_neox.layers]
dictionaries = {}
ae = AutoEncoder(512, 64 * 512).to('cuda:0')
ae.load_state_dict(t.load('/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/embed/10_32768/ae.pt'))
dictionaries[model.gpt_neox.embed_in] = ae
for i in range(len(model.gpt_neox.layers)):
    ae = AutoEncoder(512, 64 * 512).to('cuda:0')
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/attn_out_layer{i}/10_32768/ae.pt'))
    dictionaries[model.gpt_neox.layers[i].attention] = ae

    ae = AutoEncoder(512, 64 * 512).to('cuda:0')
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/mlp_out_layer{i}/10_32768/ae.pt'))
    dictionaries[model.gpt_neox.layers[i].mlp] = ae

    ae = AutoEncoder(512, 64 * 512).to('cuda:0')
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{i}/10_32768/ae.pt'))
    dictionaries[model.gpt_neox.layers[i]] = ae

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
threshold = 0.1
ablation = "resample"
circuit_name = "lin_effects_final-5-pos_nsamples8192_nctx64_cluster50of750_dict10_node0.1_edge0.01_n4_aggsum.pt"
faithfulness = True
completeness = True
handle_resids = "default"

# parser = ArgumentParser()
# parser.add_argument('--threshold', type=float, default=0.1)
# parser.add_argument('--ablation', type=str, default='resample')
# parser.add_argument('--circuit', type=str, default='rc_dict10_node0.01_edge0.001_n30_aggsum.pt')
# parser.add_argument('--faithfulness', action='store_true')
# parser.add_argument('--completeness', action='store_true')
# parser.add_argument('--handle_resids', type=str, default='default')
# args = parser.parse_args()

In [4]:
circuit = t.load(f'/home/can/dictionary-circuits/circuits/{circuit_name}')

examples = circuit['examples']

nodes_out = circuit['nodes']
nodes = {}
submod_nodes = (nodes_out['embed'] > threshold).nonzero().squeeze(-1)
nodes[model.gpt_neox.embed_in] = list(submod_nodes.act) + (['res'] if len(submod_nodes.resc) > 0 else [])
for i in range(len(model.gpt_neox.layers)):
    submod_nodes = (nodes_out[f'attn_{i}'] > threshold).nonzero().squeeze(-1)
    nodes[model.gpt_neox.layers[i].attention] = list(submod_nodes.act) + (['res'] if len(submod_nodes.resc) > 0 else [])
    submod_nodes = (nodes_out[f'mlp_{i}'] > threshold).nonzero().squeeze(-1)
    nodes[model.gpt_neox.layers[i].mlp] = list(submod_nodes.act) + (['res'] if len(submod_nodes.resc) > 0 else [])
    submod_nodes = (nodes_out[f'resid_{i}'] > threshold).nonzero().squeeze(-1)
    nodes[model.gpt_neox.layers[i]] = list(submod_nodes.act) + (['res'] if len(submod_nodes.resc) > 0 else [])

In [5]:
clean_inputs = t.cat([e['clean_prefix'] for e in examples], dim=0).to('cuda:0')
clean_answer_idxs = t.tensor([e['clean_answer'] for e in examples], dtype=t.long, device='cuda:0')
patch_inputs = t.cat([e['patch_prefix'] for e in examples], dim=0).to('cuda:0')
patch_answer_idxs = t.tensor([e['patch_answer'] for e in examples], dtype=t.long, device='cuda:0')
def metric_fn(model):
    return (
        - t.gather(model.embed_out.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) + \
        t.gather(model.embed_out.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
    )

if ablation == 'resample': ablation_fn = lambda x: x
if ablation == 'zero': ablation_fn = lambda x: x.zeros_like()
if ablation == 'mean': ablation_fn = lambda x: x.mean(dim=0).expand_as(x)

KeyError: 'patch_prefix'

In [None]:
# Faithfulness
ablation_outs = run_with_ablations(
    clean_inputs,
    patch_inputs,
    model,
    submodules,
    dictionaries,
    nodes,
    metric_fn,
    ablation_fn=ablation_fn,
    handle_resids=handle_resids
)
print(f"F(C) = {ablation_outs.mean()}")

with model.trace(clean_inputs):
    metric = metric_fn(model).save()
normal_outs = metric.value
print(f"F(M) = {normal_outs.mean()}")

all_ablated = run_with_ablations(
    clean_inputs,
    patch_inputs,
    model,
    submodules,
    dictionaries,
    nodes={submod : [] for submod in submodules},
    metric_fn=metric_fn,
    ablation_fn=ablation_fn,
    handle_resids=handle_resids
)
print(f"F(∅) = {all_ablated.mean()}")

print(f"|F(C) - F(M)| = {(ablation_outs - normal_outs).abs().mean()}")
print(f"|F(∅) - F(M)| = {(all_ablated - normal_outs).abs().mean()}")

print(normal_outs - ablation_outs)

: 

In [None]:
# Completeness

ablation_outs = run_with_ablations(
        clean_inputs,
        patch_inputs,
        model,
        submodules,
        dictionaries,
        nodes,
        metric_fn,
        complement=True,
        ablation_fn=ablation_fn,
        handle_resids=handle_resids
    )
    print(f"F(M \ C) = {ablation_outs.mean()}")
    print(f'|F(∅) - F(M \ C)| = {(all_ablated - ablation_outs).abs().mean()}')




    # print(f"Completeness: {(all_ablated - ablation_outs).abs().mean()}")