# Notebook for project IEinAI course

authors: Ana Ivaschescu, Luc Straeter, Eric Zila

A large part of this code is from Neel Nanda notebook in the original repo we forked.


## Setup (Don't Read This)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install transformer_lens==1.2.1
    %pip install git+https://github.com/neelnanda-io/neel-plotly

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/neel-plotly
  Cloning https://github.com/neelnanda-io/neel-plotly to /tmp/pip-req-build-kio47e7u
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly /tmp/pip-req-build-kio47e7u
  Resolved https://github.com/neelnanda-io/neel-plotly to commit 6dc24b26f8dec991908479d7445dae496b3430b7
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f6b56e0fb50>

Plotting helper functions:

In [6]:
from neel_plotly import line, scatter, imshow, histogram

## Loading the model

This loads a conversion of the author's synthetic model checkpoint to TransformerLens format. See [this notebook](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Othello_GPT.ipynb) for how.

In [7]:
import transformer_lens.utils as utils
cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre"
)
model = HookedTransformer(cfg)

In [8]:

sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)

<All keys matched successfully>

Code to load and convert one of the author's checkpoints to TransformerLens:

Testing code for the synthetic checkpoint giving the correct outputs

In [9]:
# An example input
sample_input = torch.tensor([[20, 19, 18, 10, 2, 1, 27, 3, 41, 42, 34, 12, 4, 40, 11, 29, 43, 13, 48, 56, 33, 39, 22, 44, 24, 5, 46, 6, 32, 36, 51, 58, 52, 60, 21, 53, 26, 31, 37, 9, 25, 38, 23, 50, 45, 17, 47, 28, 35, 30, 54, 16, 59, 49, 57, 14, 15, 55, 7]])
# The argmax of the output (ie the most likely next move from each position)
sample_output = torch.tensor([[21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5,
         33,  5, 52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59,
         50, 28, 14, 28, 28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15,
         14, 15,  8,  7,  8]])
model(sample_input).argmax(dim=-1)

tensor([[21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5,
         33,  5, 52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59,
         50, 28, 14, 28, 28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15,
         14, 15,  8,  7,  8]], device='cuda:0')

## Loading Othello Content
Boring setup code to load in 100K sample Othello games, the linear probe, and some utility functions

In [10]:

if IN_COLAB:
    #!git clone https://github.com/likenneth/othello_world
    !git clone https://github.com/zilaeric/othello-gpt-probing
    OTHELLO_ROOT = Path("/content/othello-gpt-probing/")
    import sys
    sys.path.append(str(OTHELLO_ROOT/"mechanistic_interpretability"))
    from mech_interp_othello_utils import plot_single_board, to_string, to_int, int_to_label, string_to_label, OthelloBoardState
else:
    OTHELLO_ROOT = Path("/workspace/othello_world/")
    from tl_othello_utils import plot_single_board, to_string, to_int, int_to_label, string_to_label, OthelloBoardState


fatal: destination path 'othello-gpt-probing' already exists and is not an empty directory.


We load in a big tensor of 100,000 games, each with 60 moves. This is in the format the model wants, with 1-59 representing the 60 moves, and 0 representing pass.

We also load in the same set of games, in the same order, but in "string" format - still a tensor of ints but referring to moves with numbers from 0 to 63 rather than in the model's compressed format of 1 to 59

In [11]:
board_seqs_int = torch.tensor(np.load(OTHELLO_ROOT/"mechanistic_interpretability/board_seqs_int_small.npy"), dtype=torch.long)
board_seqs_string = torch.tensor(np.load(OTHELLO_ROOT/"mechanistic_interpretability/board_seqs_string_small.npy"), dtype=torch.long)

num_games, length_of_game = board_seqs_int.shape
print("Number of games:", num_games,)
print("Length of game:", length_of_game)

Number of games: 100000
Length of game: 60


In [12]:
stoi_indices = [
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
]
alpha = "ABCDEFGH"


def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"


