In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import math
import os
from pathlib import Path

import torch
from huggingface_hub import snapshot_download
from transformer_lens import HookedEncoder, HookedTransformerConfig
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

In [3]:
os.chdir('..')
print('Changed working directory to parent directory')

with open(os.path.expanduser('~/.huggingface/token')) as f:
    os.environ['HF_TOKEN'] = f.read().strip()
    print(f'Hugging Face token loaded: {os.environ['HF_TOKEN'][:3]}...')

torch.set_grad_enabled(False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Changed working directory to parent directory
Hugging Face token loaded: hf_...
Using device: cpu


# 1. Download weights from Hugging Face

In [4]:
checkpoint = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english'

In [10]:
if not Path(checkpoint).exists():
    print(f'Downloading {checkpoint}...')
    local_dir = Path(checkpoint).resolve()
    snapshot_download(repo_id=checkpoint, local_dir=local_dir, local_dir_use_symlinks=False)
    print(f'Downloaded {checkpoint}.')

Downloading distilbert/distilbert-base-uncased-finetuned-sst-2-english...


Fetching 17 files:   0%|          | 0/17 [00:00<?, ?it/s]Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Fetching 17 files:  24%|██▎       | 4/17 [00:00<00:01,  7.45it/s]Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Fetching 17 files:  29%|██▉       | 5/17 [00:00<00:01,  6.86it/s]Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP downloa

Downloaded distilbert/distilbert-base-uncased-finetuned-sst-2-english.





# 1. Model setup

In [11]:
tokenizer = DistilBertTokenizer.from_pretrained(checkpoint)
model = DistilBertForSequenceClassification.from_pretrained(checkpoint).to(device)

In [12]:
cfg = HookedTransformerConfig(
  model_name=checkpoint,
  tokenizer_name=checkpoint,
  device=device,
  d_model=model.config.dim,
  d_head=model.config.dim // model.config.n_heads,
  n_layers=model.config.n_layers,
  n_ctx=model.config.max_position_embeddings,
  n_heads=model.config.n_heads,
  d_mlp=model.config.hidden_dim,
  d_vocab=model.config.vocab_size,
  d_vocab_out=len(model.config.id2label),
  act_fn=model.config.activation,
  use_attn_scale=True,                                                # as in Attention is All You Need -- may need to be adjusted
  eps=1e-12,                                                          # chatGPT query -- may need to be adjusted
  use_hook_tokens=False,                                              # memory intensive, but may be tested later
  use_attn_result=False,                                              # memory intensive, but may be tested later
  use_split_qkv_input=False,                                          # memory intensive, but may be tested later
  use_hook_mlp_in=False,                                              # memory intensive, but may be tested later
  use_attn_in=False,                                                  # memory intensive, but may be tested later
  ungroup_grouped_query_attention=False,                              # wild guess
  attn_scale=math.sqrt(model.config.dim // model.config.n_heads),     # lib default
)

In [13]:
hooked_model = HookedEncoder(cfg)

Moving model to device:  cpu


In [16]:
from transformer_lens import loading_from_pretrained as loading

state_dict = loading.get_pretrained_state_dict(checkpoint, cfg)
state_dict

You are using a model of type distilbert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Some weights of BertForPreTraining were not initialized from the model checkpoint at /Users/marcosf/Desktop/research/mech_interp/distilbert/distilbert-base-uncased-finetuned-sst-2-english and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'embeddings.LayerNorm.bias', 'embeddings.LayerNorm.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.word_embeddings.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.atten

ValueError: Loading weights from the architecture is not currently supported: None, generated from model name distilbert/distilbert-base-uncased-finetuned-sst-2-english. Feel free to open an issue on GitHub to request this feature.

In [None]:
hooked_model.state_dict()

In [None]:
inputs = tokenizer(
  'I am in love with it.', 
  return_tensors='pt', 
  padding='max_length', 
  truncation=True,
  max_length=512
).to(device)

In [None]:
import torch.nn.functional as F

In [None]:
logits = hooked_model(inputs['input_ids'])
logits

# Sources

1. [Ground truth - Arena::Indirect Object Identification](https://arena-chapter1-transformer-interp.streamlit.app/[1.4.1]_Indirect_Object_Identification)
2. [Interpretability in the wild: A circuit for indirect object identification in GPT-2 small, by Wang, K, et. al.](https://arxiv.org/pdf/2211.00593)
3. [NOTEBOOK - Exploratory Analysis Demo, by Neel Nanda](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=WXktSe0CvBdh)
4. [An analogy for understanding transformers, by Callum McDougall](https://www.lesswrong.com/posts/euam65XjigaCJQkcN/an-analogy-for-understanding-transformers)