In [12]:
%load_ext autoreload
%autoreload 2

import sys
import os
parent_dir = os.path.abspath('..')
sys.path.append(parent_dir)

from nnsight import LanguageModel
from activation_utils import SparseAct
import torch as t
import plotly.graph_objects as go
from loading_utils import load_examples
from dictionary_learning import AutoEncoder
from dictionary_learning.dictionary import IdentityDict
from dictionary_loading_utils import load_saes_and_submodules
from ablation import run_with_ablations
from scipy import interpolate
import math
from tqdm import tqdm
from statistics import stdev
import hashlib

DEBUGGING = False

tracer_kwargs = {"validate" : DEBUGGING, "scan" : DEBUGGING}

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
model_name = "google/gemma-2-2b"
device = 'cuda:0'
dtype = {
    "EleutherAI/pythia-70m-deduped": t.float32,
    "google/gemma-2-2b" : t.bfloat16,
}[model_name]

model = LanguageModel(model_name, attn_implementation="eager", torch_dtype=dtype, device_map=device, dispatch=True)

submodules, dictionaries = load_saes_and_submodules(model, include_embed=False, dtype=dtype, device=device)

neuron_dicts = {
    submod : IdentityDict(dictionaries[submod].activation_dim) for submod in submodules
}

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading Gemma SAEs:  96%|█████████▋| 26/27 [01:32<00:03,  3.56s/it]


In [17]:
start_layer = {
    "EleutherAI/pythia-70m-deduped": 2,
    "google/gemma-2-2b": 8,
}[model_name]

submodules = [
    s for s in submodules if int(s.name.split('_')[-1]) >= start_layer
]


In [18]:
# use mean ablation
ablation_fn = lambda x: x.mean(dim=0).expand_as(x)

