In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import math
import os
from pathlib import Path

import torch
from huggingface_hub import snapshot_download
from transformer_lens import HookedEncoder, HookedTransformerConfig
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

In [3]:
os.chdir('..')
print('Changed working directory to parent directory')

with open(os.path.expanduser('~/.huggingface/token')) as f:
    os.environ['HF_TOKEN'] = f.read().strip()
    print(f'Hugging Face token loaded: {os.environ['HF_TOKEN'][:3]}...')

torch.set_grad_enabled(False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Changed working directory to parent directory
Hugging Face token loaded: hf_...
Using device: cpu


# 1. Download weights from Hugging Face

In [4]:
checkpoint = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english'

In [5]:
if not Path(checkpoint).exists():
    print(f'Downloading {checkpoint}...')
    local_dir = Path(checkpoint).resolve()
    snapshot_download(repo_id=checkpoint, local_dir=local_dir, local_dir_use_symlinks=False)
    print(f'Downloaded {checkpoint}.')

# 1. Model setup

In [6]:
tokenizer = DistilBertTokenizer.from_pretrained(checkpoint)
model = DistilBertForSequenceClassification.from_pretrained(checkpoint).to(device)
model.config

DistilBertConfig {
  "_attn_implementation_autoset": true,
  "_name_or_path": "distilbert/distilbert-base-uncased-finetuned-sst-2-english",
  "activation": "gelu",
  "architectures": [
    "DistilBertForSequenceClassification"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "finetuning_task": "sst-2",
  "hidden_dim": 3072,
  "id2label": {
    "0": "NEGATIVE",
    "1": "POSITIVE"
  },
  "initializer_range": 0.02,
  "label2id": {
    "NEGATIVE": 0,
    "POSITIVE": 1
  },
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "output_past": true,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "torch_dtype": "float32",
  "transformers_version": "4.49.0",
  "vocab_size": 30522
}

In [None]:
cfg = HookedTransformerConfig(
  model_name=checkpoint,
  tokenizer_name=checkpoint,
  device=device,
  d_model=model.config.dim,
  d_head=model.config.dim // model.config.n_heads,
  n_layers=model.config.n_layers,
  n_ctx=model.config.max_position_embeddings,
  act_fn=model.config.activation,
  original_architecture=model.config.architectures[0],
  # below are optional
  n_heads=model.config.n_heads,
  d_mlp=model.config.hidden_dim,
  d_vocab=model.config.vocab_size,
  d_vocab_out=len(model.config.id2label),
  use_attn_scale=True,                                                # as in Attention is All You Need -- may need to be adjusted
  eps=1e-12,                                                          # chatGPT query -- may need to be adjusted
  use_hook_tokens=False,                                              # memory intensive, but may be tested later
  use_attn_result=False,                                              # memory intensive, but may be tested later
  use_split_qkv_input=False,                                          # memory intensive, but may be tested later
  use_hook_mlp_in=False,                                              # memory intensive, but may be tested later
  use_attn_in=False,                                                  # memory intensive, but may be tested later
  ungroup_grouped_query_attention=False,                              # wild guess
  attn_scale=math.sqrt(model.config.dim // model.config.n_heads),     # lib default
)

In [26]:
huggingface_token = os.environ.get("HF_TOKEN", "") 

In [20]:
official_model_name = '/Users/marcosf/Desktop/research/mech_interp/distilbert/distilbert-base-uncased-finetuned-sst-2-english'
official_model_name

'/Users/marcosf/Desktop/research/mech_interp/distilbert/distilbert-base-uncased-finetuned-sst-2-english'

In [11]:
dtype = torch.float32
kwargs = {}
dtype, kwargs

(torch.float32, {})

In [19]:
from transformers import DistilBertForSequenceClassification

In [21]:
hf_model = DistilBertForSequenceClassification.from_pretrained(
    official_model_name,
    torch_dtype=dtype,
    token=huggingface_token if len(huggingface_token) > 0 else None
    **kwargs,
)

In [22]:
hf_model

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


In [None]:
hooked_model = HookedEncoder(cfg)

In [None]:
from transformer_lens import loading_from_pretrained as loading

state_dict = loading.get_pretrained_state_dict(checkpoint, cfg)
state_dict

In [None]:
hooked_model.state_dict()

In [None]:
inputs = tokenizer(
  'I am in love with it.', 
  return_tensors='pt', 
  padding='max_length', 
  truncation=True,
  max_length=512
).to(device)

In [None]:
import torch.nn.functional as F

In [None]:
logits = hooked_model(inputs['input_ids'])
logits

# Sources

1. [Ground truth - Arena::Indirect Object Identification](https://arena-chapter1-transformer-interp.streamlit.app/[1.4.1]_Indirect_Object_Identification)
2. [Interpretability in the wild: A circuit for indirect object identification in GPT-2 small, by Wang, K, et. al.](https://arxiv.org/pdf/2211.00593)
3. [NOTEBOOK - Exploratory Analysis Demo, by Neel Nanda](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=WXktSe0CvBdh)
4. [An analogy for understanding transformers, by Callum McDougall](https://www.lesswrong.com/posts/euam65XjigaCJQkcN/an-analogy-for-understanding-transformers)