In [None]:
import torch
import numpy as np
from fancy_einsum import einsum
import chess
import numpy as np
import pickle
import logging
import plotly.graph_objects as go
from functools import partial

import chess_utils
import train_test_chess

There's a bunch of setup below to get some data in some tensors that we can feed to our model.

In [None]:
torch.set_grad_enabled(False)

In [None]:
# Flags to control logging
debug_mode = False
info_mode = True

if debug_mode:
    log_level = logging.DEBUG
elif info_mode:
    log_level = logging.INFO
else:
    log_level = logging.WARNING

# Configure logging
logging.basicConfig(level=log_level)
logger = logging.getLogger(__name__)

Here you can select which probe and model to use. By default, the model_setup.py downloads a lichess 8 layer model. We can then select a probe from saved_probes/. Ideally, this should also be a lichess probe. Then this code should auto populate parameters according to the probe's state dict.

To reproduce paper / blog post figures, set USE_16_LAYER to True and run model_setup.py on the lichess 16 layer model.

In [None]:
MODEL_DIR = "models/"
DATA_DIR = "data/"
PROBE_DIR = "linear_probes/"
SAVED_PROBE_DIR = "linear_probes/saved_probes/"
SPLIT = "test"

DEVICE = "cpu"
logger.info(f"Using device: {DEVICE}")
LAYER = 5
base_probe_name = "tf_lens_lichess_8layers_ckpt_no_optimizer_chess_piece_probe_layer_0.pth"


USE_16_LAYER = False

if USE_16_LAYER:
    LAYER = 11
    base_probe_name = "tf_lens_lichess_16layers_ckpt_no_optimizer_chess_piece_probe_layer_0.pth"

probe_to_test = base_probe_name.replace("layer_0", f"layer_{LAYER}")

num_games = 10
sample_size = 1
modes = 1

probe_file_location = f"{SAVED_PROBE_DIR}{probe_to_test}"
with open(probe_file_location, "rb") as f:
    state_dict = torch.load(f, map_location=torch.device(DEVICE))
    print(state_dict.keys())
    for key in state_dict.keys():
        if key != "linear_probe":
            print(key, state_dict[key])

    config = chess_utils.find_config_by_name(state_dict["config_name"])
    layer = state_dict["layer"]
    model_name = state_dict["model_name"]
    dataset_prefix = state_dict["dataset_prefix"]
    column_name = state_dict["column_name"]
    config.pos_start = state_dict["pos_start"]
    levels_of_interest = None
    if "levels_of_interest" in state_dict.keys():
        levels_of_interest = state_dict["levels_of_interest"]
    config.levels_of_interest = levels_of_interest
    indexing_function_name = state_dict["indexing_function_name"]
    n_layers = state_dict["n_layers"]
    

    split = SPLIT
    input_dataframe_file = f"{DATA_DIR}{dataset_prefix}{split}.csv"
    config = chess_utils.set_config_min_max_vals_and_column_name(
        config, input_dataframe_file, dataset_prefix
    )
    misc_logging_dict = {
        "split": split,
        "dataset_prefix": dataset_prefix,
        "model_name": model_name,
        "n_layers": n_layers,
    }

At the end of the below cell, we index at select 1 of the num_games. The reason we do this is that with a large number of games, storing all the resid_posts and state_stacks quickly grows to many gigabytes of VRAM.

In [None]:
probe_data = train_test_chess.construct_linear_probe_data(
    input_dataframe_file,
    dataset_prefix,
    n_layers,
    model_name,
    config,
    num_games,
    DEVICE,
)
if DEVICE == "cpu":
    probe_data.model.cpu()

game_of_interest = 3

game_length_in_chars = len(probe_data.board_seqs_string[0])


state_stacks_all_chars = chess_utils.create_state_stacks(probe_data.board_seqs_string[:num_games], config.custom_board_state_function)
logger.info(f"state_stack shape: {state_stacks_all_chars.shape}")
assert(state_stacks_all_chars.shape) == (modes, num_games, game_length_in_chars, config.num_rows, config.num_cols)
white_move_indices = probe_data.custom_indices[:num_games]
print(white_move_indices.shape)
num_white_moves = white_move_indices.shape[1]
assert(white_move_indices.shape) == (num_games, num_white_moves)


print("\nSelecting the game of interest")
print(probe_data.board_seqs_int.shape)
print(state_stacks_all_chars.shape)
print(white_move_indices.shape)
print(len(probe_data.board_seqs_string), len(probe_data.board_seqs_string[0]))

