In [None]:
import itertools
import pickle

import torch
from leela_interp import Lc0sight, LeelaBoard
from leela_interp.tools.attention import attention_attribution, top_k_attributions
from leela_interp.tools.patching import activation_patch

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Lc0sight("lc0.onnx", device=device)

In [None]:
with open("interesting_puzzles.pkl", "rb") as f:
    puzzles = pickle.load(f)
len(puzzles)

In [None]:
def find_sacrifice_puzzles(puzzle):
    moves_to_play = [move for i, move in enumerate(puzzle.principal_variation) if i % 2 == 0] 
    squares_to_play_on = [x[2:4] for x in moves_to_play]
    return all(x == squares_to_play_on[0] for x in squares_to_play_on)

sacrifice_puzzles_idx = puzzles.apply(find_sacrifice_puzzles, axis=1)
sacrifice_puzzles = puzzles[sacrifice_puzzles_idx]

In [None]:
sacrifice_puzzles

Let's look at one of these puzzles:

In [None]:
puzzle = sacrifice_puzzles.iloc[1]
board = LeelaBoard.from_puzzle(puzzle)
board

The *principal variation* is the best sequence of moves for both sides:

In [None]:
puzzle.principal_variation

Leela solves this puzzle correctly:

In [None]:
model.pretty_play(board)

# Visualizing attention patterns

Next, let's look at some attention patterns. These are 64 x 64 arrays, with one entry for each pair of squares.

In [None]:
layer = 9
head = 5

# We're using nnsight to cache activations and do interventions. There's also an interface
# based directly on pytorch hooks if you prefer that, see Lc0Model.capturing().
with model.trace(board):
    attention = model.attention_scores(layer).output[0, head].save()
attention.shape

We can plot slices of this attention pattern:

In [None]:
square = "b3"
# This converts a square in chess notation to the index inside Leela's activations for
# that square. Note that the input to Leela is flipped depending on the current player's
# color.
idx = board.sq2idx(square)
# attention has shape (query_dim, key_dim); indexing into the first one gives us a slice
# of the attention pattern with fixed query.
board.plot(attention[idx], caption=f"L{layer}H{head} attention with query={square}")

# Attention attribution

Let's look at L12H12 instead and do attention attribution (this is basically approximating a zero-ablation of individual attention weights). We'll then plot the entries with the highest attribution scores as arrows from key to query (i.e. in the direction of information flow).

In [None]:
attribution = attention_attribution(
    [board], layer=12, head=12, model=model, return_pt=True
)[0]
values, colors = top_k_attributions(attribution, board, k=5)
board.plot(arrows=colors)

In [None]:
query_square = puzzle.principal_variation[0][2:4]
key_square = puzzle.principal_variation[2][2:4]
query_idx = board.sq2idx(query_square)
key_idx = board.sq2idx(key_square)

with model.trace(board):
    model.attention_scores(12).output[0, 12, query_idx, key_idx] = 0
    output = model.output.save()

probs = model.logits_to_probs(board, output[0])[0]
policy = model.top_moves(board, probs, top_k=5)
print(policy)
print("WDL:", output[1])

The previous top move, Ng6, is now in 4th place at only 16%. Leela also thinks it's worse (the win probability is down to 9.9% from 28.7%).

# Activation patching

Finally, let's do activation patching. Every puzzle in our dataset already has a "corrupted version" that we automatically generated. This is a very similar board position, but with a slight difference that makes the tactic no longer work. Note the new pawn on h6:

In [None]:
corrupted_board = LeelaBoard.from_fen(puzzle.corrupted_fen)
display(corrupted_board)
model.pretty_play(corrupted_board)

We could implement activation patching fairly easily with `nnsight`, but we'll instead introduce our patching helper function. Let's patch the output of L12H12 on every square:

In [None]:
log_odds_reductions = -activation_patch(
    module_func=model.headwise_attention_output,
    # Layer, head, output square:
    locations=list(itertools.product([12], [12], range(64))),
    model=model,
    # We could also pass in board and corrupted_board manually instead
    puzzles=puzzle,
)
log_odds_reductions.shape

In [None]:
board.plot(log_odds_reductions, caption="Log odds reduction for each square")

As we can see, activation patching essentially only has a big effect on g6, where L12H12 moved information to from h4.

# Next steps
We demonstrated how to use several mechanistic interpretability techniques on a single board position. It's fairly straightforward to extend these to batches of positions, see the files in `scripts` for examples. `nnsight` also makes it quite easy to use other interpretability techniques that we didn't cover here.