In [1]:
import chess.svg
import einops
from pathlib import Path
import torch

from chess_gnn.models import ChessBERT
from chess_gnn.utils import PGNBoardHelper

In [2]:
ckpt = torch.load('/Users/ray/models/chess/bert/7b961c05-55bf-45a0-8090-1409a883d676/final.ckpt', map_location="cpu")
model = ChessBERT(**ckpt['hyper_parameters'])
model.load_state_dict(ckpt['state_dict'])


  ckpt = torch.load('/Users/ray/models/chess/bert/7b961c05-55bf-45a0-8090-1409a883d676/final.ckpt', map_location="cpu")
/Users/ray/miniconda3/envs/ChessGNN/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:209: Attribute 'block' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['block'])`.
/Users/ray/miniconda3/envs/ChessGNN/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:209: Attribute 'mask_handler' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['mask_handler'])`.


<All keys matched successfully>

In [3]:
pgn = PGNBoardHelper(Path('/Users/ray/Datasets/chess/Carlsen.pgn'))
pgn.get_game()
board_fens = pgn.get_board_fens()

boards_in = []
for i, board_fen in enumerate(board_fens):
    board = chess.Board(board_fen)
    boards_in.append(board)

In [4]:
from chess_gnn.utils import process_board_string
from chess_gnn.tokenizers import SimpleChessTokenizer

def prep_model_inputs(chess_board: chess.Board):
    tokenizer = SimpleChessTokenizer()
    board = process_board_string(str(chess_board))
    print(board)
    board_tokens = torch.Tensor(tokenizer.tokenize(board)).long().unsqueeze(0)
    whose_move = torch.Tensor([int(not chess_board.turn)]).long()
    
    return board_tokens, whose_move

In [5]:
def bert_mask(model: ChessBERT, board_tokens: torch.Tensor, whose_move: torch.Tensor):
    out = model.forward_mask(board_tokens, whose_move)
    mlm_preds = model.mlm_head(out['tokens'])
    
    return board_tokens, torch.argmax(mlm_preds, dim=-1), out['masked_token_labels']

In [6]:
from chess_gnn.configuration import LocalHydraConfiguration
untrained_model = ChessBERT.from_hydra_configuration(LocalHydraConfiguration('/Users/ray/Projects/ChessGNN/configs/bert/training/bert.yaml'))

/Users/ray/miniconda3/envs/ChessGNN/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:209: Attribute 'block' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['block'])`.
/Users/ray/miniconda3/envs/ChessGNN/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:209: Attribute 'mask_handler' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['mask_handler'])`.


In [8]:
tokenizer = SimpleChessTokenizer()
tokenizer.vocab

['.', 'B', 'K', 'N', 'P', 'Q', 'R', 'b', 'k', 'n', 'p', 'q', 'r']

In [7]:
labels, preds, masked = bert_mask(model, *prep_model_inputs(chess_board=boards_in[28]))

...r.rk.pp..qppp..pb.n......n.....P.......NQPN..PB...PPP...RR.K.


In [9]:
einops.rearrange(torch.eq(labels, preds), "1 (h w) -> h w", h=8)

tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])

In [10]:
torch.sum(~torch.eq(labels, preds))

tensor(0)

In [11]:
einops.rearrange(labels, "1 (h w) -> h w", h=8)

tensor([[ 0,  0,  0, 12,  0, 12,  8,  0],
        [10, 10,  0,  0, 11, 10, 10, 10],
        [ 0,  0, 10,  7,  0,  9,  0,  0],
        [ 0,  0,  0,  0,  9,  0,  0,  0],
        [ 0,  0,  4,  0,  0,  0,  0,  0],
        [ 0,  0,  3,  5,  4,  3,  0,  0],
        [ 4,  1,  0,  0,  0,  4,  4,  4],
        [ 0,  0,  0,  6,  6,  0,  2,  0]])

In [12]:
einops.rearrange(preds, "1 (h w) -> h w", h=8)

tensor([[ 0,  0,  0, 12,  0, 12,  8,  0],
        [10, 10,  0,  0, 11, 10, 10, 10],
        [ 0,  0, 10,  7,  0,  9,  0,  0],
        [ 0,  0,  0,  0,  9,  0,  0,  0],
        [ 0,  0,  4,  0,  0,  0,  0,  0],
        [ 0,  0,  3,  5,  4,  3,  0,  0],
        [ 4,  1,  0,  0,  0,  4,  4,  4],
        [ 0,  0,  0,  6,  6,  0,  2,  0]])

In [13]:
einops.rearrange(masked!=-100, "1 (h w) -> h w", h=8)

tensor([[False, False, False, False, False, False, False,  True],
        [False,  True, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False],
        [False, False, False, False,  True, False, False, False],
        [False, False, False, False, False, False, False, False],
        [False,  True, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False],
        [ True, False, False, False, False, False, False, False]])

In [99]:
torch.sum(masked!=-100)

tensor(10)