probe_data.board_seqs_int = probe_data.board_seqs_int[game_of_interest].unsqueeze(0)
probe_data.board_seqs_string = [probe_data.board_seqs_string[game_of_interest]]
probe_data.custom_indices = white_move_indices[game_of_interest].unsqueeze(0)
state_stacks_all_chars = state_stacks_all_chars[:, game_of_interest, :, :, :].unsqueeze(1)
white_move_indices = white_move_indices[game_of_interest].unsqueeze(0)

print(probe_data.board_seqs_int.shape)
print(state_stacks_all_chars.shape)
print(white_move_indices.shape)
print(len(probe_data.board_seqs_string), len(probe_data.board_seqs_string[0]))

Here is an explanation of all the data we just generated:

In [None]:
print(f"All pgn strings are of length {game_length_in_chars}")
print(f"For game {game_of_interest}, the pgn string is {probe_data.board_seqs_string[0]}")
print(f"Using our encode functions, it's represented as ints that are fed as input to the GPT model with shape {probe_data.board_seqs_int.shape}")
print(f"The first 30 characters of board_seqs_ints looks like this: {probe_data.board_seqs_int[:, :30]}")
print(f"state_stacks_all_chars contains the board state at every char index in the pgn string with shape {state_stacks_all_chars.shape}")
print(f"white_move_indices contains the index of every white move in the pgn string with shape {white_move_indices.shape}")
print(f"That means there are {num_white_moves} white moves in the game")
print(f"For example, in {probe_data.board_seqs_string[0][:14]}, the white move indices are {white_move_indices[:, :2]} (the indices of each period)")



Important note: At the bottom of the below cell, I currently am using softmax to view probe output probabilities. You can comment that out to view raw logits instead.

In this cell, we input the board_seqs_int to the GPT to obtain resid_post, the intermediate activations after our layer of interest. We index into resid_post using white_move_indices. These indexed resid_posts are then input to the linear probe, which outputs probe_out, a probability distribution for the state of every square on the board.

In [None]:
checkpoint = torch.load(probe_file_location, map_location=torch.device(DEVICE))
linear_probe = checkpoint["linear_probe"]
print(linear_probe.shape)


one_hot_range = config.max_val - config.min_val + 1

board_seqs_int = probe_data.board_seqs_int[:].to(DEVICE)
assert(board_seqs_int.shape) == (1, game_length_in_chars)

indexed_state_stacks = []

for batch_idx in range(sample_size):
    # Get the indices for the current batch
    dots_indices_for_batch = white_move_indices[batch_idx]

    # Index the state_stack for the current batch. Adding an unsqueeze operation to maintain the batch dimension.
    indexed_state_stack = state_stacks_all_chars[:, batch_idx:batch_idx+1, dots_indices_for_batch, :, :]

    # Append the result to the list
    indexed_state_stacks.append(indexed_state_stack)

# Concatenate the indexed state stacks along the second dimension (batch dimension)
# Since we're maintaining the batch dimension during indexing, we don't need to add it back in.
state_stack_white_moves = torch.cat(indexed_state_stacks, dim=1)

print("state stack shapes")
print(state_stack_white_moves.shape)
print(state_stacks_all_chars.shape)

with torch.inference_mode():
    _, cache = probe_data.model.run_with_cache(board_seqs_int[:, :-1], return_type=None)
    resid_post = cache["resid_post", layer][:, :]

# print(resid_post.shape)
assert(resid_post.shape) == (sample_size, game_length_in_chars - 1, linear_probe.shape[1])

# Initialize a list to hold the indexed state stacks
indexed_resid_posts = []

for batch_idx in range(sample_size):
    # Get the indices for the current batch
    dots_indices_for_batch = white_move_indices[batch_idx]

    # Index the state_stack for the current batch
    indexed_resid_post = resid_post[batch_idx, dots_indices_for_batch]

    # Append the result to the list
    indexed_resid_posts.append(indexed_resid_post)

# Stack the indexed state stacks along the first dimension
# This results in a tensor of shape [2, 61, 8, 8] (assuming all batches have 61 indices)
resid_post = torch.stack(indexed_resid_posts)
resid_post = resid_post.to(DEVICE)
print("Resid post", resid_post.shape)
probe_out = einsum(
    "batch pos d_model, modes d_model rows cols options -> modes batch pos rows cols options",
    resid_post,
    linear_probe,
)
probe_out = probe_out.log_softmax(-1)
print(f"Probe out shape: {probe_out.shape}")
assert(probe_out.shape) == (modes, sample_size, white_move_indices.shape[1], config.num_rows, config.num_cols, one_hot_range)

