# Introduction

This notebook is based on this [colab notebook](https://neelnanda.io/othello-notebook) by Neel Nanda



## Setup (Don't Read This)

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install transformer_lens==1.2.1
    %pip install git+https://github.com/neelnanda-io/neel-plotly
    %pip install circuitsvis
    
    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformer_lens==1.2.1
  Downloading transformer_lens-1.2.1-py3-none-any.whl (80 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m80.5/80.5 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch<2.0,>=1.10
  Downloading torch-1.13.1-cp39-cp39-manylinux1_x86_64.whl (887.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m887.4/887.4 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
Collecting transformers<5.0.0,>=4.25.1
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m93.0 MB/s[0m eta [36m0:00:00[0m
Collecting jaxtyping<0.3.0,>=0.2.11
  Downloading jaxtyping-0.2.15-py3-none-any.whl (20 kB)
Collecting datasets<3.0.0,>=2.7.1
  Downloading datasets-2.11.0-py3-none-any.whl (468 kB)
[2K     [90m━━━━━━━━━━━━━━━━━

In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [3]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy
import circuitsvis as cv

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [4]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fc240535d90>

Plotting helper functions:

In [6]:
from neel_plotly import line, scatter, imshow, histogram

# Othello GPT

<details><summary>I was in a massive rush when I made this codebase, so it's a bit of a mess, sorry! This Colab is an attempt to put a smiley face on the underlying shoggoth, but if you want more of the guts, here's an info dump</summary>

This codebase is a bit of a mess! This colab is an attempt to be a pretty mask on top of the shoggoth, but if it helps, here's an info dump I wrote for someone about interpreting this codebase:

Technical details:

-   Games are 60 moves, but the model can only take in 59. It's trained to predict the next move, so they give it the first 59 moves (0<=...<59) and evaluate the predictions for each next move (1<=...<60). There is no Beginning of Sequence token, and the model never tries to predict the first move of the game

-   This means that, while in Othello black plays first, here white plays "first" because first is actually second

-   You can get code to load the synthetic model (ie trained to play uniformly random legal moves) into TransformerLens here: [https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Othello\_GPT.ipynb](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Othello_GPT.ipynb) 
-   You can load in their synthetically generated games [from their Github](https://github.com/likenneth/othello_world) (there's a google drive link)
-   Their model has 8 layers, residual stream width 512, 8 heads per layer and 2048 neurons per layer.

-   The vocab size is 61. 0 is -100, which I *think *means pass, I just filtered out the rare games that include that move and ignore it. 1 to 60 (inclusive) means the board moves in lexicographic order (A0, A1, ..., A7, B0, ...) but *skipping *D3, D4, E3 and E4. These are at the center of the board and so can never be played, because Othello starts with them filled)

-   There's 3 ways to denote a board cell. I call them "int", "string" and "label" (which is terrible notation, sorry). 

-   "label" means the label for a board cell, \["A0", ..., "A7", ''', "H7"\] (I index at 0 not 1, sorry!). 
-   "int" means "part of the model vocabulary", so 1 means A0, we then count up but miss the center squares, so 27 is D2, 28 is D5, 33 is E2 and 34 is E5. 
-   "string" means "the input format of the OthelloBoardState class". These are integers (sorry!) from 0 to 63, and exactly correspond to labels A0, ..., H7, without skipping any center cells. OthelloBoardState is a class in data/othello.py that can play the Othello game and tell you the board state and valid moves (created by the authors, not me)
-   I have utility functions to\_int, to\_string, str\_to\_label and int\_to\_label in tl\_othello\_utils.py to do this

-   The embedding and unembedding are untied (ie, in contrast to most language models, the map W\_U from final residual to the logits is *not *the transpose of W\_E, the map from tokens to the initial residual. They're unrelated matrices)
-   tl\_othello\_utils.py is my utils file, with various functions to load games, etc. \`board\_seqs\_string\` and \`board\_seqs\_int\` are massive saved tensors with every move across all 4.5M synthetic games in both string and int format, these are 2.3GB so I haven't attached them lol. You can recreate them from the synthetic games they provide. It also provides a bunch of plotting functions to make nice othello board states, and some random other utilities
-   \`tl\_probing.py\` is my probe training file. But it was used to train a *second* probe, linear\_probe\_L4\_blank\_vs\_color\_v1.pth . This probe actually didn't work very well for analysing the model (despite getting great accuracy) and I don't know why - it was trained on layer 4, to do a binary classification on blank vs not blank, and on my color vs their color *conditional *on not being blank (ie not evaluated if blank). For some reason, the "this cell is my color" direction has a significant dot product with the "is blank" direction, and this makes it much worse for eg interpreting neurons. I don't know why!
-   \`tl\_scratch.py\` is where I did some initial exploration, including activation patching between different final moves
-   \`tl\_exploration.py\` is where I did my most recent exploration, verifying that the probe works, doing probe interventions (CTRL F for \`newly\_legal\`) and using the probe to interpret neurons

</details>


## Loading the model

This loads a conversion of the author's synthetic model checkpoint to TransformerLens format. See [this notebook](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Othello_GPT.ipynb) for how.

In [7]:
import transformer_lens.utils as utils
cfg = HookedTransformerConfig(
    n_layers = 8,
    d_model = 512,
    d_head = 64,
    n_heads = 8,
    d_mlp = 2048,
    d_vocab = 61,
    n_ctx = 59,
    act_fn="gelu",
    normalization_type="LNPre"
)
model = HookedTransformer(cfg)
model.use_attn_result = True

In [8]:

sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "synthetic_model.pth")
# champion_ship_sd = utils.download_file_from_hf("NeelNanda/Othello-GPT-Transformer-Lens", "championship_model.pth")
model.load_state_dict(sd)

Downloading synthetic_model.pth:   0%|          | 0.00/101M [00:00<?, ?B/s]

<All keys matched successfully>

## Loading Othello Content
Boring setup code to load in 100K sample Othello games, the linear probe, and some utility functions

In [9]:

if IN_COLAB:
    !git clone https://github.com/likenneth/othello_world
    OTHELLO_ROOT = Path("/content/othello_world/")
    import sys
    sys.path.append(str(OTHELLO_ROOT/"mechanistic_interpretability"))
    from mech_interp_othello_utils import plot_single_board, to_string, to_int, int_to_label, string_to_label, OthelloBoardState
else:
    OTHELLO_ROOT = Path("/workspace/othello_world/")
    from tl_othello_utils import plot_single_board, to_string, to_int, int_to_label, string_to_label, OthelloBoardState


Cloning into 'othello_world'...
remote: Enumerating objects: 66, done.[K
remote: Counting objects: 100% (23/23), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 66 (delta 13), reused 10 (delta 10), pack-reused 43[K
Unpacking objects: 100% (66/66), 10.01 MiB | 5.34 MiB/s, done.


We load in a big tensor of 100,000 games, each with 60 moves. This is in the format the model wants, with 1-59 representing the 60 moves, and 0 representing pass.

We also load in the same set of games, in the same order, but in "string" format - still a tensor of ints but referring to moves with numbers from 0 to 63 rather than in the model's compressed format of 1 to 59

In [10]:
board_seqs_int = torch.tensor(np.load(OTHELLO_ROOT/"board_seqs_int_small.npy"), dtype=torch.long)
board_seqs_string = torch.tensor(np.load(OTHELLO_ROOT/"board_seqs_string_small.npy"), dtype=torch.long)

num_games, length_of_game = board_seqs_int.shape
print("Number of games:", num_games,)
print("Length of game:", length_of_game)

Number of games: 100000
Length of game: 60


In [11]:
stoi_indices = [
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
]
alpha = "ABCDEFGH"


def to_board_label(i):
    return f"{alpha[i//8]}{i%8}"


board_labels = list(map(to_board_label, stoi_indices))

## Running the Model

The model's context length is 59, not 60, because it's trained to receive the first 59 moves and predict the final 59 moves (ie `[0:-1]` and `[1:]`. Let's run the model on the first 30 moves of game 0!

In [12]:
moves_int = board_seqs_int[0, :30]

# This is implicitly converted to a batch of size 1
logits = model(moves_int)
print("logits:", logits.shape)

logits: torch.Size([1, 30, 61])


We take the final vector of logits. We convert it to log probs and we then remove the first element (corresponding to passing, and we've filtered out all games with passing) and get the 60 logits. This is 64-4 because the model's vocab is compressed, since the center 4 squares can't be played.

We then convert it to an 8 x 8 grid and plot it, with some tensor magic

In [13]:
logit_vec = logits[0, -1]
log_probs = logit_vec.log_softmax(-1)
# Remove passing
log_probs = log_probs[1:]
assert len(log_probs)==60

temp_board_state = torch.zeros(64, device=logit_vec.device)
# Set all cells to -15 by default, for a very negative log prob - this means the middle cells don't show up as mattering
temp_board_state -= 13.
temp_board_state[stoi_indices] = log_probs

We can now plot this as a board state! We see a crisp distinction from a set of moves that the model clearly thinks are valid (at near uniform probabilities), and a bunch that aren't. Note that by training the model to predict a *uniformly* chosen next move, we incentivise it to be careful about making all valid logits be uniform!

In [14]:
def plot_square_as_board(state, diverging_scale=True, **kwargs):
    """Takes a square input (8 by 8) and plot it as a board. Can do a stack of boards via facet_col=0"""
    if diverging_scale:
        imshow(state, y=[i for i in alpha], x=[str(i) for i in range(8)], color_continuous_scale="RdBu", color_continuous_midpoint=0., aspect="equal", **kwargs)
    else:
        imshow(state, y=[i for i in alpha], x=[str(i) for i in range(8)], color_continuous_scale="Blues", color_continuous_midpoint=None, aspect="equal", **kwargs)
plot_square_as_board(temp_board_state.reshape(8, 8), zmax=0, diverging_scale=False, title="Example Log Probs")

## Exploring Game Play

For comparison, let's plot the board state after 30 moves, and the valid moves

`plot_single_board` is a helper function to plot the board state after a series of moves. It takes in moves in the label format, ie A0 to H7, so we convert. We can see on inspection that our model has correctly identified the valid moves!

We can also compute this explicitly, using the OthelloBoardState class (thanks to Kenneth Li for writing this one and saving me a bunch of tedious effort!)

In [15]:
board = OthelloBoardState()
board.update(to_string(moves_int))
plot_square_as_board(board.state, title="Example Board State (+1 is Black, -1 is White)")

And we can get out a list of valid moves:

In [16]:
print("Valid moves:", string_to_label(board.get_valid_moves()))

Valid moves: ['A3', 'A5', 'A6', 'B2', 'C7', 'D2', 'E6', 'F7', 'G6', 'H2', 'H3', 'H4', 'H6']


## Making some utilities

At this point, I'll stop and get some aggregate data that will be useful later - a tensor of valid moves, of board states, and a cache of all model activations across 50 games (in practice, you want as much as will comfortably fit into GPU memory). It's really convenient to have the ability to quickly run an experiment across a bunch of games! And one of the great things about small models on algorithmic tasks is that you just can do stuff like this. 

For want of a lack of creativity, let's call these the **focus games**

In [17]:
num_games = 50
focus_games_int = board_seqs_int[:num_games]
focus_games_string = board_seqs_string[:num_games]

A big stack of each move's board state and a big stack of the valid moves in each game (one hot encoded to be a nice tensor)

In [18]:
def one_hot(list_of_ints, num_classes=64):
    out = torch.zeros((num_classes,), dtype=torch.float32)
    out[list_of_ints] = 1.
    return out
focus_states = np.zeros((num_games, 60, 8, 8), dtype=np.float32)
focus_valid_moves = torch.zeros((num_games, 60, 64), dtype=torch.float32)
for i in (range(num_games)):
    board = OthelloBoardState()
    for j in range(60):
        board.umpire(focus_games_string[i, j].item())
        focus_states[i, j] = board.state
        focus_valid_moves[i, j] = one_hot(board.get_valid_moves())
print("focus states:", focus_states.shape)
print("focus_valid_moves", focus_valid_moves.shape)


focus states: (50, 60, 8, 8)
focus_valid_moves torch.Size([50, 60, 64])


A cache of every model activation and the logits

In [19]:
model.use_attn_result = True
focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1].cuda())


## Using the probe

The training of this probe was kind of a mess, and I'd do a bunch of things differently if doing it again.

<details><summary>Info dump of technical details:</summary>

mode==0 was trained on black to play, ie odd moves, and the classes are \[blank, white, black\] ie \[blank, their colour, my colour\] (I *think*, they could easily be the other way round. This should be easy to sanity check)

mode==1 was trained on white to play, ie even moves, and the classes are \[blank, black, white\] ie \[blank, their colour, my colour\] (I think*)*

mode==2 was trained on all moves, and just doesn't work very well.


The probes were trained on moves 5 to 54 (ie not the early or late moves, because these are weird). I literally did AdamW against cross-entropy loss for each board cell, nothing fancy. You really didn't need to train on 4M games lol, it plateaued well before the end. Which is to be expected, it's just logistic regression!

</details>

But it works!


Let's load in the probe. The shape is [modes, d_model, row, col, options]. The 3 modes are "black to play/odd moves", "white to play/even moves", and "all moves". The 3 options are empty, white and black in that order.

We'll just focus on the black to play probe - it basically just works for the even moves too, once you realise that it's detecting my colour vs their colour!

This means that the options are "empty", "their's" and "mine" in that order

In [20]:
full_linear_probe = torch.load(OTHELLO_ROOT/"main_linear_probe.pth")

On move 29 in game 0, we can apply the probe to the model's residual stream after layer 6. Move 29 is black to play.

In [21]:
rows = 8
cols = 8 
options = 3
black_to_play_index = 0
white_to_play_index = 1
blank_index = 0
their_index = 1
my_index = 2
linear_probe = torch.zeros(cfg.d_model, rows, cols, options, device="cuda")
linear_probe[..., blank_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 0] + full_linear_probe[white_to_play_index, ..., 0])
linear_probe[..., their_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 1] + full_linear_probe[white_to_play_index, ..., 2])
linear_probe[..., my_index] = 0.5 * (full_linear_probe[black_to_play_index, ..., 2] + full_linear_probe[white_to_play_index, ..., 1])

In [22]:
layer = 4
game_index = 1
move = 29
def plot_probe_outputs(layer, game_index, move, **kwargs):
    residual_stream = focus_cache["resid_post", layer][game_index, move]
    print("residual_stream", residual_stream.shape)
    probe_out = einops.einsum(residual_stream, linear_probe, "d_model, d_model row col options -> row col options")
    probabilities = probe_out.softmax(dim=-1)
    plot_square_as_board(probabilities, facet_col=2, facet_labels=["P(Empty)", "P(Their's)", "P(Mine)"], height=400, width=600, **kwargs)
plot_probe_outputs(layer, game_index, move, title="Example probe outputs after move 29 (black to play)")

plot_single_board(int_to_label(focus_games_int[game_index, :move+1]))

residual_stream torch.Size([512])


In [23]:
layer = 4
game_index = 1
move = 29
def plot_probe_outputs(layer, game_index, move, **kwargs):
    residual_stream = focus_cache["resid_post", layer][game_index, move]
    print("residual_stream", residual_stream.shape)
    probe_out = einops.einsum(residual_stream, linear_probe, "d_model, d_model row col options -> row col options")
    probabilities = probe_out.softmax(dim=-1)
    plot_square_as_board(probabilities, facet_col=2, facet_labels=["P(Empty)", "P(Their's)", "P(Mine)"], height=400, width=600, **kwargs)
plot_probe_outputs(layer, game_index, move, title="Example probe outputs after move 29 (black to play)")

plot_single_board(int_to_label(focus_games_int[game_index, :move+1]))

residual_stream torch.Size([512])


I got the best results intervening after layer 4, but interestingly here the model has *almost* figured out the board state by then, but is missing the fact that C5 and C6 are black and is confident that they're white. My guess is that the board state calculation circuits haven't quite finished and are doing some iterative reasoning - if those cells have been taken several times, maybe it needs a layer to track the next earliest time it was taken? I don't know, and figuring this out would be a great starter project if you want to explore!

In [24]:
plot_probe_outputs(layer=4, game_index=game_index, move=move, title="Example probe outputs at layer 4 after move 29 (Black to play)")

residual_stream torch.Size([512])


We can now look at move 30, and we see that the representations totally flip! We also see that the model gets the corner wrong - it's not a big deal, but interesting!

In [25]:
plot_probe_outputs(layer=6, game_index=game_index, move=30, title="Example probe outputs at layer 4 after move 30 (White to play)")

plot_single_board(focus_games_string[game_index, :31])

residual_stream torch.Size([512])


Fascinatingly, the white to play probe gets the corner right! A fact about Othello is that a piece in the corners can never be flanked and thus will never change colour once placed - perhaps the model has decided to cut corners and have a different and less symmetric circuit for these?

In [26]:
residual_stream = focus_cache["resid_post", layer][game_index, 30]
white_to_play_probe = full_linear_probe[1]
probe_out = einops.einsum(residual_stream, white_to_play_probe, "d_model, d_model row col options -> row col options")
probabilities = probe_out.softmax(dim=-1)
plot_square_as_board(probabilities, facet_col=2, facet_labels=["P(Empty)", "P(White)", "P(Black)"], title="Probabilities after move 30 for the white to play probe - correct A7 label!")

### Computing Accuracy

Hopefully I've convinced you anecdotally that a linear probe works. But to be more rigorous, let's check accuracy across our 50 games.

In [27]:
def state_stack_to_one_hot(state_stack):
    one_hot = torch.zeros(
        state_stack.shape[0], # num games
        state_stack.shape[1], # num moves
        8, # rows
        8, # cols
        3, # the two options
        device=state_stack.device,
        dtype=torch.int,
    )
    one_hot[..., 0] = state_stack == 0 # empty
    one_hot[..., 1] = state_stack == -1 # white
    one_hot[..., 2] = state_stack == 1 # black
    
    return one_hot

# We first convert the board states to be in terms of my (+1) and their (-1)
alternating = np.array([-1 if i%2 == 0 else 1 for i in range(focus_games_int.shape[1])])
flipped_focus_states = focus_states * alternating[None, :, None, None]

# We now convert to one hot
focus_states_flipped_one_hot = state_stack_to_one_hot(torch.tensor(flipped_focus_states))

# Take the argmax
focus_states_flipped_value = focus_states_flipped_one_hot.argmax(dim=-1)

Apply the probe to the residual stream for every move! (Taken after layer 6)

In [28]:
probe_out = einops.einsum(focus_cache["resid_post", 6], linear_probe, "game move d_model, d_model row col options -> game move row col options")
probe_out_value = probe_out.argmax(dim=-1)


Take the average accuracy across all games and the middle moves (5:-5) we see extremely low error rate on black to play moves, and fairly low error rate when applying a zero shot transfer to all moves (by flipping the labels) - and that it's worse near corners!

In [29]:
correct_middle_odd_answers = (probe_out_value.cpu() == focus_states_flipped_value[:, :-1])[:, 5:-5:2]
accuracies_odd = einops.reduce(correct_middle_odd_answers.float(), "game move row col -> row col", "mean")
correct_middle_answers = (probe_out_value.cpu() == focus_states_flipped_value[:, :-1])[:, 5:-5]
accuracies = einops.reduce(correct_middle_answers.float(), "game move row col -> row col", "mean")

plot_square_as_board(1 - torch.stack([accuracies_odd, accuracies], dim=0), title="Average Error Rate of Linear Probe", facet_col=0, facet_labels=["Black to Play moves", "All Moves"], zmax=0.25, zmin=-0.25)

## Intervening with the probe

One of the really exciting consequences of a linear probe is that it gives us a set of interpretable directions in the residual stream! And with this, we can not only interpret the model's representations, but we can also intervene in the model's reasoning. This is a good proof of concept that if you can *really* understand a model, you can get precise and detailed control over its behaviour.

The first step is to convert our probe to meaningful directions. Each square's probe has 3 vectors, but the logits go into a softmax, which is translation invariant, so this only has two degrees of freedom. A natural-ish way to convert it into two vectors is taking `blank - (mine + their's)/2` giving a "is this cell empty or not" direction and `mine - their's` giving a "conditional on being blank, is this my colour vs their's" direction.

Having a single meaningful direction is important, because it allows us to interpret a feature or intervene on it. The original three directions has one degree of freedom, so each direction is arbitrary on its own.

In [30]:
blank_probe = linear_probe[..., 0] - linear_probe[..., 1] * 0.5 - linear_probe[..., 2] * 0.5
my_probe = linear_probe[..., 2] - linear_probe[..., 1]

In [31]:
pos = 20
game_index = 0
moves = focus_games_string[game_index, :pos+1]
plot_single_board(moves)
state = torch.zeros((64,), dtype=torch.float32, device="cuda") - 10.
state[stoi_indices] = focus_logits[game_index, pos].log_softmax(dim=-1)[1:]


We now flip cell F4 from black to white. This makes D2 into a legal move and G4 into an illegal move.

In [32]:
cell_r = 5
cell_c = 4
print(f"Flipping the color of cell {'ABCDEFGH'[cell_r]}{cell_c}")

board = OthelloBoardState()
board.update(moves.tolist())
board_state = board.state.copy()
valid_moves = board.get_valid_moves()
flipped_board = copy.deepcopy(board)
flipped_board.state[cell_r, cell_c] *= -1
flipped_valid_moves = flipped_board.get_valid_moves()

newly_legal = [string_to_label(move) for move in flipped_valid_moves if move not in valid_moves]
newly_illegal = [string_to_label(move) for move in valid_moves if move not in flipped_valid_moves]
print("newly_legal", newly_legal)
print("newly_illegal", newly_illegal)


Flipping the color of cell F4
newly_legal ['D2']
newly_illegal ['G4']


We can now intervene on the model's residual stream using the "my colour vs their colour" direction. I get the best results intervening after layer 4. This is a **linear intervention** - we are just changing a single dimension of the residual stream and keeping the others unchanged. This is a fairly simple intervention, and it's striking that it works!

I apply the fairly janky technique of taking current coordinate in the given direction, negating it, and then multiply by a hyperparameter called `scale` (scale between 1 and 8 tends to work best - small isn't enough and big tends to break things). I haven't tried hard to optimise this and I'm sure it can be improved! Eg by replacing the model's coordinate by a constant rather than scaling it. I also haven't dug into the best scale parameters, or which ones work best in which contexts - plausibly different cells have different activation scales on their world models and need different behaviour! 

In [33]:
flip_dir = my_probe[:, cell_r, cell_c]

big_flipped_states_list = []
layer = 4
scales = [0, 1, 2, 4, 8, 16]

for scale in scales:

    def flip_hook(resid, hook):
        # print(resid.shape) # [1, 21, 512] # batch, pos, emb
        coeff = resid[0, pos] @ flip_dir/flip_dir.norm()
        print(coeff.shape)
        # if coeff.item() > 0:
        resid[0, pos] -= (scale+1) * coeff * flip_dir/flip_dir.norm()

    flipped_logits = model.run_with_hooks(focus_games_int[game_index:game_index+1, :pos+1],
                        fwd_hooks=[
                        #  ("blocks.3.hook_resid_post", flip_hook),
                        (f"blocks.{layer}.hook_resid_post", flip_hook),
                        #  ("blocks.5.hook_resid_post", flip_hook),
                        #  ("blocks.6.hook_resid_post", flip_hook),
                        #  ("blocks.7.hook_resid_post", flip_hook),
                        ]
                        ).log_softmax(dim=-1)[0, pos]

    flip_state = torch.zeros((64,), dtype=torch.float32, device="cuda") - 10.
    flip_state[stoi_indices] = flipped_logits.log_softmax(dim=-1)[1:]
    big_flipped_states_list.append(flip_state)
flip_state_big = torch.stack(big_flipped_states_list)
state_big = einops.repeat(state, "d -> b d", b=6)
color = torch.zeros((len(scales), 64)).cuda() + 0.2
for s in newly_legal:
    color[:, to_string(s)] = 1
for s in newly_illegal:
    color[:, to_string(s)] = -1
scatter(y=state_big, x=flip_state_big, title=f"Original vs Flipped {string_to_label(8*cell_r+cell_c)} at Layer {layer}", xaxis="Flipped", yaxis="Original", hover=[f"{r}{c}" for r in "ABCDEFGH" for c in range(8)], facet_col=0, facet_labels=[f"Translate by {i}x" for i in scales], color=color, color_name="Newly Legal", color_continuous_scale="Geyser")

torch.Size([])
torch.Size([])
torch.Size([])
torch.Size([])
torch.Size([])
torch.Size([])


# My Space

## Section 1: Current move not blank

In [34]:
layer = 1

for stage in ['resid_pre', 'attn_out', 'resid_mid', 'mlp_out', 'resid_post']:
    print(f'Using {stage}')
    for layer in range(1):
        acc = []
        # cell = [[0 for _ in range(8)] for _ in range(8)]
        state = torch.zeros(8, 8)
        state_count = torch.zeros(8, 8)
        # state.flatten()[stoi_indices] = w_out @ model.W_U[:, 1:]
        # plot_square_as_board(state, title=f"Output weights of Neuron L{layer}N{neuron} in the output logit basis", width=600)

        for game_index in range(50):
            for move in range(3, 59):
                current_int = focus_games_int[game_index, move]
                current_string = focus_games_string[game_index, move]
                current_label = int_to_label(current_int)
                residual_stream = focus_cache[f"{stage}", layer][game_index, move]
                # print("residual_stream", residual_stream.shape)

                probe_out = einops.einsum(residual_stream, linear_probe, "d_model, d_model row col options -> row col options")
                probabilities = probe_out.softmax(dim=-1)
                probe_out_value = probabilities.argmax(dim=-1)
                probe_out_value = torch.clamp(probe_out_value, max=1)

                if probe_out_value.flatten()[current_string] != 0: # which means current cell not blank
                    success = 1
                else:
                    success = 0
                
                acc.append(success)
                state.flatten()[current_string] += success
                state_count.flatten()[current_string] += 1

        state /= state_count
        # plot_square_as_board(state)
        print(f'Layer: {layer} Accuracy: {sum(acc)/len(acc)}')
    print()


Using resid_pre
Layer: 0 Accuracy: 0.99

Using attn_out
Layer: 0 Accuracy: 0.855

Using resid_mid
Layer: 0 Accuracy: 0.9957142857142857

Using mlp_out
Layer: 0 Accuracy: 1.0

Using resid_post
Layer: 0 Accuracy: 1.0



In [35]:
blank_probe = linear_probe[..., 0] - linear_probe[..., 1] * 0.5 - linear_probe[..., 2] * 0.5
my_probe = linear_probe[..., 2] - linear_probe[..., 1]

# Scale the probes down to be unit norm per cell
blank_probe_normalised = blank_probe / blank_probe.norm(dim=0, keepdim=True)
my_probe_normalised = my_probe / my_probe.norm(dim=0, keepdim=True)
# Set the center blank probes to 0, since they're never blank so the probe is meaningless
blank_probe_normalised[:, [3, 3, 4, 4], [3, 4, 3, 4]] = 0.

W_E = model.W_E[1:, :] # tokens(61->60), emb
W_U = model.W_U[:, 1:].T # tokens(61->60), emb

# blank probe
blank_probe_normalised_60 = einops.rearrange(blank_probe_normalised, 'd_model row col -> d_model (row col)')
blank_probe_normalised_60 = blank_probe_normalised_60[:, stoi_indices]
blank_probe_normalised_60.shape

probe_out = einops.einsum(W_E, blank_probe_normalised_60, "token d_model, d_model token -> token")

state = torch.zeros(8, 8, device="cuda")
state.flatten()[stoi_indices] = probe_out
plot_square_as_board(state, title='Cosine Similarity of Blank Probe and W_E', zmax=0.8, zmin=-0.8, width=600)

probe_out = einops.einsum(W_U, blank_probe_normalised_60, "token d_model, d_model token -> token")

state = torch.zeros(8, 8, device="cuda")
state.flatten()[stoi_indices] = probe_out
plot_square_as_board(state, title='Cosine Similarity of Blank Probe and W_U', zmax=0.8, zmin=-0.8, width=600)

# my probe
my_probe_normalised_60 = einops.rearrange(my_probe_normalised, 'd_model row col -> d_model (row col)')
my_probe_normalised_60 = my_probe_normalised_60[:, stoi_indices]
my_probe_normalised_60.shape

probe_out = einops.einsum(W_E, my_probe_normalised_60, "token d_model, d_model token -> token")

state = torch.zeros(8, 8, device="cuda")
state.flatten()[stoi_indices] = probe_out
plot_square_as_board(state, title='Cosine Similarity of My Probe and W_E', zmax=0.8, zmin=-0.8, width=600)

probe_out = einops.einsum(W_U, my_probe_normalised_60, "token d_model, d_model token -> token")

state = torch.zeros(8, 8, device="cuda")
state.flatten()[stoi_indices] = probe_out
plot_square_as_board(state, title='Cosine Similarity of My Probe and W_U', zmax=0.8, zmin=-0.8, width=600)

## Section 2: Blank/Not-Blank Board State

In [36]:
def state_stack_to_one_hot(state_stack):
    one_hot = torch.zeros(
        state_stack.shape[0], # num games
        state_stack.shape[1], # num moves
        8, # rows
        8, # cols
        3, # the two options
        device=state_stack.device,
        dtype=torch.int,
    )
    one_hot[..., 0] = state_stack == 0 # empty
    one_hot[..., 1] = state_stack == -1 # white
    one_hot[..., 2] = state_stack == 1 # black
    
    return one_hot
    
# We first convert the board states to be in terms of my (+1) and their (-1)
alternating = np.array([-1 if i%2 == 0 else 1 for i in range(focus_games_int.shape[1])])
flipped_focus_states = focus_states * alternating[None, :, None, None]

# We now convert to one hot
focus_states_flipped_one_hot = state_stack_to_one_hot(torch.tensor(flipped_focus_states))

# Take the argmax
focus_states_flipped_value = focus_states_flipped_one_hot.argmax(dim=-1)
# set maximum to 1, so now is only binary: 0 means blank, 1 means not blank
focus_states_flipped_value = torch.clamp(focus_states_flipped_value, max=1)

n_layers = 4
act_name_dict = {
    'resid_pre': 'Residual Pre',
    'attn_out': 'Attn Out',
    'resid_mid': 'Residual Mid',
    'mlp_out': 'MLP out',
    'resid_post': 'Residual Post'
}

for act in ['resid_pre', 'attn_out', 'resid_mid', 'mlp_out', 'resid_post']:
    accuracies_lst = []
    for layer in range(n_layers):
        probe_out = einops.einsum(focus_cache[act, layer], linear_probe, "game move d_model, d_model row col options -> game move row col options")
        probe_out_value = probe_out.argmax(dim=-1)
        # set maximum to 1
        probe_out_value = torch.clamp(probe_out_value, max=1)

        correct_middle_answers = (probe_out_value.cpu() == focus_states_flipped_value[:, :-1])[:, 5:-5]
        accuracies = einops.reduce(correct_middle_answers.float(), "game move row col -> row col", "mean")
        accuracies_lst.append(accuracies)    

    act_name = act_name_dict[act]
    plot_square_as_board(torch.round(1 - torch.stack(accuracies_lst, dim=0), decimals=2), title=f"Average Error Rate of Blank Probe (Blank vs Not Blank) using {act_name}", facet_col=0, facet_labels=[f'Layer {l}' for l in range(n_layers)], yaxis=act_name, zmax=0.25, zmin=-0.25, text_auto=True)


In [37]:
# Baseline: blind guess all cells being blank

# Take the argmax
focus_states_flipped_value = focus_states_flipped_one_hot.argmax(dim=-1)
focus_states_flipped_value = torch.clamp(focus_states_flipped_value, max=1)


probe_out = einops.einsum(focus_cache[act, layer], linear_probe, "game move d_model, d_model row col options -> game move row col options")
probe_out_value = probe_out.argmax(dim=-1)
# set maximum to 0, which means guessing all cell being blank
probe_out_value = torch.clamp(probe_out_value, max=0)

correct_middle_answers = (probe_out_value.cpu() == focus_states_flipped_value[:, :-1])[:, 5:-5]
accuracies = einops.reduce(correct_middle_answers.float(), "game move row col -> row col", "mean")

act_name = act_name_dict[act]
plot_square_as_board(torch.round(1 - torch.stack([accuracies], dim=0), decimals=2), title=f"Average Error Rate of Blank Probe (Blank vs Not Blank) Baseline", facet_col=0, zmax=0.25, zmin=-0.25, text_auto=True)


## Section 3: Attention Pattern

Look at average attention pattern

In [38]:
attnD = {}

for layer in range(8):
    tmp_attn = torch.mean(focus_cache[f'blocks.{layer}.attn.hook_pattern'][:, :, :10, :10], dim=0)
    attnD[f'Layer: {layer}'] = cv.attention.attention_patterns(tokens=[''], attention=tmp_attn)

for key, html in attnD.items():
    attnD[key] = f'<span style="background-color:yellow;">{key}</span>' + str(html) + '</div>'


display(HTML(''.join(attnD.values())))

## Section 4: MLP Layers
* the code and graph in this section is completely from Neel Nanda's notebook

In [39]:
game_index = 1
move = 20
layer = 4
plot_single_board(focus_games_string[game_index, :move+1])
plot_probe_outputs(layer, game_index, move)

residual_stream torch.Size([512])


In [40]:

imshow([(focus_cache["attn_out", l][game_index, move][:, None, None] * my_probe).sum(0) for l in range(layer+1)], facet_col=0, y=[i for i in "ABCDEFGH"], facet_name="Layer", title=f"Attention Layer Contributions to my vs their (Game {game_index} Move {move})", aspect="equal")
imshow([(focus_cache["mlp_out", l][game_index, move][:, None, None] * my_probe).sum(0) for l in range(layer+1)], facet_col=0, y=[i for i in "ABCDEFGH"], facet_name="Layer", title=f"MLP Layer Contributions to my vs their (Game {game_index} Move {move})", aspect="equal")

## Logit Lens

In [41]:
logits, cache = model.run_with_cache(torch.tensor([20, 21, 28, 23, 29, 35]))
for layer in range(8):
    res_post = cache[f'blocks.{layer}.hook_resid_post'][:, -1, :]
    scaled_res_post = cache.apply_ln_to_stack(res_post, layer=layer + 1, pos_slice=-1)
    custom_logits = scaled_res_post[0] @ model.W_U
    custom_logits.shape # [61]

    logit_vec = custom_logits
    log_probs = logit_vec.log_softmax(-1)
    log_probs = log_probs[1:]
    assert len(log_probs)==60

    temp_board_state = torch.zeros(64, device=logit_vec.device)
    # Set all cells to -15 by default, for a very negative log prob - this means the middle cells don't show up as mattering
    temp_board_state -= 13.
    temp_board_state[stoi_indices] = log_probs

    plot_square_as_board(temp_board_state.reshape(8, 8), zmax=0, diverging_scale=False, title=f"Logit Lens Layer {layer}")


## Attention Head Ablation
* zero ablating some head and see how it change the probe result
* not done

In [42]:
def ablate_attn_layer_output(activations, hook, head):
    # zero ablating hook_z
    activations[:, :, head, :] = torch.zeros(activations[:, :, head, :].shape).cuda()
    # attn_out[0, -1, :] = clean_cache["hook_z", layer][0, -1, :]
    return activations

def show_patched_probe(resid_mid, hook):
    residual_stream = resid_mid[0, -1]
    print("residual_stream", residual_stream.shape)
    probe_out = einops.einsum(residual_stream, linear_probe, "d_model, d_model row col options -> row col options")
    probabilities = probe_out.softmax(dim=-1)
    plot_square_as_board(probabilities, facet_col=2, facet_labels=["P(Empty)", "P(Their's)", "P(Mine)"], height=400, width=600)
    return resid_mid

layer = 0
head = (1,3,4,5,6)
game_index = 0
move = 5

# prompt = torch.tensor([20, 21])
prompt = focus_games_int[game_index, :move+1]
logits = model.run_with_hooks(
    prompt,
    fwd_hooks=[
        # ('blocks.7.attn.hook_z', partial(ablate_attn_layer_output, head=head)),
        # ('blocks.6.attn.hook_z', partial(ablate_attn_layer_output, head=head)),
        # ('blocks.5.attn.hook_z', partial(ablate_attn_layer_output, head=head)),
        # ('blocks.4.attn.hook_z', partial(ablate_attn_layer_output, head=head)),
        # ('blocks.3.attn.hook_z', partial(ablate_attn_layer_output, head=head)),
        # ('blocks.2.attn.hook_z', partial(ablate_attn_layer_output, head=head)),
        # ('blocks.1.attn.hook_z', partial(ablate_attn_layer_output, head=head)),
        ('blocks.0.attn.hook_z', partial(ablate_attn_layer_output, head=head)),
        (f'blocks.{layer}.hook_resid_mid', partial(show_patched_probe))
    ]
)

# logit_vec = logits[0, -1]
# log_probs = logit_vec.log_softmax(-1)
# log_probs = log_probs[1:]
# assert len(log_probs)==60

# temp_board_state = torch.zeros(64, device=logit_vec.device)
# # Set all cells to -15 by default, for a very negative log prob - this means the middle cells don't show up as mattering
# temp_board_state -= 13.
# temp_board_state[stoi_indices] = log_probs

plot_probe_outputs(layer, game_index, move)
plot_single_board(int_to_label(focus_games_int[game_index, :move+1]))
# plot_square_as_board(temp_board_state.reshape(8, 8), zmax=0, diverging_scale=False)


residual_stream torch.Size([512])


residual_stream torch.Size([512])
