Neel Nanda's Transformer Lens Library (https://github.com/neelnanda-io/TransformerLens) is great. It has infinitely more functionality, is maintained properly and has much more efficient code. I use it and I love it and you should to.

So why did I create this? TransformerLens works great if you're happy to use whichever models are supported in the library. If you have another model, perhaps one you've trained yourself, or isn't a transformer-based LLM, or that just happens not to currently be supported on TransformerLens, then it's not that easy to work with. 

I put this together as a minimally invasive library that you can wrap around an existing model but still make interpretability relatively easy and straightforward. It aims to be flexible (can work with any PyTorch model) and easy to use. It doesn't currently have much in the way of functionality but it's a platform that can be built on. The price for flexibility is that the setup is a bit more involved than simply working inside a notebook.

This repo contains support for Andrej Karpathy's MinGPT (https://github.com/karpathy/minGPT) for use as an example but the idea is that you can hook this up to any model. The setup is roughly as follows:

1. Clone the repository, install requirements and copy the code for your model into a new file in the repo
2. In 'module.py' create a child class of 'BaseModule' and define custom dataloader, forward and loss functions for your model
3. Inside your model definition import the 'HookPoint' class from 'utils.py' and place hook modules at locations in your model that you wish to target
4. Open up a notebook and start hacking away

In [1]:
from interpretability.module import GPTModule
from interpretability.utils import show_attention_head

In [2]:
# Create an instance of GPTModule. The dataloader, forward and backward passes and loss functions are already defined.
model = GPTModule(batch_size=1, gpt_version="gpt2")

loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, block_size=1024, bias=True
number of parameters: 123.65M


In [3]:
# Run a forward pass. 
# Batch size defaults to 1. If inputs and targets are set to None then the module will iterate over the dataset defined
# in the dataloader method. You do not need to pass targets when passing custom inputs if you only want to do a forward pass.
# All (tagged) activations and parameters are saved to 'model.data', as are inputs, outputs and targets for that particular batch.

model.forward(inputs=None, targets=None)            

In [4]:
# Optionally run a backward pass.
# If you have defined a custom loss function and want to store gradients you can run a backward pass
# If you pass a target tensor to the backward method it will overrideany target tensor passed during
# the forward method.
model.backward(targets=None)

In [5]:
# Some helper methods are defined to make it easier to access specific activations. The tag name corresponds 
# to the tag name specified in the relevant HookPoint in the model definition. A reminder that everything is 
# stored inside 'model.data' if you can't find what you're looking for.

# `get_activation_values_by_tag` will return a list of all activations with the corresponding tag in 
# the order they appear in the model.

attn = model.get_activation_values_by_tag("attention")

# In order to save out activations and access them you need to define HookPoints within your model definition to
# specify which activations you want to save. Inside your model definition:

# from utils import HookPoint
# self.hook = HookPoint(tags=["attention"])        # in init method
# x = self.hook(x)                                 # in forward method 

In [6]:
# A few helper functions are defined in 'util.py' to perform common interpretability analyses. There's pretty limited
# functionality at the moment but I hope to add more as I go. Of course, you can just hack away and do whatever you 
# want at this point.

# One such function is `show_attention_head`:
show_attention_head(attn, layer=0, head=0)