Here you can select which move you want to visualize (move_of_interest).

In [None]:
move_of_interest = 11
GAME_IDX = 0 # After refactoring to discard unused games, this is always 0
move_of_interest_index = white_move_indices[GAME_IDX][move_of_interest] # Used to select pgn strings
move_of_interest_state = state_stack_white_moves[0][GAME_IDX][move_of_interest]
print(move_of_interest_state.shape)
print(move_of_interest_state)

Now we one hot encode our move_of_interest and store it in move_of_interest_state_one_hot.

In [None]:
state_stacks_one_hot = chess_utils.state_stack_to_one_hot(modes, config.num_rows, config.num_cols, config.min_val, config.max_val, DEVICE, state_stack_white_moves)
print(state_stacks_one_hot.shape)
assert(state_stacks_one_hot.shape) == (modes, sample_size, num_white_moves, config.num_rows, config.num_cols, one_hot_range)
move_of_interest_state_one_hot = state_stacks_one_hot[0][GAME_IDX][move_of_interest]
print(move_of_interest_state_one_hot.shape)

We get the argmax of each square's probe probability distribution and store it in state_stacks_probe_outputs for easy graphing.

In [None]:
print(move_of_interest_state_one_hot.shape)
print(state_stacks_one_hot.shape)
state_stacks_probe_outputs = chess_utils.one_hot_to_state_stack(probe_out, config.min_val)
state_stacks_probe_outputs = torch.tensor(state_stacks_probe_outputs)
print(state_stacks_probe_outputs.shape)
assert(state_stacks_probe_outputs.shape) == (modes, sample_size, num_white_moves, config.num_rows, config.num_cols)
print(state_stacks_probe_outputs[0][GAME_IDX][move_of_interest])

Change blank_index, king_index, or pawn_index if you want to visualize the probe's view of other pieces. For example, if I want to see the black queen, I could set blank_index = -5 (refer to INT_TO_CHAR for the mapping)

In [None]:

INT_TO_CHAR = {
    -6: "\u265a",
    -5: "\u265b",
    -4: "\u265c",
    -3: "\u265d",
    -2: "\u265e",
    -1: "\u265f",
    0: ".",
    1: "\u2659",
    2: "\u2658",
    3: "\u2657",
    4: "\u2656",
    5: "\u2655",
    6: "\u2654",
}

# Mapping of integers to chess pieces
# I'm duplicating this from chess_utils.py for easy reference
PIECE_TO_ONE_HOT_MAPPING = {
    -6: 0,
    -5: 1,
    -4: 2,
    -3: 3,
    -2: 4,
    -1: 5,
    0: 6,
    1: 7,
    2: 8,
    3: 9,
    4: 10,
    5: 11,
    6: 12,
}

# Mapping of chess pieces to integers
PIECE_TO_INT = {
    chess.PAWN: 1,
    chess.KNIGHT: 2,
    chess.BISHOP: 3,
    chess.ROOK: 4,
    chess.QUEEN: 5,
    chess.KING: 6,
}

INT_TO_PIECE = {value: key for key, value in PIECE_TO_INT.items()}

BLANK_INDEX = PIECE_TO_ONE_HOT_MAPPING[0]
white_pawn_index = PIECE_TO_ONE_HOT_MAPPING[1]
black_king_index = PIECE_TO_ONE_HOT_MAPPING[-6]

def plot_board_state(board_state: torch.Tensor, clip_size: int = 200, show_scale: bool = False):
    # color scale: Black for -1, Gray for 0, White for 1
    # colorscale = [[0.0, 'black'], [0.5, 'gray'], [1.0, 'white']]
    colorscale = 'gray'
    if board_state.is_cuda:
        board_state = board_state.cpu()
    board_state = np.clip(board_state.numpy(), -clip_size, clip_size)

    # Create heatmap
    heatmap = go.Heatmap(z=board_state, colorscale=colorscale, showscale=show_scale)
    return heatmap

print(move_of_interest_state_one_hot[:, :, white_pawn_index])
# heatmap = plot_board_state(move_of_interest_state_one_hot[:, :, white_pawn_index], show_scale=True)

move_of_interest_probe_out = probe_out[0][0][move_of_interest]
print(move_of_interest_probe_out.shape)

