In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import json
import torch

In [6]:
import sys
sys.path.append('..')

In [7]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x2ba1722abdf0>

In [3]:
import transformers

In [9]:
from datasets import load_dataset

dataset = load_dataset("Anthropic/discrim-eval", 'explicit')

# dataset = load_dataset("starmpcc/Asclepius-Synthetic-Clinical-Notes")

In [10]:
df = dataset['train'].to_pandas()

In [11]:
test_df = df[(df['decision_question_id'] == 0) & (df['age'] == 20) & (df['gender'] == 'female')]

In [12]:
test_df.head()

Unnamed: 0,filled_template,decision_question_id,age,gender,race,fill_type
0,The patient is a 20-year-old white female suff...,0,20.0,female,white,explicit
1,The patient is a 20-year-old Black female suff...,0,20.0,female,Black,explicit
2,The patient is a 20-year-old Asian female suff...,0,20.0,female,Asian,explicit
3,The patient is a 20-year-old Hispanic female s...,0,20.0,female,Hispanic,explicit
4,The patient is a 20-year-old Native American f...,0,20.0,female,Native American,explicit


In [13]:
text = test_df.iloc[0]['filled_template']

### Let's get the mean activations now

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer

In [4]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = 'gpt2'
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

In [17]:
model_config={"n_heads":model.config.n_head,
                "n_layers":model.config.n_layer,
                "resid_dim":model.config.n_embd,
                "name_or_path":model.config.name_or_path,
                "attn_hook_names":[f'transformer.h.{layer}.attn.c_proj' for layer in range(model.config.n_layer)],
                "layer_hook_names":[f'transformer.h.{layer}' for layer in range(model.config.n_layer)]}

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [14]:
inputs = tokenizer(text, return_tensors='pt').to(model.device)

In [None]:
inputs

{'input_ids': tensor([[  464,  5827,   318,   257,  1160,    12,  1941,    12,   727,  2330,
          4048,  7195,   422,  6049, 21919,  4369,    13,  1375,   468,   587,
           319,  5980,  3097,   329,   625,   257,   614,   475,   607,  1535,
           318, 18080,    13,  1375,   468,   587, 14451,   351,   257, 21919,
         17052,   351,   257,   845,  1969, 16217,  2770,  2872,    13,   383,
          5827,   338,   584,  1535, 21337,   389,  4143,   922,    11,   996,
           607, 22456,   318,  4622,  2440,   621,  7306,    13,  1375,   468,
          1913,  1641,  1104,    11,   468, 11238,  9216,    11,   290,   468,
          9141,   477,  2672, 21951, 10991,    13,   220,  2102,    11, 41395,
           329, 23319,   389, 18549,   290,   612,   318,   257,   890,  4953,
          1351,    13,  1318,   318,   691,   257,  5969,   640,  4324,   284,
           466,   262, 23319,   706,   262, 21919,  4329,  1695,   878,   340,
           481,   645,  2392,   307, 1

In [1]:
from baukit import TraceDict

In [47]:
torch.__version__

'1.13.0'

In [18]:
layers = model_config['attn_hook_names']

In [19]:
layers

['transformer.h.0.attn.c_proj',
 'transformer.h.1.attn.c_proj',
 'transformer.h.2.attn.c_proj',
 'transformer.h.3.attn.c_proj',
 'transformer.h.4.attn.c_proj',
 'transformer.h.5.attn.c_proj',
 'transformer.h.6.attn.c_proj',
 'transformer.h.7.attn.c_proj',
 'transformer.h.8.attn.c_proj',
 'transformer.h.9.attn.c_proj',
 'transformer.h.10.attn.c_proj',
 'transformer.h.11.attn.c_proj']

In [20]:
# Access Activations 
with TraceDict(model, layers=layers, retain_input=True, retain_output=False) as td:                
    model(**inputs) # batch_size x n_tokens x vocab_size, only want last token prediction

In [21]:
td

TraceDict([('transformer.h.0.attn.c_proj',
            <baukit.nethook.Trace at 0x2ae96395fc10>),
           ('transformer.h.1.attn.c_proj',
            <baukit.nethook.Trace at 0x2ae96395f8e0>),
           ('transformer.h.2.attn.c_proj',
            <baukit.nethook.Trace at 0x2ae96395e7a0>),
           ('transformer.h.3.attn.c_proj',
            <baukit.nethook.Trace at 0x2ae96395d720>),
           ('transformer.h.4.attn.c_proj',
            <baukit.nethook.Trace at 0x2ae96395e830>),
           ('transformer.h.5.attn.c_proj',
            <baukit.nethook.Trace at 0x2ae96395f580>),
           ('transformer.h.6.attn.c_proj',
            <baukit.nethook.Trace at 0x2ae96395db70>),
           ('transformer.h.7.attn.c_proj',
            <baukit.nethook.Trace at 0x2ae96395e380>),
           ('transformer.h.8.attn.c_proj',
            <baukit.nethook.Trace at 0x2ae96395fbe0>),
           ('transformer.h.9.attn.c_proj',
            <baukit.nethook.Trace at 0x2ae96395f0a0>),
           ('transfo

In [23]:
def split_activations_by_head(activations, model_config):
    new_shape = activations.size()[:-1] + (model_config['n_heads'], model_config['resid_dim']//model_config['n_heads']) # split by head: + (n_attn_heads, hidden_size/n_attn_heads)
    activations = activations.view(*new_shape)  # (batch_size, n_tokens, n_heads, head_hidden_dim)
    return activations

In [25]:
stack_initial = torch.vstack([split_activations_by_head(td[layer].input, model_config) for layer in model_config['attn_hook_names']]).permute(0,2,1,3)

In [27]:
stack_initial.shape

torch.Size([12, 12, 157, 64])