In [58]:
from nnsight import LanguageModel
from activation_utils import SparseAct
import torch as t
import plotly.express as px
from plotly.subplots import make_subplots
from loading_utils import load_examples
from dictionary_learning import AutoEncoder
from dictionary_learning.dictionary import IdentityDict
from ablation_sam import run_with_ablations
from tqdm import tqdm
import pandas as pd
from collections import defaultdict

In [59]:
device = 'cuda:0'
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map=device, dispatch=True)

start_layer = 0 # explain the model starting here

feature_circuit = t.load('circuits/rc_dict10_node0.1_edge0.01_n30_aggsum.pt')['nodes']
neuron_circuit = t.load('circuits/rc_dictid_node0.02_edge0.002_n30_aggsum.pt')['nodes']

examples = load_examples('/share/projects/dictionary_circuits/data/phenomena/rc_test.json', 40, model, length=6)

ablation = 'mean'

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [60]:
# load submodules
submodules = []
if start_layer < 0: submodules.append(model.gpt_neox.embed_in)
for i in range(start_layer, len(model.gpt_neox.layers)):
    submodules.extend([
        model.gpt_neox.layers[i].attention,
        model.gpt_neox.layers[i].mlp,
        model.gpt_neox.layers[i]
    ])

submod_names = {
    model.gpt_neox.embed_in : 'embed'
}
for i in range(len(model.gpt_neox.layers)):
    submod_names[model.gpt_neox.layers[i].attention] = f'attn_{i}'
    submod_names[model.gpt_neox.layers[i].mlp] = f'mlp_{i}'
    submod_names[model.gpt_neox.layers[i]] = f'resid_{i}'


# select submodules
criterium = "mlp"
criterium_name = "feature_" + criterium
selected_sumodules = []
for submod in submodules:
    name = submod_names[submod]
    if criterium in name:
        selected_sumodules.append(submod)
        print(name)

mlp_0
mlp_1
mlp_2
mlp_3
mlp_4
mlp_5


In [61]:
# load dictionaries
dict_id = 10

activation_dim = 512
expansion_factor = 64
dict_size = expansion_factor * activation_dim

feat_dicts = {}
ae = AutoEncoder(activation_dim, dict_size).to(device)
ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/embed/{dict_id}_{dict_size}/ae.pt'))
feat_dicts[model.gpt_neox.embed_in] = ae
for i in range(len(model.gpt_neox.layers)):
    ae = AutoEncoder(activation_dim, dict_size).to(device)
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/attn_out_layer{i}/{dict_id}_{dict_size}/ae.pt'))
    feat_dicts[model.gpt_neox.layers[i].attention] = ae

    ae = AutoEncoder(activation_dim, dict_size).to(device)
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/mlp_out_layer{i}/{dict_id}_{dict_size}/ae.pt'))
    feat_dicts[model.gpt_neox.layers[i].mlp] = ae

    ae = AutoEncoder(activation_dim, dict_size).to(device)
    ae.load_state_dict(t.load(f'/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped/resid_out_layer{i}/{dict_id}_{dict_size}/ae.pt'))
    feat_dicts[model.gpt_neox.layers[i]] = ae

neuron_dicts = {
    submod : IdentityDict(activation_dim).to(device) for submod in submodules
}

In [62]:
# load data
clean_inputs = t.cat([e['clean_prefix'] for e in examples], dim=0).to('cuda:0')
clean_answer_idxs = t.tensor([e['clean_answer'] for e in examples], dtype=t.long, device='cuda:0')
patch_inputs = t.cat([e['patch_prefix'] for e in examples], dim=0).to('cuda:0')
patch_answer_idxs = t.tensor([e['patch_answer'] for e in examples], dtype=t.long, device='cuda:0')