heatmap = plot_board_state(move_of_interest_probe_out[:, :, white_pawn_index], show_scale=True)

# Define the layout
layout = go.Layout(
    title="Chess board white pawns",
    xaxis=dict(ticks='', nticks=8),
    yaxis=dict(ticks='', nticks=8),
    autosize=False,
    width=600,
    height=600
)

# Create figure and plot
fig = go.Figure(data=[heatmap], layout=layout)
fig.show()

In [None]:
def tensor_to_text(board_state: torch.Tensor) -> np.ndarray:
    # Create a mapping from numbers to characters
    # Update this mapping according to your requirements

    # Convert the tensor to numpy array for easier processing
    board_array = board_state.numpy()

    # Create an empty array with the same shape for text
    text_array = np.empty(board_array.shape, dtype=str)

    # Fill the text array with corresponding characters
    for i in range(board_array.shape[0]):
        for j in range(board_array.shape[1]):
            text_array[i, j] = INT_TO_CHAR.get(board_array[i, j], str(board_array[i, j]))

    return text_array

def plot_board_state_with_text(board_state: torch.Tensor):
    # Convert the tensor to a text matrix
    text_matrix = tensor_to_text(board_state)

    # Define the custom colorscale
    colorscale = [
        [0, 'white'],   # Negative values
        [0.49, 'white'],
        [0.5, 'grey'],  # Zero
        [0.51, 'white'],
        [1, 'white']    # Positive values
    ]


    # Create heatmap with text and custom colorscale
    heatmap = go.Heatmap(
        z=board_state.numpy(), 
        text=text_matrix, 
        showscale=False, 
        colorscale=colorscale,
        texttemplate="%{text}",
        textfont=dict(size=48) 
    )

    return heatmap
heatmap = plot_board_state_with_text(move_of_interest_state)

# Define the layout
layout = go.Layout(
    title="Chess board state with text",
    xaxis=dict(ticks='', nticks=8),
    yaxis=dict(ticks='', nticks=8),
    autosize=False,
    width=600,
    height=600
)

# Create figure and plot
fig = go.Figure(data=[heatmap], layout=layout)
fig.show()


def matrix_to_fen(board_state: torch.Tensor):
    matrix = tensor_to_text(board_state)
    fen_mapping = {
        '♔': 'K', '♕': 'Q', '♖': 'R', '♗': 'B', '♘': 'N', '♙': 'P',
        '♚': 'k', '♛': 'q', '♜': 'r', '♝': 'b', '♞': 'n', '♟': 'p',
        '.': '.'
    }
    fen_rows = []
    for row in matrix:
        fen_row = []
        empty_count = 0
        for cell in row:
            piece = fen_mapping[cell]
            if piece == '.':
                empty_count += 1
            else:
                if empty_count > 0:
                    fen_row.append(str(empty_count))
                    empty_count = 0
                fen_row.append(piece)
        if empty_count > 0:
            fen_row.append(str(empty_count))
        fen_rows.append(''.join(fen_row))
    return '/'.join(fen_rows[::-1])  # Join rows and reverse for FEN which starts from 8th rank

In [None]:
import chess
import chess.svg
import re

def write_svg_to_file(filename, content):
    with open(filename, "w") as file:
        file.write(content)
        
def get_fen_after_moves(moves):
    board = chess.Board()  # Initialize a new chess board
    # Remove move numbers
    moves_cleaned = re.sub(r'\d+\.', '', moves)
    moves_list = moves_cleaned.split()
    for move in moves_list[:2 * move_of_interest]:
        board.push_san(move)  # Apply each move to the board
    return board.fen()  # Return the FEN representation of the current board state

san_moves = probe_data.board_seqs_string[0][1:] # remove the ";"
print(san_moves)
fen_output = get_fen_after_moves(san_moves)
print(fen_output)
            
# board = chess.Board(fen_output)
# board_render = chess.svg.board(
#     board,
#     size=350,
# )

# write_svg_to_file(
#     "images/sample_svg_board.svg",
#     board_render,
# )

In [None]:
from plotly.subplots import make_subplots

move_of_interest_probe_out = probe_out[0][0][move_of_interest]
print(move_of_interest_probe_out.shape)

