In [1]:
import warnings
warnings.filterwarnings('ignore')

In [18]:
import math
import os
from pathlib import Path

import torch
from huggingface_hub import snapshot_download
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from transformer_lens import HookedEncoder, HookedTransformerConfig
from transformer_lens.loading_from_pretrained import get_pretrained_state_dict

In [3]:
os.chdir('..')
print('Changed working directory to parent directory')

with open(os.path.expanduser('~/.huggingface/token')) as f:
    os.environ['HF_TOKEN'] = f.read().strip()
    print(f'Hugging Face token loaded: {os.environ['HF_TOKEN'][:3]}...')

torch.set_grad_enabled(False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Changed working directory to parent directory
Hugging Face token loaded: hf_...
Using device: cpu


In [4]:
checkpoint = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english'

### 1. Download weights from Hugging Face

In [5]:
if not Path(checkpoint).exists():
    print(f'Downloading {checkpoint}...')
    snapshot_download(repo_id=checkpoint, local_dir=Path(checkpoint).resolve(), local_dir_use_symlinks=False)
    print(f'Downloaded {checkpoint}.')

### 2. Download model from Hugging Face

In [6]:
hf_model = DistilBertForSequenceClassification.from_pretrained(
    Path(checkpoint).resolve(),
    torch_dtype=torch.float32,
    token=os.environ.get("HF_TOKEN", "") 
)

In [7]:
hf_model.config

DistilBertConfig {
  "_attn_implementation_autoset": true,
  "_name_or_path": "/Users/marcosf/Desktop/research/mech_interp/distilbert/distilbert-base-uncased-finetuned-sst-2-english",
  "activation": "gelu",
  "architectures": [
    "DistilBertForSequenceClassification"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "finetuning_task": "sst-2",
  "hidden_dim": 3072,
  "id2label": {
    "0": "NEGATIVE",
    "1": "POSITIVE"
  },
  "initializer_range": 0.02,
  "label2id": {
    "NEGATIVE": 0,
    "POSITIVE": 1
  },
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "output_past": true,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.49.0",
  "vocab_size": 30522
}

In [8]:
hf_model

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


### 3. Hooked `state_dict`

In [None]:
cfg = HookedTransformerConfig(
  model_name=checkpoint,
  tokenizer_name=checkpoint,               
  d_model=hf_model.config.dim,
  d_head=hf_model.config.dim // hf_model.config.n_heads,
  d_mlp=hf_model.config.hidden_dim,
  n_ctx=hf_model.config.max_position_embeddings,
  n_heads=hf_model.config.n_heads,
  n_layers=hf_model.config.n_layers,
  act_fn=hf_model.config.activation,
  original_architecture=hf_model.config.architectures[0],
  d_vocab=hf_model.config.vocab_size,
  d_vocab_out=len(hf_model.config.id2label),
  eps=1e-12,                                                          
  device=device,
  attn_scale=math.sqrt(hf_model.config.dim // hf_model.config.n_heads),     # lib default
  use_hook_tokens=False,                                                    # memory intensive, but may be tested later
  use_attn_result=False,                                                    # memory intensive, but may be tested later
  use_split_qkv_input=False,                                                # memory intensive, but may be tested later
  use_hook_mlp_in=False,                                                    # memory intensive, but may be tested later
  use_attn_in=False,                                                        # memory intensive, but may be tested later
  # use_attn_scale=True,                                                    # as in Attention is All You Need -- may need to be adjusted
  # ungroup_grouped_query_attention=False,                                  # wild guess
)

In [20]:
state_dict = get_pretrained_state_dict(checkpoint, cfg)

### 4. Hooked `Encoder`

In [11]:
tokenizer = DistilBertTokenizer.from_pretrained(checkpoint)

In [21]:
hooked_model = HookedEncoder(cfg, tokenizer)

Moving model to device:  cpu


In [22]:
hm_sd = hooked_model.state_dict()

In [23]:
hm_sd.keys()

odict_keys(['embed.embed.W_E', 'embed.pos_embed.W_pos', 'embed.token_type_embed.W_token_type', 'embed.ln.w', 'embed.ln.b', 'blocks.0.attn.W_Q', 'blocks.0.attn.W_O', 'blocks.0.attn.b_Q', 'blocks.0.attn.b_O', 'blocks.0.attn.W_K', 'blocks.0.attn.W_V', 'blocks.0.attn.b_K', 'blocks.0.attn.b_V', 'blocks.0.attn.mask', 'blocks.0.attn.IGNORE', 'blocks.0.ln1.w', 'blocks.0.ln1.b', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.0.ln2.w', 'blocks.0.ln2.b', 'blocks.1.attn.W_Q', 'blocks.1.attn.W_O', 'blocks.1.attn.b_Q', 'blocks.1.attn.b_O', 'blocks.1.attn.W_K', 'blocks.1.attn.W_V', 'blocks.1.attn.b_K', 'blocks.1.attn.b_V', 'blocks.1.attn.mask', 'blocks.1.attn.IGNORE', 'blocks.1.ln1.w', 'blocks.1.ln1.b', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out', 'blocks.1.ln2.w', 'blocks.1.ln2.b', 'blocks.2.attn.W_Q', 'blocks.2.attn.W_O', 'blocks.2.attn.b_Q', 'blocks.2.attn.b_O', 'blocks.2.attn.W_K', 'blocks.2.attn.W_V', 'blocks

In [24]:
state_dict.keys()

dict_keys(['embed.embed.W_E', 'embed.pos_embed.W_pos', 'embed.ln.w', 'embed.ln.b', 'blocks.0.attn.W_Q', 'blocks.0.attn.b_Q', 'blocks.0.attn.W_K', 'blocks.0.attn.b_K', 'blocks.0.attn.W_V', 'blocks.0.attn.b_V', 'blocks.0.attn.W_O', 'blocks.0.attn.b_O', 'blocks.0.ln1.w', 'blocks.0.ln1.b', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.0.ln2.w', 'blocks.0.ln2.b', 'blocks.1.attn.W_Q', 'blocks.1.attn.b_Q', 'blocks.1.attn.W_K', 'blocks.1.attn.b_K', 'blocks.1.attn.W_V', 'blocks.1.attn.b_V', 'blocks.1.attn.W_O', 'blocks.1.attn.b_O', 'blocks.1.ln1.w', 'blocks.1.ln1.b', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out', 'blocks.1.ln2.w', 'blocks.1.ln2.b', 'blocks.2.attn.W_Q', 'blocks.2.attn.b_Q', 'blocks.2.attn.W_K', 'blocks.2.attn.b_K', 'blocks.2.attn.W_V', 'blocks.2.attn.b_V', 'blocks.2.attn.W_O', 'blocks.2.attn.b_O', 'blocks.2.ln1.w', 'blocks.2.ln1.b', 'blocks.2.mlp.W_in', 'blocks.2.mlp.b_in', 'blocks.2.mlp.W_out

In [25]:
len(hm_sd.keys()), len(state_dict.keys())

(123, 104)

In [26]:
hm_sd.keys() - state_dict.keys()

{'blocks.0.attn.IGNORE',
 'blocks.0.attn.mask',
 'blocks.1.attn.IGNORE',
 'blocks.1.attn.mask',
 'blocks.2.attn.IGNORE',
 'blocks.2.attn.mask',
 'blocks.3.attn.IGNORE',
 'blocks.3.attn.mask',
 'blocks.4.attn.IGNORE',
 'blocks.4.attn.mask',
 'blocks.5.attn.IGNORE',
 'blocks.5.attn.mask',
 'embed.token_type_embed.W_token_type',
 'mlm_head.W',
 'mlm_head.b',
 'mlm_head.ln.b',
 'mlm_head.ln.w',
 'nsp_head.W',
 'nsp_head.b',
 'pooler.W',
 'pooler.b',
 'unembed.W_U',
 'unembed.b_U'}

In [None]:
hooked_model.load_state_dict(state_dict, strict=False)

In [None]:
inputs = tokenizer(
  'I am in love with it.', 
  return_tensors='pt', 
  padding='max_length', 
  truncation=True,
  max_length=512
).to(device)

In [None]:
import torch.nn.functional as F

In [None]:
logits = hooked_model(inputs['input_ids'])
logits

# Sources

1. [Ground truth - Arena::Indirect Object Identification](https://arena-chapter1-transformer-interp.streamlit.app/[1.4.1]_Indirect_Object_Identification)
2. [Interpretability in the wild: A circuit for indirect object identification in GPT-2 small, by Wang, K, et. al.](https://arxiv.org/pdf/2211.00593)
3. [NOTEBOOK - Exploratory Analysis Demo, by Neel Nanda](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=WXktSe0CvBdh)
4. [An analogy for understanding transformers, by Callum McDougall](https://www.lesswrong.com/posts/euam65XjigaCJQkcN/an-analogy-for-understanding-transformers)