def metric_fn(model):
    return (
        - t.gather(model.embed_out.output[:,-1,:], dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) + \
        t.gather(model.embed_out.output[:,-1,:], dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
    )

In [63]:
# set ablation fn
# if args.ablation == 'resample': ablation_fn = lambda x: x[t.randperm(x.act.shape[0])] # TODO this is wrong for SparseActs
if ablation == 'zero': ablation_fn = lambda x: x.zeros_like()
if ablation == 'mean': ablation_fn = lambda x: x.mean(dim=0).expand_as(x)


In [64]:
# get F(M)
with model.trace(clean_inputs), t.no_grad():
    metric = metric_fn(model).save()
fm = metric.value.mean()
print(f"F(M) = {fm}")

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


F(M) = 1.4239044189453125


In [65]:
# # get F(∅)
# fempty = run_with_ablations(
#     clean_inputs,
#     patch_inputs,
#     model,
#     submodules,
#     feat_dicts,
#     nodes = {
#         submod : SparseAct(act=t.zeros(dict_size, dtype=t.bool), resc=t.zeros(1, dtype=t.bool)).to(device)
#         for submod in submodules
#     },
#     metric_fn=metric_fn,
#     ablation_fn=ablation_fn
# )
# fempty = fempty.mean()
# print(f"F(∅) = {fempty}")

In [66]:
# sweeping over thresholds, get fc and fc' for:
# feature circuit with residuals
# feature circuit without residuals
# neuron circuit
# random feature circuit with all residuals

fc = {'features' : {}, 'features_wo_resids' : {}, 'neurons' : {}, 'random' : {}}
fccomp = {'features' : {}, 'features_wo_resids' : {}, 'neurons' : {}, 'random' : {}}
n_nodes = {'features' : {}, 'features_wo_resids' : {}, 'neurons' : {}, 'random' : {}}
fempties = dict()

thresholds = t.logspace(-4, -1, 20, 10)
layers = range(start_layer, len(model.gpt_neox.layers))
with t.no_grad():
    for layer in layers:
        for keys in fc.keys():
            fc[keys][layer] = dict()
            fccomp[keys][layer] = dict()
            n_nodes[keys][layer] = dict()

        # select submodules
        criterium = str(layer)
        criterium_name = f'layer_{criterium}'
        selected_submodules = []
        names = []
        for submod in submodules:
            name = submod_names[submod]
            if criterium in name and not 'resid' in name:
                selected_submodules.append(submod)
                names.append(name)
        print(f'Layer {layer}: {names} submodules')

        # get F(∅)
        fempties[layer] = run_with_ablations(
            clean_inputs,
            patch_inputs,
            model,
            selected_submodules,
            feat_dicts,
            nodes = {
                submod : SparseAct(act=t.zeros(dict_size, dtype=t.bool), resc=t.zeros(1, dtype=t.bool)).to(device)
                for submod in selected_submodules
            },
            metric_fn=metric_fn,
            ablation_fn=ablation_fn
        ).mean()

        for threshold in tqdm(thresholds):        
            feat_nodes = {
                submod : feature_circuit[submod_names[submod]].abs() > threshold for submod in selected_submodules
            }
            neuron_nodes = {
                submod : neuron_circuit[submod_names[submod]].abs() > threshold for submod in selected_submodules
            }
            n_nodes['features'][layer][threshold] = sum([feat_nodes[submod].act.sum().item() + feat_nodes[submod].resc.sum().item() for submod in selected_submodules])
            n_nodes['features_wo_resids'][layer][threshold] = sum([feat_nodes[submod].act.sum().item() for submod in selected_submodules])
            n_nodes['neurons'][layer][threshold] = sum([neuron_nodes[submod].act.sum().item() for submod in selected_submodules])
            n_nodes['random'][layer][threshold] = n_nodes['features_wo_resids'][layer][threshold]

            random_nodes = {}
            for submod in selected_submodules:
                nodes = SparseAct(act=t.zeros(dict_size, dtype=t.bool), resc=t.ones(1, dtype=t.bool)).to(device)
                nodes.act[t.randperm(dict_size)[:n_nodes['random'][layer][threshold]]] = True
                random_nodes[submod] = nodes

            fc['features'][layer][threshold] = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                selected_submodules,
                feat_dicts,
                nodes=feat_nodes,
                metric_fn=metric_fn,
                ablation_fn=ablation_fn
            ).mean()
            fc['features_wo_resids'][layer][threshold] = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                selected_submodules,
                feat_dicts,
                nodes=feat_nodes,
                metric_fn=metric_fn,
                ablation_fn=ablation_fn,
                handle_resids='remove'
            ).mean()
            fc['neurons'][layer][threshold] = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                selected_submodules,
                neuron_dicts,
                nodes=neuron_nodes,
                metric_fn=metric_fn,
                ablation_fn=ablation_fn
            ).mean()
            fc['random'][layer][threshold] = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                selected_submodules,
                feat_dicts,
                nodes=random_nodes,
                metric_fn=metric_fn,
                ablation_fn=ablation_fn
            ).mean()
            fccomp['features'][layer][threshold] = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                selected_submodules,
                feat_dicts,
                nodes=feat_nodes,
                metric_fn=metric_fn,
                ablation_fn=ablation_fn,
                complement=True
            ).mean()
            fccomp['features_wo_resids'][layer][threshold] = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                selected_submodules,
                feat_dicts,
                nodes=feat_nodes,
                metric_fn=metric_fn,
                ablation_fn=ablation_fn,
                handle_resids='keep',
                complement=True
            ).mean()
            fccomp['neurons'][layer][threshold] = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                selected_submodules,
                neuron_dicts,
                nodes=neuron_nodes,
                metric_fn=metric_fn,
                ablation_fn=ablation_fn,
                complement=True
            ).mean()
            fccomp['random'][layer][threshold] = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                selected_submodules,
                feat_dicts,
                nodes=random_nodes,
                metric_fn=metric_fn,
                ablation_fn=ablation_fn,
                complement=True
            ).mean()