fig_rows = 4
fig_cols = 3
fig = make_subplots(rows=fig_rows, cols=fig_cols, subplot_titles=[
    "Ground truth blank squares", "Predicted blank squares", "Confidence gradient blank squares",
    "Ground truth white pawn positions", "Predicted white pawn positions", "Confidence gradient white pawn positions",
    "Ground truth black king position", "Predicted black king position", "Confidence gradient black king position",
    "Ground truth state", "Predicted board state", "Redundant probe output board state"
])


# Specify the size of each plot
plot_size = 400  # You can adjust this size

fig.add_trace(plot_board_state(move_of_interest_state_one_hot[:, :, BLANK_INDEX]), row=1, col=1)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, BLANK_INDEX], clip_size=2), row=1, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, BLANK_INDEX]), row=1, col=3)

fig.add_trace(plot_board_state(move_of_interest_state_one_hot[:, :, white_pawn_index]), row=2, col=1)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, white_pawn_index], clip_size=2), row=2, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, white_pawn_index], show_scale=True), row=2, col=3)

fig.add_trace(plot_board_state(move_of_interest_state_one_hot[:, :, black_king_index]), row=3, col=1)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, black_king_index], clip_size=2), row=3, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, black_king_index]), row=3, col=3)

fig.add_trace(plot_board_state_with_text(move_of_interest_state), row=4, col=1)
fig.add_trace(plot_board_state_with_text(state_stacks_probe_outputs[0][0][move_of_interest]), row=4, col=2)
fig.add_trace(plot_board_state_with_text(state_stacks_probe_outputs[0][0][move_of_interest]), row=4, col=2)

# Adjust the overall size of the figure
fig.update_layout(height=fig_rows * plot_size * 1.3, width=fig_cols * plot_size)
fig.update_annotations(dict(font=dict(size=18))) 



# Show the figure
fig.show()

This will check the percentage of squares in the sample (sample_size defaults to 1 game) where the ground truth matches the probe output.
I also do a round trip through all the transformations, which should match 100%.

In [None]:
def calculate_matching_percentage(state_stacks: torch.Tensor, probe_outputs: torch.Tensor) -> float:
    """
    Calculate the percentage of matching cells in two tensors.

    :param state_stacks: A tensor of shape [1, 1, 680, 8, 8].
    :param probe_outputs: A tensor of shape [1, 1, 680, 8, 8].
    :return: The percentage of cells that match.
    """
    # Element-wise comparison
    matches = state_stacks == probe_outputs

    # Count the number of matches
    num_matches = matches.sum().item()

    # Total number of elements
    total_elements = state_stacks.numel()

    # Calculate percentage
    percentage = (num_matches / total_elements) * 100
    print(f"Out of {total_elements} elements, {num_matches} matched, {percentage}%")

    return percentage
assert(state_stacks_probe_outputs.shape) == (state_stack_white_moves.shape)
print("Linear probe accuracy on all board squares in sample size:", calculate_matching_percentage(state_stack_white_moves, state_stacks_probe_outputs))

round_trip = chess_utils.one_hot_to_state_stack(chess_utils.state_stack_to_one_hot(modes, config.num_rows, config.num_cols, config.min_val,config.max_val, DEVICE, state_stack_white_moves), config.min_val)
round_trip = torch.tensor(round_trip)
print(round_trip.shape)
print(state_stack_white_moves.shape)
assert(round_trip.shape) == (modes, sample_size, num_white_moves, config.num_rows, config.num_cols)
assert(round_trip.shape) == state_stack_white_moves.shape
matching_percentage = calculate_matching_percentage(round_trip, state_stack_white_moves)
assert(matching_percentage == 100.0)
print(f"Round trip matching percentage: {matching_percentage}%")

Now, we can perform interventions on the model's internals and view the modified probe outputs. We can also verify the model produces legal moves under the modified state of the board.

First, we perform a sanity check to ensure that our interventions on model activations are working correctly. In this case, diff should roughly equal flip_dir.

Note that I'm only intervening on one layer here. By modifying the first for loop and training additional probes, we can easily intervene on an arbitrary amount of layers. If we were to intervene on multiple layers, we can only check that torch.allclose(diff, flip_dirs[layer], atol=1e-6) for the first layer that we intervene on.

In [None]:
probe_data.model.reset_hooks()

_, cache = probe_data.model.run_with_cache(board_seqs_int.to(DEVICE)[:, :-1], return_type=None)
resid_post = cache["resid_post", layer][:, :]

r = 0
c = 0

probe_names = {}
for i in range(layer, layer + 1):
    probe_names[i] = base_probe_name.replace("layer_0", f"layer_{i}")

probes = {}