board_labels = list(map(to_board_label, stoi_indices))

## Running the Model

The model's context length is 59, not 60, because it's trained to receive the first 59 moves and predict the final 59 moves (ie `[0:-1]` and `[1:]`. Let's run the model on the first 30 moves of game 0!

In [13]:
moves_int = board_seqs_int[0, :30]

# This is implicitly converted to a batch of size 1
logits = model(moves_int)
print("logits:", logits.shape)

logits: torch.Size([1, 30, 61])


We take the final vector of logits. We convert it to log probs and we then remove the first element (corresponding to passing, and we've filtered out all games with passing) and get the 60 logits. This is 64-4 because the model's vocab is compressed, since the center 4 squares can't be played.

We then convert it to an 8 x 8 grid and plot it, with some tensor magic

In [14]:
logit_vec = logits[0, -1]
log_probs = logit_vec.log_softmax(-1)
# Remove passing
log_probs = log_probs[1:]
assert len(log_probs)==60

temp_board_state = torch.zeros(64, device=logit_vec.device)
# Set all cells to -15 by default, for a very negative log prob - this means the middle cells don't show up as mattering
temp_board_state -= 13.
temp_board_state[stoi_indices] = log_probs

We can now plot this as a board state! We see a crisp distinction from a set of moves that the model clearly thinks are valid (at near uniform probabilities), and a bunch that aren't. Note that by training the model to predict a *uniformly* chosen next move, we incentivise it to be careful about making all valid logits be uniform!

In [15]:
def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    if diverging_scale:
        imshow(state, y=[i for i in alpha], x=[str(i) for i in range(8)], color_continuous_scale="RdBu", color_continuous_midpoint=0., aspect="equal", **kwargs)
    else:
        imshow(state, y=[i for i in alpha], x=[str(i) for i in range(8)], color_continuous_scale="Blues", color_continuous_midpoint=None, aspect="equal", **kwargs)
plot_square_as_board(temp_board_state.reshape(8, 8), zmax=0, diverging_scale=False, title="Example Log Probs")

## Exploring Game Play

For comparison, let's plot the board state after 30 moves, and the valid moves

`plot_single_board` is a helper function to plot the board state after a series of moves. It takes in moves in the label format, ie A0 to H7, so we convert. We can see on inspection that our model has correctly identified the valid moves!

In [16]:
plot_single_board(int_to_label(moves_int))

We can also compute this explicitly, using the OthelloBoardState class (thanks to Kenneth Li for writing this one and saving me a bunch of tedious effort!)

In [17]:
board = OthelloBoardState()
board.update(to_string(moves_int))
fig = plot_square_as_board(board.state, title="Example Board State (+1 is Black, -1 is White)")



And we can get out a list of valid moves:

In [18]:
print("Valid moves:", string_to_label(board.get_valid_moves()))

Valid moves: ['A3', 'A5', 'A6', 'B2', 'C7', 'D2', 'E6', 'F7', 'G6', 'H2', 'H3', 'H4', 'H6']


## Making some utilities

At this point, I'll stop and get some aggregate data that will be useful later - a tensor of valid moves, of board states, and a cache of all model activations across 50 games (in practice, you want as much as will comfortably fit into GPU memory). It's really convenient to have the ability to quickly run an experiment across a bunch of games! And one of the great things about small models on algorithmic tasks is that you just can do stuff like this.

For want of a lack of creativity, let's call these the **focus games**

In [19]:
num_games = 50
focus_games_int = board_seqs_int[:num_games]
focus_games_string = board_seqs_string[:num_games]

A big stack of each move's board state and a big stack of the valid moves in each game (one hot encoded to be a nice tensor)

In [20]:
def one_hot(list_of_ints, num_classes=64):
    out = torch.zeros((num_classes,), dtype=torch.float32)
    out[list_of_ints] = 1.
    return out
focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = torch.zeros((num_games, 60, 64), dtype=torch.float32)
for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(60):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())
print("focus states:", focus_states.shape)
print(focus_states[0][:5])
print("focus_valid_moves", focus_valid_moves.shape)
print(focus_valid_moves[0][:5])


focus states: (50, 60, 8, 8)
[[[ 0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  1.  0.  0.  0.  0.]
  [ 0.  0.  0.  1.  1.  0.  0.  0.]
  [ 0.  0.  0.  1. -1.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.]]

 [[ 0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  1. -1.  0.  0.  0.]
  [ 0.  0.  0.  1. -1.  0.  0.  0.]
  [ 0.  0.  0.  1. -1.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.]]

 [[ 0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  1. -1.  0.  0.  0.]
  [ 0.  0.  0.  1.  1.  1.  0.  0.]
  [ 0.  0.  0.  1. -1.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.]]

 [[ 0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  1. -1.  0. -1

A cache of every model activation and the logits

In [21]:

focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].cuda())
print(focus_cache)
focus_cache['resid_post', 6].shape


ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_re

torch.Size([50, 59, 512])

## Use probe Ana, Eric, and Luc


In [22]:
probe_names = ["4_hook_resid_post", "6_hook_resid_post", "4_hook_resid_mid", "6_hook_resid_mid", "4_hook_attn_out", "6_hook_attn_out", "4_hook_mlp_out", "6_hook_mlp_out"]
probe_dict = {}
for name in probe_names:
  full_linear_probe = torch.load(OTHELLO_ROOT/f"probes/full_linear_probe_{name}.pth")
  rows = 8
  cols = 8
  options = 3
  black_to_play_index = 0
  white_to_play_index = 1
  blank_index = 0
  their_index = 1
  my_index = 2
  linear_probe = torch.zeros(cfg.d_model, rows, cols, options, device="cuda")
  linear_probe[..., blank_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 0] + full_linear_probe[white_to_play_index, ..., 0])
  linear_probe[..., their_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 1] + full_linear_probe[white_to_play_index, ..., 2])
  linear_probe[..., my_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 2] + full_linear_probe[white_to_play_index, ..., 1])
  probe_dict[name] = linear_probe


