In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import math
import os
from pathlib import Path

import torch
from huggingface_hub import snapshot_download
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from transformer_lens import HookedTransformerConfig
from transformer_lens.loading_from_pretrained import get_pretrained_state_dict

In [3]:
os.chdir('..')
print('Changed working directory to parent directory')

with open(os.path.expanduser('~/.huggingface/token')) as f:
    os.environ['HF_TOKEN'] = f.read().strip()
    print(f'Hugging Face token loaded: {os.environ['HF_TOKEN'][:3]}...')

torch.set_grad_enabled(False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Changed working directory to parent directory
Hugging Face token loaded: hf_...
Using device: cpu


In [4]:
checkpoint = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english'

### 1. Download weights from Hugging Face

In [5]:
if not Path(checkpoint).exists():
    print(f'Downloading {checkpoint}...')
    snapshot_download(repo_id=checkpoint, local_dir=Path(checkpoint).resolve(), local_dir_use_symlinks=False)
    print(f'Downloaded {checkpoint}.')

Downloading distilbert/distilbert-base-uncased-finetuned-sst-2-english...


Fetching 17 files: 100%|██████████| 17/17 [01:46<00:00,  6.26s/it]

Downloaded distilbert/distilbert-base-uncased-finetuned-sst-2-english.





### 2. Download model from Hugging Face

In [None]:
hf_model = DistilBertForSequenceClassification.from_pretrained(
    Path(checkpoint).resolve(),
    torch_dtype=torch.float32,
    token=os.environ.get("HF_TOKEN", "") 
)

In [None]:
hf_model.config

In [None]:
hf_model

### 3. Hooked `state_dict`

In [None]:
cfg = HookedTransformerConfig(
  model_name=checkpoint,
  tokenizer_name=checkpoint,               
  d_model=hf_model.config.dim,
  d_head=hf_model.config.dim // hf_model.config.n_heads,
  d_mlp=hf_model.config.hidden_dim,
  n_ctx=hf_model.config.max_position_embeddings,
  n_heads=hf_model.config.n_heads,
  n_layers=hf_model.config.n_layers,
  act_fn=hf_model.config.activation,
  original_architecture=hf_model.config.architectures[0],
  d_vocab=hf_model.config.vocab_size,
  d_vocab_out=len(hf_model.config.id2label),
  eps=1e-12,                                                          
  device=device,
  attn_scale=math.sqrt(hf_model.config.dim // hf_model.config.n_heads),     # lib default
  use_hook_tokens=False,                                                    # memory intensive, but may be tested later
  use_attn_result=False,                                                    # memory intensive, but may be tested later
  use_split_qkv_input=False,                                                # memory intensive, but may be tested later
  use_hook_mlp_in=False,                                                    # memory intensive, but may be tested later
  use_attn_in=False,                                                        # memory intensive, but may be tested later
  # use_attn_scale=True,                                                    # as in Attention is All You Need -- may need to be adjusted
  # ungroup_grouped_query_attention=False,                                  # wild guess
)

In [None]:
state_dict = get_pretrained_state_dict(checkpoint, cfg)

### 4. Hooked `Encoder`

In [None]:
tokenizer = DistilBertTokenizer.from_pretrained(checkpoint)

In [None]:
hooked_model = '?'

In [None]:
# inputs = tokenizer(
#   'I am in love with it.', 
#   return_tensors='pt', 
#   padding='max_length', 
#   truncation=True,
#   max_length=512
# ).to(device)

# logits = hooked_model(inputs['input_ids'])
# logits

# Sources

1. [Ground truth - Arena::Indirect Object Identification](https://arena-chapter1-transformer-interp.streamlit.app/[1.4.1]_Indirect_Object_Identification)
2. [Interpretability in the wild: A circuit for indirect object identification in GPT-2 small, by Wang, K, et. al.](https://arxiv.org/pdf/2211.00593)
3. [NOTEBOOK - Exploratory Analysis Demo, by Neel Nanda](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=WXktSe0CvBdh)
4. [An analogy for understanding transformers, by Callum McDougall](https://www.lesswrong.com/posts/euam65XjigaCJQkcN/an-analogy-for-understanding-transformers)