# Use this to intervene on multiple layers
for layer, probe_name in probe_names.items():
    probe_file_location = f"{SAVED_PROBE_DIR}{probe_name}"
    checkpoint = torch.load(probe_file_location, map_location=torch.device(DEVICE))
    linear_probe = checkpoint["linear_probe"]
    probes[layer] = linear_probe


flip_dirs = {}

piece1 = BLANK_INDEX
piece2 = black_king_index

for layer, linear_probe in probes.items():
    piece1_probe = linear_probe[:, :, r, c, piece1].squeeze()
    piece2_probe = linear_probe[:, :, r, c, piece2].squeeze()
    flip_dir = piece2_probe - piece1_probe
    flip_dir.to(DEVICE)
    flip_dirs[layer] = flip_dir

def flip_hook(resid, hook, flip_dir: torch.Tensor):
    resid[GAME_IDX, :] -= flip_dir # NOTE: We could only intervene on a single position in the sequence, but there's no harm in intervening on all of them

probe_data.model.reset_hooks()

for layer, flip_dir in flip_dirs.items():
    temp_hook_fn = partial(flip_hook, flip_dir=flip_dir)
    hook_name = f"blocks.{layer}.hook_resid_post"
    probe_data.model.add_hook(hook_name, temp_hook_fn)

print(probe_data.model.cpu())
_, modified_cache = probe_data.model.run_with_cache(board_seqs_int.to(DEVICE)[:, :-1])
probe_data.model.reset_hooks()
modified_resid_post = modified_cache["resid_post", layer][:, :]

print(resid_post.shape)
print(modified_resid_post.shape)

diff = resid_post[GAME_IDX, 10, :] - modified_resid_post[GAME_IDX, 10, :]

assert torch.allclose(diff, flip_dirs[layer], atol=1e-6)
print("Flip hook test passed")

Next, we load the model's vocab.