In [23]:
def plot_probe_outputs(linear_probe, layer, game_index, move, **kwargs):
    residual_stream = focus_cache["resid_post", layer][game_index, move]
    # print("residual_stream", residual_stream.shape)
    probe_out = einops.einsum(residual_stream, linear_probe, "d_model, d_model row col options -> row col options")
    probabilities = probe_out.softmax(dim=-1)
    plot_square_as_board(probabilities, facet_col=2, facet_labels=["P(Empty)", "P(Their's)", "P(Mine)"], height=400, width=600, **kwargs)

In [24]:
game_index = 0
move = 20

plot_single_board(int_to_label(focus_games_int[game_index, :move+1]))

for name in probe_names:
  layer = int(name[0])
  plot_probe_outputs(probe_dict[name], layer, game_index, move, title=name)



In [25]:
def state_stack_to_one_hot(state_stack):
    one_hot = torch.zeros(
        state_stack.shape[0], # num games
        state_stack.shape[1], # num moves
        8, # rows
        8, # cols
        3, # the two options
        device=state_stack.device,
        dtype=torch.int,
    )
    one_hot[..., 0] = state_stack == 0 # empty
    one_hot[..., 1] = state_stack == -1 # white
    one_hot[..., 2] = state_stack == 1 # black

    return one_hot

# We first convert the board states to be in terms of my (+1) and their (-1)
alternating = np.array([-1 if i%2 == 0 else 1 for i in range(focus_games_int.shape[1])])
flipped_focus_states = focus_states * alternating[None, :, None, None]

# We now convert to one hot
focus_states_flipped_one_hot = state_stack_to_one_hot(torch.tensor(flipped_focus_states))

# Take the argmax
focus_states_flipped_value = focus_states_flipped_one_hot.argmax(dim=-1)

