In [1]:
import datetime as dt
from dataclasses import dataclass
from pathlib import Path
from typing import Callable

import einops
import huggingface_hub as hf
import numpy as np
import torch as t
import wandb
from datasets import load_dataset
from tqdm import tqdm
from transformer_lens import HookedTransformer, HookedTransformerConfig
from eindex import eindex

from othello_gpt.data.vis import plot_game
from othello_gpt.model.nanoGPT import GPT, GPTConfig
from othello_gpt.util import (
    convert_nanogpt_to_transformer_lens_weights,
    get_all_squares,
    pad_batch,
)

In [2]:
root_dir = Path().cwd().parent.parent.parent
data_dir = root_dir / "data"
probe_dir = data_dir / "probes"
probe_dir.mkdir(parents=True, exist_ok=True)

hf.login((root_dir / "secret.txt").read_text())
wandb.login()

size = 6
all_squares = get_all_squares(size)
dataset_dict = load_dataset("awonga/othello-gpt")

device = t.device(
    "mps"
    if t.backends.mps.is_available()
    else "cuda"
    if t.cuda.is_available()
    else "cpu"
)

[34m[1mwandb[0m: Currently logged in as: [33malfredwong[0m ([33malfredwong-university-of-cambridge[0m). Use [1m`wandb login --relogin`[0m to force relogin


Resolving data files:   0%|          | 0/42 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/42 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/42 [00:00<?, ?it/s]

In [3]:
class HubGPT(GPT, hf.PyTorchModelHubMixin):
    pass


nano_cfg = GPTConfig(
    # block_size=(size * size - 4) * 2 - 1,
    block_size=(size * size - 4) - 1,
    # vocab_size=size * size - 4 + 2,  # pass and pad
    vocab_size=size * size - 4 + 1,  # pad
    n_layer=2,
    n_head=2,
    n_embd=128,
    dropout=0.0,
    bias=True,
)
hooked_cfg = HookedTransformerConfig(
    n_layers=nano_cfg.n_layer,
    d_model=nano_cfg.n_embd,
    n_ctx=nano_cfg.block_size,
    d_head=nano_cfg.n_embd // nano_cfg.n_head,
    n_heads=nano_cfg.n_head,
    d_vocab=nano_cfg.vocab_size,
    act_fn="gelu",
    normalization_type="LN",
    device=device,
)

model = HubGPT.from_pretrained("awonga/othello-gpt", config=nano_cfg).to(device)
state_dict = convert_nanogpt_to_transformer_lens_weights(
    model.state_dict(), nano_cfg, hooked_cfg
)
model = HookedTransformer(hooked_cfg)
model.load_and_process_state_dict(state_dict)
model.to(device)

number of parameters: 0.40M
Moving model to device:  mps


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

In [4]:
def theirs_empty_mine_target(batch):
    boards = t.tensor(batch["boards"])[:, :-1]
    boards[:, 1::2] *= -1
    boards += 1
    return boards


def legality_target(batch):
    legal = t.tensor(batch["legalities"])[:, 1:]
    return legal


def forward_probe(linear_probe, batch, target_fn):
    input_ids = pad_batch(batch["input_ids"], max_len=model.cfg.n_ctx + 1).to(device)
    _, cache = model.run_with_cache(
        input_ids[:, :-1],
        names_filter=lambda name: "hook_resid_p" in name or "ln_final.hook_scale" in name,
    )
    X = cache.accumulated_resid(apply_ln=True)
    y = target_fn(batch)

    preds = einops.einsum(X, linear_probe, "layer batch n_ctx d_model, d_model row col d_probe layer -> layer batch n_ctx row col d_probe")
    log_probs = preds.log_softmax(-1)

    correct_log_probs = eindex(log_probs, y, "layer batch n_ctx rows cols [batch n_ctx rows cols]")
    # loss = -einops.reduce(correct_log_probs, "layer batch n_ctx rows cols -> layer", "mean")
    loss = -correct_log_probs.mean()

    return log_probs, loss

In [5]:
## TRAIN LINEAR PROBES
# A linear probe maps residual vectors (n_batch, d_model) to e.g. board representations (n_batch, size, size)
# This helps us to discover interpretable directions in activation space

# Key concepts:
#  - training linear probes
#  - causal interventions
#  -


@dataclass
class LinearProbeTrainingArgs:
    n_epochs: int = 8
    lr: float = 5e-4
    batch_size: int = 1024
    n_steps_per_epoch: int = 500
    betas: tuple[float, float] = (0.9, 0.99)
    weight_decay: float = 1e-3
    use_wandb: bool = True
    wandb_project: str | None = "othello-gpt-probe"
    wandb_name: str | None = None
    warmup_steps: int = 100


def test_linear_probe(test_dataset, test_y, linear_probe, target_fn):
    with t.inference_mode():
        test_y_pred, test_loss = forward_probe(
            linear_probe, test_dataset, target_fn
        )
    test_accs = ((test_y_pred > np.log(0.5)).argmax(-1) == test_y)
    test_accs = einops.reduce(test_accs.float(), "layer batch pos row col -> layer", "mean")
    test_accs = test_accs.cpu().round(decimals=4)
    return test_loss, test_accs


def train_linear_probe(
    model: HookedTransformer,
    args: LinearProbeTrainingArgs,
    target_fn: Callable,
):
    n_test = 1000
    test_dataset = dataset_dict["test"].take(n_test)
    test_y = target_fn(test_dataset).to(device)
    d_probe = test_y.max().item() + 1
    n_probes = model.cfg.n_layers + 1

    linear_probe = t.randn((model.cfg.d_model, size, size, d_probe, n_probes)) / np.sqrt(
        model.cfg.d_model
    )
    linear_probe = linear_probe.to(device)
    linear_probe.requires_grad = True

    test_loss, test_accs = test_linear_probe(test_dataset, test_y, linear_probe, target_fn)

    batch_indices = t.randint(
        0,
        len(dataset_dict["train"]),
        (args.n_epochs, args.n_steps_per_epoch, args.batch_size),
    )

    optimizer = t.optim.AdamW(
        [linear_probe], lr=args.lr, betas=args.betas, weight_decay=args.weight_decay
    )

    if args.use_wandb:
        wandb.init(project=args.wandb_project, name=args.wandb_name, config=args)

    step = 0
    for i in range(args.n_epochs):
        for j in (pbar := tqdm(range(args.n_steps_per_epoch))):
            batch = dataset_dict["train"].select(batch_indices[i, j, :])
            _, loss = forward_probe(linear_probe, batch, target_fn)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            pbar.set_description(f"{loss=:.4f} {test_accs=}")
            if args.use_wandb and step >= args.warmup_steps:
                wandb.log({"train_loss": loss}, step=step)
            step += 1

        test_loss, test_accs = test_linear_probe(test_dataset, test_y, linear_probe, target_fn)

        if args.use_wandb:
            wandb.log({"eval_loss": test_loss}, step=step)
            wandb.log({f"eval_acc_{i}": test_accs[i].item() for i in range(n_probes)}, step=step)

    if args.use_wandb:
        wandb.finish()

    print(test_accs)

    return linear_probe


args = LinearProbeTrainingArgs()
# args = LinearProbeTrainingArgs(
#     use_wandb=False, n_epochs=2, n_steps_per_epoch=10, lr=1e-3
# )
linear_probe = train_linear_probe(model, args, theirs_empty_mine_target)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


loss=0.6706 test_accs=tensor([0.2451, 0.2479, 0.3015]): 100%|██████████| 500/500 [06:59<00:00,  1.19it/s]
loss=0.5463 test_accs=tensor([0.2872, 0.7606, 0.8224]): 100%|██████████| 500/500 [07:11<00:00,  1.16it/s]
loss=0.4844 test_accs=tensor([0.3582, 0.8513, 0.8872]): 100%|██████████| 500/500 [07:10<00:00,  1.16it/s]
loss=0.4471 test_accs=tensor([0.4518, 0.8817, 0.9067]): 100%|██████████| 500/500 [07:03<00:00,  1.18it/s]
loss=0.4237 test_accs=tensor([0.5233, 0.8934, 0.9150]): 100%|██████████| 500/500 [07:22<00:00,  1.13it/s]
loss=0.4088 test_accs=tensor([0.5663, 0.8985, 0.9190]): 100%|██████████| 500/500 [07:13<00:00,  1.15it/s]
loss=0.3979 test_accs=tensor([0.5907, 0.9010, 0.9215]): 100%|██████████| 500/500 [07:15<00:00,  1.15it/s]
loss=0.3904 test_accs=tensor([0.6043, 0.9023, 0.9231]): 100%|██████████| 500/500 [07:08<00:00,  1.17it/s]


0,1
eval_acc_0,▁▃▅▆▇███
eval_acc_1,▁▅▇█████
eval_acc_2,▁▅▇▇████
eval_loss,█▅▃▂▂▁▁▁
train_loss,██▇▅▅▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval_acc_0,0.6139
eval_acc_1,0.903
eval_acc_2,0.9241
eval_loss,0.38981
train_loss,0.39039


tensor([0.6139, 0.9030, 0.9241])


In [6]:
t.save(
    linear_probe,
    probe_dir / f"linear_probe_{dt.datetime.now().strftime('%Y%m%d_%H%M%S')}.pt",
)
# linear_probe = t.load(probe_dir / "linear_probe_20250201_084310.pt", weights_only=True)

In [7]:
n_focus = 50
focus_games = dataset_dict["test"].take(n_focus)
focus_input_ids = pad_batch(focus_games["input_ids"], max_len=model.cfg.n_ctx + 1).to(
    device
)
focus_logits, focus_cache = model.run_with_cache(focus_input_ids[:, :-1])
focus_logit_boards = t.full((n_focus, focus_logits.shape[1], size, size), 0.0)
focus_logit_boards.flatten(2)[..., all_squares] = focus_logits[..., 1:].detach().cpu()
focus_probs = focus_logits.softmax(-1)
focus_prob_boards = t.full((n_focus, focus_logits.shape[1], size, size), 0.0)
focus_prob_boards.flatten(2)[..., all_squares] = focus_probs[..., 1:].detach().cpu()

X = focus_cache.accumulated_resid()
y = theirs_empty_mine_target(focus_games)
y_pred, _ = forward_probe(linear_probe, focus_games, theirs_empty_mine_target)
y_pred = t.exp(y_pred)

In [11]:
layer = 1
test_index = 1
test_pred_model = {
    "boards": focus_prob_boards[test_index].detach().cpu(),
    "legalities": focus_games[test_index]["legalities"],
    "moves": focus_games[test_index]["moves"],
}
test_pred_theirs = {
    "boards": y_pred[layer, test_index, ..., 0].detach().cpu(),
    "legalities": y[test_index, ...].detach().cpu() == 0,
    "moves": focus_games[test_index]["moves"],
}
test_pred_empty = {
    "boards": y_pred[layer, test_index, ..., 1].detach().cpu(),
    "legalities": y[test_index, ...].detach().cpu() == 1,
    "moves": focus_games[test_index]["moves"],
}
test_pred_mine = {
    "boards": y_pred[layer, test_index, ..., 2].detach().cpu(),
    "legalities": y[test_index, ...].detach().cpu() == 2,
    "moves": focus_games[test_index]["moves"],
}

plot_game(focus_games[test_index], title="Ground truth board states and legal moves")
plot_game(
    test_pred_model,
    reversed=False,
    textcolor="red",
    hovertext=test_pred_model["boards"],
    title="Model predictions for legal moves",
)
plot_game(
    test_pred_theirs,
    reversed=False,
    textcolor="red",
    hovertext=test_pred_theirs["boards"],
    shift_legalities=False,
    title=f"Layer {layer} probe predictions for 'their' squares",
)
plot_game(
    test_pred_empty,
    reversed=False,
    textcolor="red",
    hovertext=test_pred_empty["boards"],
    shift_legalities=False,
    title=f"Layer {layer} probe predictions for empty squares",
)
plot_game(
    test_pred_mine,
    reversed=False,
    textcolor="red",
    hovertext=test_pred_mine["boards"],
    shift_legalities=False,
    title=f"Layer {layer} probe predictions for 'my' squares",
)

In [9]:
# No need for model to learn mine/theirs for last move because it's always the empty square! (because we filtered out pass games)

## ANALYSE NEURON ACTIVATIONS
# Identify direct circuits (direct logit attribution) e.g. neuron out -> W_out -> W_U
# Identify max activating datasets
# Identify statistically interesting neurons
# Decompompose entire logit components?
# Analyse neuron specificity (spectrum plots, maybe identify polysemanticity?)

# Probe flipped pieces?

# Is mapping a 128 dim space to a 6x6 board impressive?
# Can we construct a NN that does the same job?

In [12]:
import plotly.graph_objects as go

# linear_probe.shape

fig = go.Figure()
fig.add_trace(
    go.Heatmap(
        z=linear_probe[..., 1].flatten(1, -1).detach().cpu().T,
        colorscale="gray",
    )
)