In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import os
from pathlib import Path

import torch
import torch.nn.functional as F
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import transformer_lens

In [3]:
os.chdir('..')
print('Changed working directory to parent directory')

with open(os.path.expanduser('~/.huggingface/token')) as f:
    os.environ['HF_TOKEN'] = f.read().strip()
    print(f'Hugging Face token loaded: {os.environ['HF_TOKEN'][:3]}...')

torch.set_grad_enabled(False)
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Changed working directory to parent directory
Hugging Face token loaded: hf_...
Using device: mps


In [4]:
checkpoint = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english'

### 2. Download model from Hugging Face

In [5]:
tokenizer = DistilBertTokenizer.from_pretrained(checkpoint)

In [6]:
hf_model = DistilBertForSequenceClassification.from_pretrained(
    Path(checkpoint).resolve(),
    torch_dtype=torch.float32,
    token=os.environ.get("HF_TOKEN", "") 
)

In [7]:
hf_model.config

DistilBertConfig {
  "_attn_implementation_autoset": true,
  "_name_or_path": "/Users/marcosf/Desktop/research/mech_interp/distilbert/distilbert-base-uncased-finetuned-sst-2-english",
  "activation": "gelu",
  "architectures": [
    "DistilBertForSequenceClassification"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "finetuning_task": "sst-2",
  "hidden_dim": 3072,
  "id2label": {
    "0": "NEGATIVE",
    "1": "POSITIVE"
  },
  "initializer_range": 0.02,
  "label2id": {
    "NEGATIVE": 0,
    "POSITIVE": 1
  },
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "output_past": true,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.49.0",
  "vocab_size": 30522
}

In [8]:
hf_model

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


### 3. Hooked model

In [9]:
model = transformer_lens.HookedEncoder.from_pretrained(
    checkpoint, 
    hf_model=hf_model,
    device=device
)

If using BERT for interpretability research, keep in mind that BERT has some significant architectural differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning that the last LayerNorm in a block cannot be folded.


Moving model to device:  mps
Loaded pretrained model distilbert/distilbert-base-uncased-finetuned-sst-2-english into HookedEncoder


In [10]:
inputs = tokenizer(
  'I love this movie!', 
  return_tensors='pt', 
  padding='max_length', 
  truncation=True,
  max_length=512
).to(device)

In [11]:
with torch.no_grad():
    output = model(**inputs)
    probs = F.softmax(output.logits, dim=-1)
    predicted_class_id = probs.argmax().item()
    confidence_score = probs[:, predicted_class_id].item()
    predicted_label = hf_model.config.id2label[predicted_class_id]

print({'label': predicted_label, 'score': confidence_score})

{'label': 'POSITIVE', 'score': 0.999690055847168}


# Sources

1. [Ground truth - Arena::Indirect Object Identification](https://arena-chapter1-transformer-interp.streamlit.app/[1.4.1]_Indirect_Object_Identification)
2. [Interpretability in the wild: A circuit for indirect object identification in GPT-2 small, by Wang, K, et. al.](https://arxiv.org/pdf/2211.00593)
3. [NOTEBOOK - Exploratory Analysis Demo, by Neel Nanda](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=WXktSe0CvBdh)
4. [An analogy for understanding transformers, by Callum McDougall](https://www.lesswrong.com/posts/euam65XjigaCJQkcN/an-analogy-for-understanding-transformers)