In [19]:
# get m(C) for the circuit obtained by thresholding nodes with the given threshold
def get_fcs(
        dataset,
        model,
        submodules,
        dictionaries,
        ablation_fn,
        thresholds,
        handle_errors = 'default', # also 'remove' or 'resid_only'
        use_neurons = False,
        random = False,
        n_examples=40,
):
    if "gemma-2" in model.config._name_or_path:
        model_name = "gemma-2-2b"
    else:
        model_name = "pythia-70m-deduped"

    hash_str = dataset + model_name + str([s.name for s in submodules]) + str(thresholds) + handle_errors + str(use_neurons) + str(random) + str(n_examples)
    hash = hashlib.md5(hash_str.encode()).hexdigest()
    if os.path.exists(f"faithfulness/{hash}.pt"):
        return t.load(f"faithfulness/{hash}.pt")

    circuit_path = f"../circuits/{model_name}_{dataset}_train_n100_aggnone" + (
        "_neurons" if use_neurons else ""
    ) + "_nodeall.pt"
    circuit = t.load(circuit_path)
    circuit = circuit['nodes']
    
    examples = load_examples(
        f"../data/{dataset}_test.json", n_examples, model, use_min_length_only=True,
    )

    clean_inputs = [
        e['clean_prefix'] for e in examples
    ]
    clean_answer_idxs = t.tensor(
        [
            model.tokenizer(e['clean_answer']).input_ids[-1] for e in examples
        ],
        dtype=t.long,
        device=device
    )
    patch_inputs = [
        e['patch_prefix'] for e in examples
    ]
    patch_answer_idxs = t.tensor(
        [
            model.tokenizer(e['patch_answer']).input_ids[-1] for e in examples
        ],
        dtype=t.long,
        device=device
    )

    def metric_fn(model):
        logits = model.output.logits[:,-1,:]
        return (
            - t.gather(logits, dim=-1, index=patch_answer_idxs.view(-1, 1)).squeeze(-1) + \
            t.gather(logits, dim=-1, index=clean_answer_idxs.view(-1, 1)).squeeze(-1)
        )

    out = {}

    with t.no_grad():
        with model.trace(clean_inputs):
            metric = metric_fn(model).save()
        fm = metric.value.mean().item()

        out["fm"] = fm

        # get m(∅)
        fempty_nodes = {
            submod : SparseAct(
                act=t.zeros(dictionaries[submod].dict_size, dtype=t.bool),
                resc=t.zeros(1, dtype=t.bool)
            ).to(device)
            for submod in submodules
        }

        fempty = run_with_ablations(
            clean_inputs,
            patch_inputs,
            model,
            submodules,
            dictionaries,
            nodes=fempty_nodes,
            metric_fn=metric_fn,
            ablation_fn=ablation_fn,
        ).mean().item()

        out["fempty"] = fempty

        for threshold in tqdm(thresholds):
            out[threshold] = {}
            nodes = {
                submod : circuit[submod.name].abs() > threshold
                for submod in submodules
            }
            if handle_errors == 'remove':
                for k in nodes:
                    nodes[k].resc = t.zeros_like(nodes[k].resc, dtype=t.bool)
            elif handle_errors == 'resid_only':
                for k in nodes:
                    if "resid" not in k.name:
                        nodes[k].resc = t.zeros_like(nodes[k].resc, dtype=t.bool)
            n_nodes = sum(
                [
                    n.act.sum() + n.resc.sum() for n in nodes.values()
                ]
            ).item()
            if random:
                total_nodes = sum(
                    [
                        n.act.numel() + n.resc.numel() for n in nodes.values()
                    ]
                )
                p = n_nodes / total_nodes
                for k in nodes:
                    nodes[k].act = t.bernoulli(
                        t.ones_like(nodes[k].act, dtype=t.float) * p
                    ).to(device).to(dtype=t.bool)
                    nodes[k].resc = t.ones_like(nodes[k].resc, dtype=t.bool).to(device)
                out[threshold]['n_nodes'] = sum(
                    [
                        n.act.sum() + n.resc.sum() for n in nodes.values()
                    ]
                ).item()
            else:
                out[threshold]['n_nodes'] = n_nodes
            
            out[threshold]["fc"] = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                submodules,
                dictionaries,
                nodes=nodes,
                metric_fn=metric_fn,
                ablation_fn=ablation_fn,
            ).mean().item()
            out[threshold]["fccomp"] = run_with_ablations(
                clean_inputs,
                patch_inputs,
                model,
                submodules,
                dictionaries,
                nodes=nodes,
                metric_fn=metric_fn,
                ablation_fn=ablation_fn,
                complement=True
            ).mean().item()
            out[threshold]["faithfulness"] = (
                out[threshold]["fc"] - out["fempty"]
            ) / (out["fm"] - out["fempty"])
            out[threshold]["completeness"] = (
                out[threshold]["fccomp"] - out["fempty"]
            ) / (out["fm"] - out["fempty"])

    t.save(out, f"faithfulness/{hash}.pt")

    return out


In [22]:
datasets = ["rc", "nounpp", "simple", "within_rc"]
thresholds = t.logspace(-4, 0, 15).tolist()


outs = {
    "features" : {
        dataset : get_fcs(
            dataset,
            model,
            submodules,
            dictionaries,
            ablation_fn,
            thresholds = thresholds,
        )
        for dataset in datasets
    },
    "features_wo_errs" : {
        dataset : get_fcs(
            dataset,
            model,
            submodules,
            dictionaries,
            ablation_fn,
            thresholds = thresholds,
            handle_errors='remove'
        )
        for dataset in datasets
    },
    "features_wo_some_errs" : {
        dataset : get_fcs(
            dataset,
            model,
            submodules,
            dictionaries,
            ablation_fn,
            thresholds = thresholds,
            handle_errors='resid_only'
        )
        for dataset in datasets
    },
    "neurons" : {
        dataset : get_fcs(
            dataset,
            model,
            submodules,
            neuron_dicts,
            ablation_fn,
            thresholds = thresholds,
            use_neurons=True
        )
        for dataset in datasets
    },
}