In [26]:
for name in probe_names:
  layer = int(name[0])
  probe_name = name[7:]
  probe_out = einops.einsum(focus_cache[probe_name, layer], probe_dict[name], "game move d_model, d_model row col options -> game move row col options")
  probe_out_value = probe_out.argmax(dim=-1)
  correct_middle_odd_answers = (probe_out_value.cpu() == focus_states_flipped_value[:, :-1])[:, 5:-5:2]
  accuracies_odd = einops.reduce(correct_middle_odd_answers.float(), "game move row col -> row col", "mean")
  correct_middle_answers = (probe_out_value.cpu() == focus_states_flipped_value[:, :-1])[:, 5:-5]
  accuracies = einops.reduce(correct_middle_answers.float(), "game move row col -> row col", "mean")

  plot_square_as_board(torch.stack([accuracies], dim=0), title="Average Error Rate of Linear Probe", facet_col=0, facet_labels=[name], zmax=1, zmin=0)

In [27]:
pos = 20
game_index = 0
moves = focus_games_string[game_index, :pos+1]
plot_single_board(moves)
state = torch.zeros((64,), dtype=torch.float32, device="cuda") - 10.
state[stoi_indices] = focus_logits[game_index, pos].log_softmax(dim=-1)[1:]



cell_r = 5
cell_c = 4
print(f"Flipping the color of cell {'ABCDEFGH'[cell_r]}{cell_c}")

board = OthelloBoardState()
board.update(moves.tolist())
board_state = board.state.copy()
valid_moves = board.get_valid_moves()
flipped_board = copy.deepcopy(board)
flipped_board.state[cell_r, cell_c] *= -1
flipped_valid_moves = flipped_board.get_valid_moves()

newly_legal = [string_to_label(move) for move in flipped_valid_moves if move not in valid_moves]
newly_illegal = [string_to_label(move) for move in valid_moves if move not in flipped_valid_moves]
print("newly_legal", newly_legal)
print("newly_illegal", newly_illegal)

Flipping the color of cell F4
newly_legal ['D2']
newly_illegal ['G4']


In [28]:

for name in probe_names:
  layer = int(name[0])
  probe_name = name[7:]
  blank_probe = probe_dict[name][..., 0] - probe_dict[name][..., 1] * 0.5 - probe_dict[name][..., 2] * 0.5
  my_probe = probe_dict[name][..., 2] - probe_dict[name][..., 1]

  flip_dir = my_probe[:, cell_r, cell_c]

  big_flipped_states_list = []
  scales = [4,8,8,8]
  for scale in scales:
      def flip_hook(resid, hook):
          coeff = resid[0, pos] @ flip_dir/flip_dir.norm()
          resid[0, pos] -= (scale+1) * coeff * flip_dir/flip_dir.norm()
      flipped_logits = model.run_with_hooks(focus_games_int[game_index:game_index+1, :pos+1],
                          fwd_hooks=[(f"blocks.{layer}.hook_{probe_name}", flip_hook)]).log_softmax(dim=-1)[0, pos]
      flip_state = torch.zeros((64,), dtype=torch.float32, device="cuda") - 10.
      flip_state[stoi_indices] = flipped_logits.log_softmax(dim=-1)[1:]
      big_flipped_states_list.append(flip_state)
  flip_state_big = torch.stack(big_flipped_states_list)
  state_big = einops.repeat(state, "d -> b d", b=4)
  color = torch.zeros((len(scales), 64)).cuda() + 0.2
  for s in newly_legal:
      color[:, to_string(s)] = 1
  for s in newly_illegal:
      color[:, to_string(s)] = -1
  scatter(y=state_big, x=flip_state_big, title=f"Original vs Flipped {string_to_label(8*cell_r+cell_c)} at Layer {layer}", xaxis="Flipped", yaxis="Original", hover=[f"{r}{c}" for r in "ABCDEFGH" for c in range(8)], facet_col=0, facet_labels=[name, name], color=color, color_name="Newly Legal", color_continuous_scale="Geyser")