In [2]:
from collections import defaultdict, namedtuple
import logging
from typing import cast, Dict, List, Tuple, Union
from typing_extensions import get_args, Literal

import os
import numpy as np
import torch
import yaml
import argparse
import pandas as pd
from tqdm import tqdm
import sys
import random

sys.path.append('/home/src/experiments/utils')
sys.path.append('/home/src/experiments')

from utils.probing_utils import AccuracyProbe
from utils.data_utils import makeHooks, decomposeHeads, decomposeSingleHead

from transformers import BertModel, BertTokenizer
from transformer_lens import HookedEncoder

  warn(f"Failed to load image Python extension: {e}")


# CHECK RESID STREAM

In [3]:
model = BertModel.from_pretrained('bert-base-cased')
model_hooked = HookedEncoder.from_pretrained('bert-base-cased')

If using BERT for interpretability research, keep in mind that BERT has some significant architectural differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning that the last LayerNorm in a block cannot be folded.


Moving model to device:  cuda
Loaded pretrained model bert-base-cased into HookedTransformer


### Attention Hooks

In [4]:
def makeHooks(model, cache : defaultdict, remove_batch_dim=False, device='cpu'):
    def hook_self_attention(module, input, output):
        if isinstance(output, tuple):
            output = output[0]
        if remove_batch_dim:
            cache['attention'].append(output[0].detach().to(device))
        else:
            cache['attention'].append(output.detach().to(device))

    def hook_attn_output(module, input, output):
        if isinstance(output, tuple):
            output = output[0]
        if remove_batch_dim:
            cache['attention_out'].append(output[0].detach().to(device))
        else:
            cache['attention_out'].append(output.detach().to(device))
            
    for layer in model.encoder.layer:
        layer.attention.self.register_forward_hook(hook_self_attention)
        
    for layer in model.encoder.layer:
        layer.attention.output.dense.register_forward_hook(hook_attn_output)
    return cache

In [5]:
def decomposeHeads(model, attention_vectors):
    """
    Decompose attention heads into subspaces.
    `(cache['attention'][0] @ model.encoder.layer[0].attention.output.dense.weight.data.T) + model.encoder.layer[0].attention.output.dense.bias`
    """
    attention_head_dict = {}
    assert len(attention_vectors) == 12
    for i, attn_layer in enumerate(tqdm(attention_vectors, desc='Decomposing attention heads')):
        output_matrix = model.encoder.layer[i].attention.output.dense.weight.data.T
        for j in range(model.config.num_attention_heads):
            output_slice = output_matrix[j*64:(j+1)*64, :]
            if len(attn_layer.shape) == 2:
                attn_slice = attn_layer[:, j*64:(j+1)*64]
            elif len(attn_layer.shape) == 3:
                attn_slice = attn_layer[:, :, j*64:(j+1)*64] 
                ## if batch dim intact
            else:
                raise ValueError('Attention layer has unexpected shape')
            
            attention_head_dict[AttnHead(i, j)] =  attn_slice @ output_slice
    return attention_head_dict

def decomposeSingleHead(model, attention_vector, layer, head):
    """
    Decompose attention heads into subspaces.
    layer is 1-indexed so use layer-1
    """
    attention_head_dict = {}
    output_matrix = model.encoder.layer[layer-1].attention.output.dense.weight.data.T
    output_slice = output_matrix[head*64:(head+1)*64, :]
    if len(attention_vector.shape) == 2:
        attn_slice = attention_vector[:, head*64:(head+1)*64]
    elif len(attention_vector.shape) == 3:
        attn_slice = attention_vector[:, :, head*64:(head+1)*64] 
        ## if batch dim intact
    else:
        raise ValueError('Attention layer has unexpected shape')
    return attn_slice @ output_slice.cpu().numpy()

def getAttnOMatrix(model, layer):
    output_matrix = model.encoder.layer[layer].attention.output.dense.weight.data.T

In [6]:
cache = defaultdict(list)
cache = makeHooks(model, cache)

### Example

In [36]:
token_example = torch.tensor([[101, 40, 23, 12, 34, 2, 103, 4323, 12, 102]])
_ = model(token_example)
_, cache_h = model_hooked.run_with_cache(token_example)
out_head = cache_h.stack_head_results(layer=1, return_labels=True, incl_remainder=True)

Tried to stack head results when they weren't cached. Computing head results now


In [168]:
out_head_dict = {val: ten for ten, val in zip(out_head[0], out_head[1])}
attn_out_l0 = (cache_h['blocks.0.attn.hook_result'].sum(axis=2) + model.encoder.layer[0].attention.output.dense.bias)
calc_l0_h0 = decomposeSingleHead(model, cache['attention'][0].detach().numpy(), 1, 0)

In [169]:
assert torch.all(cache_h['blocks.0.attn.hook_result'][:, :, 0, :] == out_head_dict['L0H0'])
assert np.isclose(cache_h['blocks.0.attn.hook_result'][:, :, 0, :].detach().numpy(), calc_l0_h0, 1e-4).mean()  > 0.98
assert np.isclose(attn_out.detach().numpy(), cache_h['blocks.0.hook_attn_out'].detach().numpy(), 1e-4).mean() > 0.98
assert np.isclose(cache['attention_out'][0], cache_h['blocks.0.hook_attn_out'].detach().numpy(), 1e-4).mean() > 0.98

## making sure all ways to calculate heads and attention out are the same

In [141]:
AttnHead = namedtuple("AttnHead", "layer head")
heads_decomp = decomposeHeads(model, cache['attention'])

In [204]:
out_head_all = cache_h.stack_head_results(layer=-1, return_labels=True, incl_remainder=False)
out_head_all_dict = {val: ten for ten, val in zip(out_head_all[0], out_head_all[1])}

In [1]:
## attn ones were messed up by 0-indexing

# FIX ATTENTION PROBING

In [7]:
from aheads import generate_algo_task_dataset

In [None]:
generate_algo_task_dataset