Layer 0: ['attn_0', 'mlp_0'] submodules


100%|██████████| 20/20 [00:56<00:00,  2.82s/it]


Layer 1: ['attn_1', 'mlp_1'] submodules


100%|██████████| 20/20 [00:58<00:00,  2.91s/it]


Layer 2: ['attn_2', 'mlp_2'] submodules


100%|██████████| 20/20 [00:54<00:00,  2.74s/it]


Layer 3: ['attn_3', 'mlp_3'] submodules


100%|██████████| 20/20 [00:55<00:00,  2.79s/it]


Layer 4: ['attn_4', 'mlp_4'] submodules


100%|██████████| 20/20 [00:54<00:00,  2.73s/it]


Layer 5: ['attn_5', 'mlp_5'] submodules


100%|██████████| 20/20 [00:55<00:00,  2.75s/it]


In [67]:
def faithfulness(fc, fempty, fm):
    return ((fc - fempty) / (fm - fempty)).item()

In [68]:
fempties

{0: tensor(0.1280, device='cuda:0'),
 1: tensor(1.2364, device='cuda:0'),
 2: tensor(0.8297, device='cuda:0'),
 3: tensor(1.0127, device='cuda:0'),
 4: tensor(0.8813, device='cuda:0'),
 5: tensor(0.5803, device='cuda:0')}

In [69]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
# Define the range of layers
layer_range = list(range(start_layer, len(model.gpt_neox.layers)))

# Create a figure with subplots
fig = make_subplots(rows=1, cols=2)

# Add a scatter trace for the current layer
# Create a scatter plot for each layer
colors = dict(features='blue', features_wo_resids='turquoise', neurons='red', random='black')

