In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import random
import yaml
from omegaconf import DictConfig, OmegaConf
from interventions import three_operands
from tqdm.notebook import tqdm
import numpy as np
from functools import partial
import pickle

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer
import transformer_lens.patching as patching
import seaborn as sns
import matplotlib.pyplot as plt

seed = 0

Matplotlib created a temporary cache directory at /tmp/matplotlib-l3josmrd because the default path (/scratch_local/eickhoff/esx208/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.
Fontconfig error: Cannot load default config file: No such file: (null)
Fontconfig error: No writable cache directories
Fontconfig error: No writable cache directories
Fontconfig error: No writable cache directories
Fontconfig error: No writable cache directories
Fontconfig error: No writable cache directories
Fontconfig error: No writable cache directories
Fontconfig error: No writable cache directories


In [2]:
random.seed(seed)
torch.manual_seed(seed)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f2ca9923df0>

In [3]:
model_name = 'EleutherAI/pythia-12b-deduped-v0'
model_name_lens = 'pythia-12b-deduped-v0'
# model_name = 'EleutherAI/pythia-6.9b-deduped-v0'
# model_name_lens = 'pythia-6.9b-deduped-v0'
# model_name_lens = 'facebook/opt-125m'
# model_name = 'mosaicml/mpt-7b'
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = HookedTransformer.from_pretrained_no_processing(
    model_name,
    dtype=torch.float16
)
model.eval()
conf = OmegaConf.load('conf/config.yaml')



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model EleutherAI/pythia-12b-deduped-v0 into HookedTransformer


In [4]:
conf = OmegaConf.load('./conf/config.yaml')
conf.model = model_name
# conf.max_n = 9
conf.max_n = 20
# conf.n_shots = 1
# intervention_list = three_operands.get_arithmetic_data_three_operands(tokenizer, conf)
intervention_list = pickle.load(open('/mnt/qb/work/eickhoff/esx208/arithmetic-lm/data/intervention_1_shots_max_20_arabic_further_templates.pkl', 'rb'))

Exception: Error while attempting to unpickle Tokenizer: data did not match any variant of untagged enum ModelWrapper at line 1 column 1559948

In [5]:
clean_logits, clean_cache = model.run_with_cache(intervention_list[0].base_string_tok)

In [7]:
model.tokens_to_residual_directions(intervention_list[0].base_string_tok).shape

torch.Size([1, 14, 5120])

In [8]:
from jaxtyping import Float, Int, Bool
from typing import Literal, Callable
from torch import Tensor
from transformer_lens.hook_points import HookPoint
import itertools

In [None]:
def patch_or_freeze_head_vectors(
    orig_head_vector: Float[Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    new_cache: ActivationCache,
    orig_cache: ActivationCache,
    head_to_patch: tuple[int, int],
) -> Float[Tensor, "batch pos head_index d_head"]:
    '''
    This helps implement step 2 of path patching. We freeze all head outputs (i.e. set them
    to their values in orig_cache), except for head_to_patch (if it's in this layer) which
    we patch with the value from new_cache.

    head_to_patch: tuple of (layer, head)
        we can use hook.layer() to check if the head to patch is in this layer
    '''
    # Setting using ..., otherwise changing orig_head_vector will edit cache value too
    orig_head_vector[...] = orig_cache[hook.name][...]
    if head_to_patch[0] == hook.layer():
        orig_head_vector[:, :, head_to_patch[1]] = new_cache[hook.name][:, :, head_to_patch[1]]
    return orig_head_vector

def get_path_patch_head_to_final_resid_post(
    model,
    patching_metric,
    interventions
):
    model.reset_hooks()
    results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device="cuda", dtype=torch.float32)
    
    resid_post_hook_name = utils.get_act_name("resid_post", model.cfg.n_layers - 1)
    resid_post_name_filter = lambda name: name == resid_post_hook_name
    
    
def get_logit_diff(logits, answer_token_indices):
    if len(logits.shape)==3:
        # Get final logits only
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()
    

def get_path_patch_head_to_final_resid_post(
    model: HookedTransformer,
    interventions
) -> Float[Tensor, "layer head"]:
    # SOLUTION
    '''
    Performs path patching (see algorithm in appendix B of IOI paper), with:

        sender head = (each head, looped through, one at a time)
        receiver node = final value of residual stream

    Returns:
        tensor of metric values for every possible sender head
    '''
    
    for intervention in interventions:
        model.reset_hooks()
        results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device="cuda", dtype=torch.float32)

        resid_post_hook_name = utils.get_act_name("resid_post", model.cfg.n_layers - 1)
        resid_post_name_filter = lambda name: name == resid_post_hook_name
        
        # ========== Step 1 ==========
        # Gather activations on x_orig and x_new

        # Note the use of names_filter for the run_with_cache function. Using it means we
        # only cache the things we need (in this case, just attn head outputs).
        z_name_filter = lambda name: name.endswith("z")
        if new_cache is None:
            _, new_cache = model.run_with_cache(
                intervention.base_string_tok,
                names_filter=z_name_filter,
                return_type=None
            )
        if orig_cache is None:
            _, orig_cache = model.run_with_cache(
                intervention.alt_string_tok,
                names_filter=z_name_filter,
                return_type=None
            )
            
        answer_token_indices = torch.tensor([[intervention.res_base_tok[0], intervention.pred_res_alt_tok]]).to(model.cfg.device)
        metric = partial(get_logit_diff, answer_token_indices=answer_token_indices)
            
        # Looping over every possible sender head (the receiver is always the final resid_post)
        # Note use of itertools (gives us a smoother progress bar)
        for (sender_layer, sender_head) in tqdm(list(itertools.product(range(model.cfg.n_layers), range(model.cfg.n_heads)))):

            # ========== Step 2 ==========
            # Run on x_orig, with sender head patched from x_new, every other head frozen

            hook_fn = partial(
                patch_or_freeze_head_vectors,
                new_cache=new_cache,
                orig_cache=orig_cache,
                head_to_patch=(sender_layer, sender_head),
            )
            model.add_hook(z_name_filter, hook_fn)

            _, patched_cache = model.run_with_cache(
                intervention.alt_string_tok,
                names_filter=resid_post_name_filter,
                return_type=None
            )
            # if (sender_layer, sender_head) == (9, 9):
            #     return patched_cache
            assert set(patched_cache.keys()) == {resid_post_hook_name}

            # ========== Step 3 ==========
            # Unembed the final residual stream value, to get our patched logits

            patched_logits = model.unembed(model.ln_final(patched_cache[resid_post_hook_name]))

            # Save the results
            results[sender_layer, sender_head] = patching_metric(patched_logits)

        return results

In [13]:
path_patch_head_to_final_resid_post = get_path_patch_head_to_final_resid_post(model, intervention_list[:1])

AttributeError: 'Intervention' object has no attribute 'pred_res_alt_tok'