In [15]:
import datetime as dt
from dataclasses import dataclass
from pathlib import Path
from typing import Callable

import einops
import huggingface_hub as hf
import numpy as np
import torch as t
import wandb
from datasets import load_dataset
from tqdm import tqdm
from transformer_lens import ActivationCache, HookedTransformer, HookedTransformerConfig

from othello_gpt.data.vis import plot_game
from othello_gpt.model.nanoGPT import GPT, GPTConfig
from othello_gpt.util import (
    convert_nanogpt_to_transformer_lens_weights,
    get_all_squares,
    pad_batch,
)

In [16]:
root_dir = Path().cwd().parent.parent.parent
data_dir = root_dir / "data"
probe_dir = data_dir / "probes"
probe_dir.mkdir(parents=True, exist_ok=True)

hf.login((root_dir / "secret.txt").read_text())
wandb.login()

True

In [17]:
device = t.device(
    "mps"
    if t.backends.mps.is_available()
    else "cuda"
    if t.cuda.is_available()
    else "cpu"
)
device

device(type='mps')

In [18]:
size = 6
all_squares = get_all_squares(size)

In [19]:
dataset_dict = load_dataset("awonga/othello-gpt")

Resolving data files:   0%|          | 0/21 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/21 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/21 [00:00<?, ?it/s]

In [20]:
class HubGPT(GPT, hf.PyTorchModelHubMixin):
    pass


nano_cfg = GPTConfig(
    # block_size=(size * size - 4) * 2 - 1,
    block_size=(size * size - 4) - 1,
    # vocab_size=size * size - 4 + 2,  # pass and pad
    vocab_size=size * size - 4 + 1,  # pad
    n_layer=8,
    n_head=8,
    n_embd=128,
    dropout=0.0,
    bias=True,
)
hooked_cfg = HookedTransformerConfig(
    n_layers=nano_cfg.n_layer,
    d_model=nano_cfg.n_embd,
    n_ctx=nano_cfg.block_size,
    d_head=nano_cfg.n_embd // nano_cfg.n_head,
    n_heads=nano_cfg.n_head,
    d_vocab=nano_cfg.vocab_size,
    act_fn="gelu",
    normalization_type="LN",
    device=device,
)

model = HubGPT.from_pretrained("awonga/othello-gpt", config=nano_cfg).to(device)
state_dict = convert_nanogpt_to_transformer_lens_weights(
    model.state_dict(), nano_cfg, hooked_cfg
)
model = HookedTransformer(hooked_cfg)
model.load_and_process_state_dict(state_dict)
model.to(device)

number of parameters: 1.59M
Moving model to device:  mps


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-7): 8 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resi

In [21]:
def empty_mine_target(batch):
    boards = t.tensor(batch["boards"])[:, :-1]
    empty = boards == 0
    boards_flipped = boards.clone()
    boards_flipped[:, 1::2] *= -1
    # mine = (boards_flipped + 1) / 2
    mine = boards_flipped == 1
    y = t.stack([empty, mine], dim=-1).float().to(device)
    return y


def empty_mine_legal_target(batch):
    boards = t.tensor(batch["boards"])[:, :-1]
    empty = boards == 0
    boards_flipped = boards.clone()
    boards_flipped[:, 1::2] *= -1
    # mine = (boards_flipped + 1) / 2
    mine = boards_flipped == 1
    legal_prev = t.tensor(batch["legalities"])[:, :-1]
    legal_next = t.tensor(batch["legalities"])[:, 1:]
    y = t.stack([empty, mine, legal_prev, legal_next], dim=-1).float().to(device)
    return y


def forward_probe(linear_probe, batch, layer, target_fn):
    input_ids = pad_batch(batch["input_ids"], max_len=model.cfg.n_ctx + 1).to(device)
    _, cache = model.run_with_cache(
        input_ids[:, :-1],
        stop_at_layer=layer+1,
        names_filter=lambda name: name.endswith("resid_post"),
    )

    X = cache[f"blocks.{layer}.hook_resid_post"]
    y = target_fn(batch)
    # y_pred = einops.einsum(X, linear_probe, "b n d, d r c o -> b n r c o")
    y_pred = (X @ linear_probe.flatten(1)).reshape(
        *X.shape[:2], *linear_probe.shape[1:]
    )
    loss = (t.square(y - y_pred)).mean()

    return y_pred, loss