In [None]:
with open("models/meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi = meta["stoi"]
itos = meta["itos"]
def encode_string(s: str) -> list[int]:
    """Encode a string into a list of integers."""
    return [stoi[c] for c in s]


def decode_list(l: list[int]) -> str:
    """Decode a list of integers into a string."""
    return "".join([itos[i] for i in l])

Next, we generate 10 characters using the model to determine the model's next move. Note that we are using argmax instead of a temperature based approach, so this will always return the most likely move.One annoying problem we deal with: In chess, the 0th row is at the bottom, which is how print(chess_board) displays everything. But, for our state stack (and any array), the 0th row is at the top.

Now, we get a pgn string up to the current move and convert it to a chess board. We use it to create an encoded model_input as well.

In [None]:
print(move_of_interest_state)

pgn_string = probe_data.board_seqs_string[GAME_IDX][:move_of_interest_index + 1]
model_input = encode_string(pgn_string)
model_input = torch.tensor(model_input).unsqueeze(0).to(DEVICE)
print(model_input.shape)
board = chess_utils.pgn_string_to_board(pgn_string)

print(board)
print(board.legal_moves)

board_render = chess.svg.board(
    board,
    size=350,
)

write_svg_to_file(
    "images/board.svg",
    board_render,
)

We generate a move using the model on the original board and check that the move is legal. Next, we determine which piece was moved, and which row / column the source square of the move was.

In [None]:
model_move = chess_utils.get_model_move(probe_data.model, meta, model_input)
model_move_san = board.parse_san(model_move)
assert model_move_san in board.legal_moves

moved_piece = board.piece_at(model_move_san.from_square)
moved_piece_int = PIECE_TO_INT[moved_piece.piece_type]
moved_piece_probe_index = PIECE_TO_ONE_HOT_MAPPING[moved_piece_int]
source_square = chess.square_name(model_move_san.from_square)


r, c = chess_utils.square_to_coordinate(model_move_san.from_square)
print(r, c)

print(f"Model move: {model_move_san}, moved piece: {moved_piece}, moved piece int: {moved_piece_int}, moved piece probe index: {moved_piece_probe_index}, source square: {source_square}")

Now, we create a modified board where the source square of the model's original move is blank.

In [None]:
modified_state_stack = state_stack_white_moves.clone()
modified_state_stack[0, GAME_IDX, move_of_interest, r, c] = 0
modified_move_of_interest_state = modified_state_stack[0, GAME_IDX, move_of_interest]
modified_state_stacks_one_hot = chess_utils.state_stack_to_one_hot(modes, config.num_rows, config.num_cols, config.min_val, config.max_val, DEVICE, modified_state_stack)
modified_move_of_interest_state_one_hot = modified_state_stacks_one_hot[0][GAME_IDX][move_of_interest]
modified_board = board.copy()
modified_board.set_piece_at(model_move_san.from_square, None)
print(modified_board)
print(modified_board.legal_moves)

board_render = chess.svg.board(
    modified_board,
    size=350,
)

write_svg_to_file(
    "images/modified_board.svg",
    board_render,
)

assert modified_move_of_interest_state_one_hot.shape == move_of_interest_state_one_hot.shape
assert modified_state_stack.shape == state_stack_white_moves.shape
assert modified_state_stacks_one_hot.shape == state_stacks_one_hot.shape

Next, we get flip_dir, which is a probe of piece * piece_coefficient - blank square * blank_coefficient. In practice, I find that it works best when blank_coefficient is 0. We subtract this flip_dir from the model's activations at every token. We generate 10 new characters using the model, and verify that the new move under this modified state is legal according to the modified state. We also save a copy of the modified activations and generate modified probe outputs.

In [None]:
_, cache = probe_data.model.run_with_cache(board_seqs_int.to(DEVICE)[:, :-1], return_type=None)
resid_post = cache["resid_post", layer][:, :]

flip_dirs = {}

piece1 = BLANK_INDEX
piece1_probe = linear_probe[:, :, r, c, piece1].squeeze()
piece2 = moved_piece_probe_index

for layer, linear_probe in probes.items():
    piece2_probe = linear_probe[:, :, r, c, piece2].squeeze()
    flip_dir = piece2_probe - piece1_probe
    flip_dir.to(DEVICE)
    flip_dirs[layer] = flip_dir

def flip_hook(resid, hook, flip_dir: torch.Tensor):
    # print(resid[0, move_of_interest_index, :].shape)
    # print(flip_dir.shape)
    # print(piece1_probe.shape)
    # left_side = torch.dot(resid[0, move_of_interest_index, :], piece1_probe) - 3.0
    # right_side = torch.dot(flip_dir, piece1_probe)
    # scale = left_side / right_side
    # print(scale)
    
    # # Calculate scale
    # scale = left_side / right_side
    piece_coefficient = 1.0
    blank_coefficient = 0.0
    blank_probe = probes[layer][:, :, r, c, BLANK_INDEX].squeeze()
    piece_probe = probes[layer][:, :, r, c, moved_piece_probe_index].squeeze()

    flip_dir = (piece_probe * piece_coefficient) - (blank_probe * blank_coefficient)
    flip_dir = flip_dir / flip_dir.norm()
    scale = 1.0
    resid[0, :] -= scale * flip_dir # NOTE: We could only intervene on a single position in the sequence, but there's no harm in intervening on all of them

probe_data.model.reset_hooks()

for layer, flip_dir in flip_dirs.items():
    temp_hook_fn = partial(flip_hook, flip_dir=flip_dir)
    hook_name = f"blocks.{layer}.hook_resid_post"
    probe_data.model.add_hook(hook_name, temp_hook_fn)
_, modified_cache = probe_data.model.run_with_cache(board_seqs_int.to(DEVICE)[:, :-1])
modified_board_model_move = chess_utils.get_model_move(probe_data.model, meta, model_input)
probe_data.model.reset_hooks()
modified_resid_post = modified_cache["resid_post", layer][:, :]


print(modified_board_model_move)
# modified_board_model_move_san = modified_board.parse_san(modified_board_model_move)
# assert modified_board_model_move_san in modified_board.legal_moves

In [None]:
print(flip_dirs[layer].shape)
print(resid_post.shape)
print(modified_resid_post.shape)

In [None]:
indexed_modified_resid_posts = []

for batch_idx in range(white_move_indices.size(0)):
    dots_indices_for_batch = white_move_indices[batch_idx]
    indexed_modified_resid_post = modified_resid_post[batch_idx, dots_indices_for_batch]
    indexed_modified_resid_posts.append(indexed_modified_resid_post)

# Stack the indexed state stacks along the first dimension
stacked_modified_resid_post = torch.stack(indexed_modified_resid_posts)
stacked_modified_resid_post = stacked_modified_resid_post.to(DEVICE)

assert stacked_modified_resid_post.shape == (sample_size, num_white_moves, linear_probe.shape[1])

modified_probe_out = einsum(
    "batch pos d_model, modes d_model rows cols options -> modes batch pos rows cols options",
    stacked_modified_resid_post,
    linear_probe,
)
modified_state_stacks_probe_outputs = chess_utils.one_hot_to_state_stack(modified_probe_out, config.min_val)
modified_state_stacks_probe_outputs = torch.tensor(modified_state_stacks_probe_outputs)

Now, we can graph the original and modified board states and probe outputs.

In [None]:
from plotly.subplots import make_subplots

move_of_interest_probe_out = probe_out[0][0][move_of_interest]
move_of_interest_probe_out_modified = modified_probe_out[0][0][move_of_interest]
print(move_of_interest_probe_out.shape)

fig_rows = 6
# fig_rows = 4
fig_cols = 3
fig = make_subplots(rows=fig_rows, cols=fig_cols, subplot_titles=[
    "Chess board blank squares", "Probe output blank squares clip=2", "Probe output blank squares no clipping",
    "Chess board original piece", "Probe output original piece clip=5", "Probe output original piece no clipping",
    "Modified chess board blank squares", "Probe output blank squares clip=2", "Probe output blank squares no clipping",
    "Modified chess board original piece", "Probe output original piece clip=5", "Probe output original piece no clipping",
    "Chess board state", "Probe output board state", "Redundant probe output board state",
    "Modified chess board state", "Probe output board state", "Redundant probe output board state"
])


# Specify the size of each plot
plot_size = 400  # You can adjust this size



fig.add_trace(plot_board_state(move_of_interest_state_one_hot[:, :, BLANK_INDEX]), row=1, col=1)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, BLANK_INDEX], clip_size=2), row=1, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, BLANK_INDEX]), row=1, col=3)

