Load libraries and variables.

In [None]:
import os
import json
import chess
import platform
import collections
import torch.utils.data
from utils import *
from config import *
from datetime import date
from model import ChessTransformer
from s03_encode_data import encode
from IPython.display import clear_output, Markdown
from s01_prepare_data import get_board_status

Initialize model and load checkpoint.

In [None]:
# Model
vocabulary = json.load(open(os.path.join(DATA_FOLDER, VOCAB_FILE), "r"))
vocab_sizes = dict()
for k in vocabulary:
    vocab_sizes[k] = len(vocabulary[k])
model = ChessTransformer(
    vocab_sizes=vocab_sizes,
    max_move_sequence_length=MAX_MOVE_SEQUENCE_LENGTH,
    d_model=D_MODEL,
    n_heads=N_HEADS,
    d_queries=D_QUERIES,
    d_values=D_VALUES,
    d_inner=D_INNER,
    n_layers=N_LAYERS,
    dropout=DROPOUT,
)
model = model.to(DEVICE)

# Checkpoint
checkpoint = torch.load(os.path.join(CHECKPOINT_FOLDER, FINAL_CHECKPOINT))
model.load_state_dict(checkpoint["model_state_dict"])
print("\nLoaded checkpoint.\n")

# Set model to compile
compiled_model = torch.compile(model, mode=COMPILE_MODE, dynamic=DYNAMIC_COMPILE)
compiled_model.eval()  # eval mode disables dropout

# "Move index to move" mapping
reverse_move_vocabulary = {v: k for k, v in vocabulary["output_sequence"].items()}

Define model and human moves.

In [None]:
def make_model_move(board):
    board_status = get_board_status(board)
    encoded_board_status = dict()
    for status in board_status:
        encoded_board_status[status] = torch.IntTensor(
            [encode(board_status[status], vocabulary=vocabulary[status])]
        ).to(DEVICE)
        if encoded_board_status[status].dim() == 1:
            encoded_board_status[status] = encoded_board_status[status].unsqueeze(0)
    moves = (
        torch.LongTensor([vocabulary["output_sequence"]["<move>"]]).unsqueeze(0).to(DEVICE)
    )
    lengths = torch.LongTensor([1]).unsqueeze(0).to(DEVICE)
    with torch.autocast(device_type=DEVICE.type, dtype=torch.float16, enabled=USE_AMP):
        predicted_moves = compiled_model(
            turns=encoded_board_status["turn"],
            white_kingside_castling_rights=encoded_board_status[
                "white_kingside_castling_rights"
            ],
            white_queenside_castling_rights=encoded_board_status[
                "white_queenside_castling_rights"
            ],
            black_kingside_castling_rights=encoded_board_status[
                "black_kingside_castling_rights"
            ],
            black_queenside_castling_rights=encoded_board_status[
                "black_queenside_castling_rights"
            ],
            can_claim_draw=encoded_board_status["can_claim_draw"],
            board_positions=encoded_board_status["board_position"],
            moves=moves,
            lengths=lengths,
        )  # (N, max_move_sequence_length, move_vocab_size)
    predicted_moves = predicted_moves[:, 0, :].squeeze()
    legal_moves = [str(m) for m in list(board.legal_moves)]
    _, model_move_indices = predicted_moves.topk(k=predicted_moves.shape[0])
    for model_move_index in model_move_indices.tolist():
        model_move = reverse_move_vocabulary[model_move_index]
        if model_move == "<loss>":
            clear_output(wait=True)
            display(Markdown("# I resign."))
            display(board)
            return board, "0-1" if board.turn else "1-0"
        if model_move == "<draw>":
            if board.can_claim_draw():
                clear_output(wait=True)
                display(Markdown(" # I claim a draw."))
                display(board)
                return board, "1/2-1/2"
        if model_move in legal_moves:
            clear_output(wait=True)
            if len(board.move_stack) == 0:
                display(Markdown(" #  I played ***%s***." % model_move))
            else:
                display(Markdown(" #  You played ***%s***. I played ***%s***." % (str(board.move_stack[-1]), model_move)))
            board.push_uci(model_move)
            if board.is_checkmate():
                clear_output(wait=True)
                display(Markdown("# I played ***%s***. I win! :)" % model_move))
                display(board)
                return board, "0-1" if board.turn else "1-0"
            display(board)
            return board, None

def make_human_move(board):
    legal_moves = [m.uci() for m in board.legal_moves]
    while True:
        human_move = input("What move would you like to play? (UCI notation; 'exit' and 'resign' are options.)")
        if human_move in legal_moves:
            board.push_uci(human_move)
            clear_output(wait=True)
            display(Markdown(" # You played ***%s***." % human_move))
            display(board)
            if board.is_checkmate():
                clear_output(wait=True)
                display(Markdown("# You played ***%s***. You win! :(" % human_move))
                display(board)
                return board, "0-1" if board.turn else "1-0"
            return board, None
        if human_move.lower() == "exit":
            clear_output(wait=True)
            display(Markdown("# You stopped playing."))
            display(board)
            return board, "0-1" if board.turn else "1-0"
        if human_move.lower() == "resign":
            clear_output(wait=True)
            display(Markdown("# You resigned."))
            display(board)
            return board, "0-1" if board.turn else "1-0"
        if human_move.lower() == "draw":
            if board.can_claim_draw():
                clear_output(wait=True)
                display(Markdown("# You claimed a draw."))
                display(board)
                return board, "1/2-1/2"
            else:
                clear_output(wait=True)
                display(Markdown("# You can't claim a draw right now."))
                display(board)
        else:
            clear_output(wait=True)
            display(Markdown("# ***%s*** isn't a valid move." % human_move))
            display(board)

Run the model once to trigger compilation. This takes a few moments.

In [None]:
%%capture
_, __ = make_model_move(chess.Board())

Play a game of chess with the model.

In [None]:
def play():
    human_color = None
    while human_color is None:
        human_color = input("Do you want to play white (w) or black (b)?")
        if human_color.lower() not in ["w", "b"]:
            human_color = None
    board = chess.Board()
    outcome = None
    if human_color.lower() == "w":
        display(Markdown("# Make the first move."))
        display(board)
        board, outcome = make_human_move(board)
    while outcome is None and not board.is_game_over():
        board, outcome = make_model_move(board)
        if outcome is outcome is None and not board.is_game_over():
            board, outcome = make_human_move(board)

    return board, outcome, human_color

board, outcome, human_color = play()

Print a short summary of the game if you wish to save it.

In [None]:
def get_pgn():
    game = chess.pgn.Game()

    # Undo all moves
    switchyard = collections.deque()
    while board.move_stack:
        switchyard.append(board.pop())

    game.setup(board)
    node = game

    # Replay all moves
    while switchyard:
        move = switchyard.pop()
        node = node.add_variation(move)
        board.push(move)

    game.headers["Result"] = outcome if outcome else board.result()
    game.headers["Event"] = "You vs. Chess Transformer"
    game.headers["Site"] = platform.node()
    game.headers["Date"] = date.today().strftime("%Y/%m/%d")
    game.headers["Round"] = "1"
    game.headers["White"] = "You" if human_color == "w"  else "Chess Transformer"
    game.headers["Black"] = "You" if human_color == "b"  else "Chess Transformer"

    return game

print(get_pgn())