In [8]:
## TRAIN LINEAR PROBES
# A linear probe maps residual vectors (n_batch, d_model) to e.g. board representations (n_batch, size, size)
# This helps us to discover interpretable directions in activation space

# Key concepts:
#  - training linear probes
#  - causal interventions
#  -


@dataclass
class LinearProbeTrainingArgs:
    n_epochs: int = 3
    lr: float = 5e-4
    batch_size: int = 512
    n_steps_per_epoch: int = 500
    betas: tuple[float, float] = (0.9, 0.99)
    weight_decay: float = 1e-3
    use_wandb: bool = True
    wandb_project: str | None = "othello-gpt-probe"
    wandb_name: str | None = None
    warmup_steps: int = 500


def train_linear_probe(
    model: HookedTransformer,
    layer: int,
    args: LinearProbeTrainingArgs,
    target_fn: Callable,
):
    test_dataset = dataset_dict["test"].take(1000)
    test_y = target_fn(test_dataset)
    probe_out_shape = test_y.shape[2:]

    linear_probe = t.randn((model.cfg.d_model, *probe_out_shape)) / np.sqrt(
        model.cfg.d_model
    )
    linear_probe = linear_probe.to(device)
    linear_probe.requires_grad = True

    with t.inference_mode():
        test_y_pred, test_loss = forward_probe(
            linear_probe, test_dataset, layer, target_fn
        )
    test_accs = ((test_y_pred > 0.5) == test_y).float().flatten(0, -2).mean(0)
    test_accs = test_accs.cpu().round(decimals=4)

    batch_indices = t.randint(
        0,
        len(dataset_dict["train"]),
        (args.n_epochs, args.n_steps_per_epoch, args.batch_size),
    )

    optimizer = t.optim.AdamW(
        [linear_probe], lr=args.lr, betas=args.betas, weight_decay=args.weight_decay
    )

    if args.use_wandb:
        wandb.init(project=args.wandb_project, name=args.wandb_name, config=args)

    step = 0
    for i in range(args.n_epochs):
        for j in (pbar := tqdm(range(args.n_steps_per_epoch))):
            batch = dataset_dict["train"].select(batch_indices[i, j, :])
            _, loss = forward_probe(linear_probe, batch, layer, target_fn)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            pbar.set_description(f"{loss=:.4f} {test_accs=}")
            if args.use_wandb and step >= args.warmup_steps:
                wandb.log({"train_loss": loss}, step=step)
            step += 1

        with t.inference_mode():
            test_y_pred, test_loss = forward_probe(
                linear_probe, test_dataset, layer, target_fn
            )
        test_accs = ((test_y_pred > 0.5) == test_y).float().flatten(0, -2).mean(0)
        test_accs = test_accs.cpu().round(decimals=4)

        if args.use_wandb:
            wandb.log({"eval_loss": test_loss}, step=step)
            wandb.log({f"eval_acc_{i}": acc for i, acc in enumerate(test_accs)}, step=step)

    if args.use_wandb:
        wandb.finish()

    print(test_accs)

    return linear_probe


args = LinearProbeTrainingArgs()
# args = LinearProbeTrainingArgs(
#     use_wandb=False, n_epochs=2, n_steps_per_epoch=100, lr=1e-3
# )
linear_probe = train_linear_probe(model, 5, args, empty_mine_legal_target)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


loss=0.5395 test_accs=tensor([0.5285, 0.4961, 0.5148, 0.4940]): 100%|██████████| 500/500 [08:41<00:00,  1.04s/it] 
loss=0.0466 test_accs=tensor([0.7700, 0.7326, 0.7053, 0.7286]): 100%|██████████| 500/500 [08:20<00:00,  1.00s/it]
loss=0.0456 test_accs=tensor([0.9967, 0.9514, 0.9123, 0.9638]): 100%|██████████| 500/500 [08:17<00:00,  1.00it/s]