for fc_name in fc.keys():        
        for layer in layer_range:
                num_nodes = list(n_nodes[fc_name][layer].values())
                x_title = "Number of nodes"
                x_range = [0, 3.5]
                ys = [faithfulness(x, fempties[layer], fm) for x in fc[fc_name][layer].values()]

                if layer == layer_range[0]:
                        visible = True
                else:
                        visible = False
                
                # Add a scatter trace for the current layer
                fig.add_trace(
                        go.Scatter(
                                x=num_nodes, 
                                y=ys, 
                                mode="markers", 
                                name=f"{fc_name} layer {layer} (#nodes)",
                                text=[f"Threshold: {t:.2e}<br>Nodes: {n}<br>Faithfulness: {f:.2f}" for t, n, f in zip(thresholds, num_nodes, ys)],
                                hoverinfo="text",
                                marker=dict(color=colors[fc_name]),
                                visible=visible
                        ),
                        row=1,
                        col=1
                )
                x_title = "Node_threshold"
                x_range = None
                
                # Add a scatter trace for the current layer
                fig.add_trace(
                        go.Scatter(
                                x=thresholds, 
                                y=ys, 
                                mode="markers", 
                                name=f"{fc_name} layer {layer} (thresh)",
                                text=[f"Threshold: {t:.2e}<br>Nodes: {n}<br>Faithfulness: {f:.2f}" for t, n, f in zip(thresholds, num_nodes, ys)],
                                hoverinfo="text",
                                marker=dict(color=colors[fc_name]),
                                visible=visible
                        ),
                        row=1,
                        col=2
                )

# Instead of a single update_layout call, use update_xaxes and update_yaxes
fig.update_layout(
    title="Faithfulness of single layers of C",
    yaxis=dict(title="Faithfulness"),
    # Additional overall layout configurations here...
)

# Update x-axis for column 1
fig.update_xaxes(
    title_text="Number of nodes",  # First plot x-axis title
    type="log", 
    range=[0, 3.5],  # Set your desired range
    row=1, col=1
)

# Update x-axis for column 2
fig.update_xaxes(
    title_text="Node threshold",  # Second plot x-axis title
    type="log",
    row=1, col=2
)

# Slider updates (keep this if applicable to your figure)
fig.update_layout(
    sliders=[
        dict(
            active=0,
            pad={"t": 50},
            steps=[
                dict(
                    label=f"Layer {layer}",
                    method="update",
                    args=[{"visible": [layer == i for i in sorted(layer_range+layer_range)]}],
                ) for layer in layer_range
            ],
        )
    ],
)
# Show the figure
fig.write_html("faithfulness_per_layer.html")
fig.show()

In [70]:
import chart_studio
import chart_studio.plotly as py
import chart_studio.tools as tls

# Replace 'your_username' and 'your_api_key' with your actual Plotly username and API key
chart_studio.tools.set_credentials_file(username='canrager', api_key='uOXzoIv31B8E76vv7mu3')

py.plot(fig, filename="Faithfulness of single layers of C", auto_open=False)

'https://plotly.com/~canrager/1/'

In [71]:
thresholds

tensor([1.0000e-04, 1.4384e-04, 2.0691e-04, 2.9764e-04, 4.2813e-04, 6.1585e-04,
        8.8587e-04, 1.2743e-03, 1.8330e-03, 2.6367e-03, 3.7927e-03, 5.4556e-03,
        7.8476e-03, 1.1288e-02, 1.6238e-02, 2.3357e-02, 3.3598e-02, 4.8329e-02,
        6.9519e-02, 1.0000e-01])

In [72]:
t.where(t.tensor(list(fc['features'][0].values())) > 0.9)[0][-1]

tensor(18)

