In this notebook we will explore different language model interpretability methods using a neural networks library written by Neel Nanda (https://github.com/neelnanda-io/Easy-Transformer) for this purpose.

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import plotly.io as pio
pio.renderers.default

'plotly_mimetype+notebook'

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm

import random
import time

from pathlib import Path
import pickle
import os

import matplotlib.pyplot as plt

import plotly.express as px
import plotly.graph_objects as go

from torch.utils.data import DataLoader #DataLoada

from functools import *
import pandas as pd
import gc
import collections
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets

In [4]:
from easy_transformer.utils import (
    gelu_new,
    to_numpy,
    get_corner,
    lm_cross_entropy_loss,
)# Helper functions
from easy_transformer.hook_points import (
    HookedRootModule,
    HookPoint,
)# Hooking utilities
from easy_transformer import EasyTransformer, EasyTransformerConfig
import easy_transformer
from easy_transformer.experiments import (
    ExperimentMetric,
    AblationConfig,
    EasyAblation,
    EasyPatching,
    PatchingConfig
)

In [5]:
device = "cuda"

## Hook Points
In (https://transformer-circuits.pub/2021/garcon/index.html) the good team at Anthropic released a walk through of their internal tool, called Garcon. In Easy-Transformer we, follow in the style of that library by defining a `HookPoint` class. This is a layer to wrap any activation within the model in. The HookPoint acts as an identity function, but allows us to put PyTorch hooks in to edit and access the relevant activation. This allows us to take any model and insert in access points to all interesting activations by wrapping them in HookPoints.

There is also a `HookedRootModule` class. This is a utility class that the root module should inherit from (root module = the model we run). It has several utility functions for using hooks as well.

The default interface is the `run_with_hooks` function on the root module, which lets us run a forwards pass on the model, and pass on a list of hooks paired with layer names to run on that pass.

The syntax for a hook is `function(activation, hook,)` where `activation` is the activation the hook is wrapped around, and `hook` is the `HookPoint` class the function is attached to. If the function returns a new activation or _edits the activation_ in-place, that replaces the old one, if it returns None then the activation remains as is.

### HookPoints Example
Here's a simple example of how to use the classes:

We define a basic network with two layers that each take a scalar input x, square it, and add a constant: $x_0=x, x_1=x_0{^2}+3,x_2=x_1{^2}-4$.
We wrap the input, each layer's output, and the intermediate value of each layer (the square) in a hook point.

In [7]:
from easy_transformer.hook_points import HookedRootModule, HookPoint

In [10]:
class SquareAdd(nn.Module):
    def __init__(self, offset):
        super().__init__()
        self.offset = nn.Parameter(torch.tensor(offset))
        self.hook_square = HookPoint()
    
    def forward(self, x):
        # The hook_square doesn't change the value, but lets us access it
        square = self.hook_square(x * x)
        return self.offset + square
    

class TwoLayerModel(HookedRootModule):
    def __init__(self):
        super().__init__()
        self.layer1 = SquareAdd(3.0)
        self.layer2 = SquareAdd(-4.0)
        self.hook_in = HookPoint()
        self.hook_mid = HookPoint()
        self.hook_out = HookPoint()
        
        # We need to call the setup function of HookedRootModule  to build an
        # internal dictionary of modules and hooks, and to give each hook a name.
        super().setup()
    
    def forward(self, x):
        # We wrap the input and each layer's output in a hook - they leave the
        # value unchanged (unless there's a hook added to explicitly change it),
        # but allow us to access it.
        x_in = self.hook_in(x)
        x_mid = self.hook_mid(self.layer1(x_in))
        x_out = self.hook_out(self.layer2(x_mid))
        return x_out

model = TwoLayerModel()

We can add a cache, to save the