0,1
eval_acc_0,▁██
eval_acc_1,▁██
eval_acc_2,▁██
eval_acc_3,▁██
eval_loss,█▁▁
train_loss,█▇▅▅▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
eval_acc_0,0.9968
eval_acc_1,0.952
eval_acc_2,0.9125
eval_acc_3,0.9655
eval_loss,0.04592
train_loss,0.04564


tensor([0.9968, 0.9520, 0.9125, 0.9655])


In [25]:
# t.save(
#     linear_probe,
#     probe_dir / f"linear_probe_{dt.datetime.now().strftime('%Y%m%d_%H%M%S')}.pt",
# )
linear_probe = t.load(probe_dir / "linear_probe_20250131_121102.pt", weights_only=True)

In [26]:
# mine + theirs + empty = 1
# (mine, theirs, empty) -> (empty, mine)
# mine 1 -> (0, 1), theirs -1 -> (0, 0), empty 0 -> (1, 0.5)
# should empty 0 -> (1, 0)? yes this boosts accuracy from 75% to 95%
# surely mine/theirs success => model tracks parity of seq len?

layer = 5
n_focus = 50
focus_games = dataset_dict["test"].take(n_focus)
focus_input_ids = pad_batch(focus_games["input_ids"], max_len=model.cfg.n_ctx + 1).to(
    device
)
focus_logits, cache = model.run_with_cache(focus_input_ids[:, :-1])
focus_logit_boards = t.full((n_focus, focus_logits.shape[1], size, size), 0.0)
focus_logit_boards.flatten(2)[..., all_squares] = focus_logits[..., 1:].detach().cpu()
# focus_probs = focus_logits.softmax(-1)

X = cache[f"blocks.{layer}.hook_resid_post"]
y = empty_mine_target(focus_games)
y_pred, _ = forward_probe(linear_probe, focus_games, layer, empty_mine_legal_target)

In [27]:
test_index = 1
test_pred_model = {
    "boards": focus_logit_boards[test_index].detach().cpu(),
    "legalities": focus_games[test_index]["legalities"],
    "moves": focus_games[test_index]["moves"],
}
test_pred_legal = {
    "boards": y_pred[test_index, ..., 3].detach().cpu(),
    "legalities": focus_games[test_index]["legalities"],
    "moves": focus_games[test_index]["moves"],
}
test_pred_empty = {
    "boards": y_pred[test_index, ..., 0].detach().cpu(),
    "legalities": y[test_index, ..., 0].detach().cpu(),
    "moves": focus_games[test_index]["moves"],
}
test_pred_mine = {
    "boards": y_pred[test_index, ..., 1].detach().cpu(),
    "legalities": y[test_index, ..., 1].detach().cpu() == 1,
    "moves": focus_games[test_index]["moves"],
}

plot_game(focus_games[test_index], title="Ground truth board states and legal moves")
plot_game(
    test_pred_model,
    reversed=False,
    textcolor="red",
    hovertext=test_pred_legal["boards"],
    title="Model predictions for legal moves",
)
plot_game(
    test_pred_legal,
    reversed=False,
    textcolor="red",
    hovertext=test_pred_legal["boards"],
    title=f"Layer {layer} probe predictions for legal moves",
)
plot_game(
    test_pred_empty,
    reversed=False,
    textcolor="red",
    hovertext=test_pred_empty["boards"],
    shift_legalities=False,
    title=f"Layer {layer} probe predictions for empty squares",
)
plot_game(
    test_pred_mine,
    reversed=False,
    textcolor="red",
    hovertext=test_pred_mine["boards"],
    shift_legalities=False,
    title=f"Layer {layer} probe predictions for 'my' squares",
)

In [12]:
## ANALYSE NEURON ACTIVATIONS
# Identify direct circuits (direct logit attribution) e.g. neuron out -> W_out -> W_U
# Identify max activating datasets
# Identify statistically interesting neurons
# Decompompose entire logit components?
# Analyse neuron specificity (spectrum plots, maybe identify polysemanticity?)

# Is mapping a 128 dim space to a 6x6 board impressive?
# Can we construct a NN that does the same job?