In [1]:
import chess.svg
import einops
from pathlib import Path
import torch

from chess_gnn.models import ChessBERT
from chess_gnn.utils import PGNBoardHelper

In [5]:
ckpt = torch.load('/Users/ray/models/chess/bert/7b961c05-55bf-45a0-8090-1409a883d676/final.ckpt', map_location="cpu")
model = ChessBERT(**ckpt['hyper_parameters'])
model.load_state_dict(ckpt['state_dict'])


  ckpt = torch.load('/Users/ray/models/chess/electra/c43db0f7-27b9-4b38-bc58-e82756a23447/last.ckpt', map_location="cpu")


In [None]:
pgn = PGNBoardHelper(Path('/Users/ray/Datasets/chess/Carlsen.pgn'))
board_fens = pgn.get_board_fens()

boards_in = []
for i, board_fen in enumerate(board_fens):
    board = chess.Board(board_fen)
    boards_in.append(board)

In [17]:
from chess_gnn.utils import process_board_string
from chess_gnn.tokenizers import SimpleChessTokenizer

def prep_model_inputs(chess_board: chess.Board):
    tokenizer = SimpleChessTokenizer()
    board = process_board_string(str(chess_board))
    print(board)
    board_tokens = torch.Tensor(tokenizer.tokenize(board)).long().unsqueeze(0)
    whose_move = torch.Tensor([int(not chess_board.turn)]).long()
    
    return board_tokens, whose_move

In [40]:
def bert_mask(model: ChessBERT, board_tokens: torch.Tensor, whose_move: torch.Tensor):
    out = model.forward_mask(board_tokens, whose_move)
    mlm_preds = model.mlm_head(out['tokens'])
    
    return board_tokens, torch.argmax(mlm_preds, dim=-1), out['masked_token_labels']

In [53]:
from chess_gnn.configuration import LocalHydraConfiguration
untrained_model = ChessBERT.from_hydra_configuration(LocalHydraConfiguration('/Users/ray/Projects/ChessGNN/configs/bert/training/bert.yaml'))

In [100]:
labels, preds, masked = bert_mask(model, *prep_model_inputs(chess_board=boards_in[28]))

r....rk..ppqbppp.nn.p.b.p..pP....P.P.B..P.P..N.P..QNBPP.R....RK.


In [101]:
einops.rearrange(torch.eq(labels, preds), "1 (h w) -> h w", h=8)

tensor([[ True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True, False,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True, False,  True, False,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True]])

In [102]:
torch.sum(~torch.eq(labels, preds))

tensor(3)

In [103]:
einops.rearrange(labels, "1 (h w) -> h w", h=8)

tensor([[12,  0,  0,  0,  0, 12,  8,  0],
        [ 0, 10, 10, 11,  7, 10, 10, 10],
        [ 0,  9,  9,  0, 10,  0,  7,  0],
        [10,  0,  0, 10,  4,  0,  0,  0],
        [ 0,  4,  0,  4,  0,  1,  0,  0],
        [ 4,  0,  4,  0,  0,  3,  0,  4],
        [ 0,  0,  5,  3,  1,  4,  4,  0],
        [ 6,  0,  0,  0,  0,  6,  2,  0]])

In [104]:
einops.rearrange(preds, "1 (h w) -> h w", h=8)

tensor([[12,  0,  0,  0,  0, 12,  8,  0],
        [ 0, 10, 10, 11,  7, 10, 10, 10],
        [ 0,  9,  9,  0, 10,  0,  0,  0],
        [10,  0,  0, 10,  4,  0,  0,  0],
        [ 0,  4,  0,  4,  0,  1,  0,  0],
        [ 4,  0,  4,  0,  0,  3,  0,  4],
        [ 0,  0,  0,  3,  5,  4,  4,  0],
        [ 6,  0,  0,  0,  0,  6,  2,  0]])

In [105]:
einops.rearrange(masked!=-100, "1 (h w) -> h w", h=8)

tensor([[False, False, False, False, False, False, False, False],
        [False, False,  True, False,  True,  True, False,  True],
        [False, False, False, False, False, False,  True, False],
        [False,  True,  True, False, False, False,  True, False],
        [False, False,  True, False, False, False, False,  True],
        [ True, False, False, False, False, False,  True, False],
        [False, False,  True, False,  True, False, False, False],
        [False, False, False, False, False, False, False, False]])

In [99]:
torch.sum(masked!=-100)

tensor(10)