In [33]:
# plot faithfulness results
fig = go.Figure()

colors = {
    'features' : 'blue',
    'features_wo_errs' : 'red',
    'features_wo_some_errs' : 'green',
    'neurons' : 'purple',
}

for setting, subouts in outs.items():

    x_min = max([min(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) + 1
    x_max = min([max(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) - 1
    fs = {
        dataset : interpolate.interp1d([subouts[dataset][t]['n_nodes'] for t in thresholds], [subouts[dataset][t]['faithfulness'] for t in thresholds])
        for dataset in datasets
    }
    xs = t.logspace(math.log10(x_min), math.log10(x_max), 100, 10).tolist()

    for dataset in datasets:
        fig.add_trace(go.Scatter(
            x = [subouts[dataset][t]['n_nodes'] for t in thresholds],
            y = [subouts[dataset][t]['faithfulness'] for t in thresholds],
            mode='lines', line=dict(color=colors[setting]), opacity=0.17, showlegend=False
        ))

    fig.add_trace(go.Scatter(
        x=xs,
        y=[ sum([f(x) for f in fs.values()]) / len(fs) for x in xs ],
        mode='lines', line=dict(color=colors[setting]), name=setting
    ))

fig.update_xaxes(range=(0, 30000))
fig.update_yaxes(range=(0, 1.1))

fig.update_layout(
    xaxis_title='Nodes',
    yaxis_title='Faithfulness',
    width=800,
    height=375,
    # set white background color
    plot_bgcolor='rgba(0,0,0,0)',
    # add grey gridlines
    yaxis=dict(gridcolor='rgb(200,200,200)',mirror=True,ticks='outside',showline=True),
    xaxis=dict(gridcolor='rgb(200,200,200)', mirror=True, ticks='outside', showline=True),

)

fig.show()
# fig.write_image('faithfulness.pdf')

In [32]:
# plot completeness results
fig = go.Figure()

colors = {
    'features' : 'blue',
    'features_wo_errs' : 'red',
    'features_wo_some_errs' : 'green',
    'neurons' : 'purple'
}

for setting, subouts in outs.items():

    x_min = max([min(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) + 1
    x_max = min([max(subouts[dataset][t]['n_nodes'] for t in thresholds) for dataset in datasets]) - 1
    fs = {
        dataset : interpolate.interp1d([subouts[dataset][t]['n_nodes'] for t in thresholds], [subouts[dataset][t]['completeness'] for t in thresholds])
        for dataset in datasets
    }
    xs = t.logspace(math.log10(x_min), math.log10(x_max), 100, 10).tolist()

    for dataset in datasets:
        fig.add_trace(go.Scatter(
            x = [subouts[dataset][t]['n_nodes'] for t in thresholds],
            y = [subouts[dataset][t]['completeness'] for t in thresholds],
            mode='lines', line=dict(color=colors[setting]), opacity=0.17, showlegend=False
        ))
    fig.add_trace(go.Scatter(
        x=xs,
        y=[ sum([f(x) for f in fs.values()]) / len(fs) for x in xs ],
        mode='lines', line=dict(color=colors[setting]), name=setting
    ))

fig.update_xaxes(range=(0,3000))
fig.update_yaxes(range=(-.15, 1))

fig.update_layout(
    xaxis_title='Nodes',
    yaxis_title='Faithfulness',
    width=800,
    height=375,
    # set white background color
    plot_bgcolor='rgba(0,0,0,0)',
    # add grey gridlines
    yaxis=dict(gridcolor='rgb(200,200,200)',mirror=True,ticks='outside',showline=True),
    xaxis=dict(gridcolor='rgb(200,200,200)', mirror=True, ticks='outside', showline=True),
)
# fig.show()
fig.write_image('completeness.pdf')