In [92]:
# faithfulness_threshs = t.linspace(0.9, 1, 10)
faithfulness_thresh = 0.9
max_thresholds = defaultdict(list)
fc_names = ["neurons", "features"]
# for faithfulness_thresh in faithfulness_threshs:
for fc_name in fc_names:
    for layer in layer_range:
         # select submodules
        criterium = str(layer)
        criterium_name = f'layer_{criterium}'
        selected_submodules = []
        names = []
        for submod in submodules:
            name = submod_names[submod]
            if criterium in name and not 'resid' in name:
                selected_submodules.append(submod)
                names.append(name)
        print(f'Layer {layer}: {names} submodules')

        faithfulness_per_layer = [faithfulness(fc, fempties[layer], fm) for fc in fc[fc_name][layer].values()]
        faithfulness_per_layer = t.tensor(faithfulness_per_layer)
        largest_thresh = thresholds[t.where(faithfulness_per_layer > faithfulness_thresh)[0]]
        if len(largest_thresh) > 0:
            l_thresh = largest_thresh[-1].item()
            if fc_name =="features":
                feat_nodes = {
                    submod : feature_circuit[submod_names[submod]].abs() > l_thresh for submod in selected_submodules
                }
                num_nodes = sum([feat_nodes[submod].act.sum().item() + feat_nodes[submod].resc.sum().item() for submod in selected_submodules])
            elif fc_name == "neurons":
                neuron_nodes = {
                    submod : neuron_circuit[submod_names[submod]].abs() > l_thresh for submod in selected_submodules
                }
                num_nodes = sum([neuron_nodes[submod].act.sum().item() for submod in selected_submodules])
            max_thresholds[faithfulness_thresh].append((layer, fc_name, l_thresh, num_nodes))

Layer 0: ['attn_0', 'mlp_0'] submodules
Layer 1: ['attn_1', 'mlp_1'] submodules
Layer 2: ['attn_2', 'mlp_2'] submodules
Layer 3: ['attn_3', 'mlp_3'] submodules
Layer 4: ['attn_4', 'mlp_4'] submodules
Layer 5: ['attn_5', 'mlp_5'] submodules
Layer 0: ['attn_0', 'mlp_0'] submodules
Layer 1: ['attn_1', 'mlp_1'] submodules
Layer 2: ['attn_2', 'mlp_2'] submodules
Layer 3: ['attn_3', 'mlp_3'] submodules
Layer 4: ['attn_4', 'mlp_4'] submodules
Layer 5: ['attn_5', 'mlp_5'] submodules


In [94]:
# Plot the thresholds over the layers
fig = make_subplots(rows=1, cols=1)

# Create a scatter plot for each layer
# for faithfulness_thresh, max_thresholds in max_thresholds_dict.items():
#         fc_name = "features"
for fc_name in fc_names:
        xs = list(layers)
        ys = [num_nodes for l, f, threshold, num_nodes in max_thresholds[faithfulness_thresh] if f == fc_name]
        print(ys)
        # Add a scatter trace for the current layer
        fig.add_trace(
                go.Scatter(
                        x=xs, 
                        y=ys, 
                        mode="markers", 
                        # name=f"{faithfulness_thresh:.2f}",
                        name=fc_name,
                        marker=dict(color=colors[fc_name])
                ),
                row=1,
                col=1
        )

# Set the layout of the figure
fig.update_layout(
        title=f"Minimum number of nodes to achieve at least {faithfulness_thresh:.2f} faithfulness",
        xaxis=dict(title="Layer"),
        yaxis=dict(title= "Number of nodes"),
        # legend=dict(title="Faithfulness threshold"),
)

# Show the figure
# fig.write_html("thresholds_per_layer.html")
fig.show()

[68, 36, 149, 120, 347, 185]
[37, 7, 74, 217, 86, 85]


In [None]:
submodules[0].__class__.__name__

In [None]:
with model.trace(clean_inputs), t.no_grad():
    submodule_input = submodules[0].input.save()

submodule_input.value[1].keys()

In [None]:
# sender_nodes doct mapping sender: output
# receiver: dict mapping receiver to senders

# tracing call

# sender_outputs = defaultdict(list)
# for all nodes in circuit


## Structure
# get outputs of submodules on patch inputs