fig.add_trace(plot_board_state(move_of_interest_state_one_hot[:, :, moved_piece_probe_index]), row=2, col=1)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, moved_piece_probe_index], clip_size=5), row=2, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out[:, :, moved_piece_probe_index]), row=2, col=3)

fig.add_trace(plot_board_state(modified_move_of_interest_state_one_hot[:, :, BLANK_INDEX]), row=3, col=1)
fig.add_trace(plot_board_state(move_of_interest_probe_out_modified[:, :, BLANK_INDEX], clip_size=2), row=3, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out_modified[:, :, BLANK_INDEX]), row=3, col=3)

fig.add_trace(plot_board_state(modified_move_of_interest_state_one_hot[:, :, moved_piece_probe_index]), row=4, col=1)
fig.add_trace(plot_board_state(move_of_interest_probe_out_modified[:, :, moved_piece_probe_index], clip_size=5), row=4, col=2)
fig.add_trace(plot_board_state(move_of_interest_probe_out_modified[:, :, moved_piece_probe_index]), row=4, col=3)

fig.add_trace(plot_board_state_with_text(move_of_interest_state), row=5, col=1)
fig.add_trace(plot_board_state_with_text(state_stacks_probe_outputs[0][0][move_of_interest]), row=5, col=2)
fig.add_trace(plot_board_state_with_text(state_stacks_probe_outputs[0][0][move_of_interest]), row=5, col=2)

fig.add_trace(plot_board_state_with_text(modified_move_of_interest_state), row=6, col=1)
fig.add_trace(plot_board_state_with_text(modified_state_stacks_probe_outputs[0][0][move_of_interest]), row=6, col=2)
fig.add_trace(plot_board_state_with_text(modified_state_stacks_probe_outputs[0][0][move_of_interest]), row=6, col=2)

# Adjust the overall size of the figure
fig.update_layout(height=fig_rows * plot_size, width=fig_cols * plot_size)

# Show the figure
fig.show()

board = chess.Board(matrix_to_fen(move_of_interest_state))
board_render = chess.svg.board(board, size=350)
write_svg_to_file("images/chess_board_state.svg", board_render)

board = chess.Board(matrix_to_fen(state_stacks_probe_outputs[0][0][move_of_interest]))
board_render = chess.svg.board(board, size=350)
write_svg_to_file("images/probe_output_board_state.svg", board_render)

mod_board = chess.Board(matrix_to_fen(modified_move_of_interest_state))
board_render = chess.svg.board(mod_board, size=350)
write_svg_to_file("images/modified_chess_board_state.svg", board_render)

mod_board = chess.Board(matrix_to_fen(modified_state_stacks_probe_outputs[0][0][move_of_interest]))
board_render = chess.svg.board(mod_board, size=350)
write_svg_to_file("images/probe_output_modified_board_state.svg", board_render)