In [7]:
import chess.svg
import einops
import numpy as np
from pathlib import Path

from chess_gnn.inference import ChessBoardPredictor
from chess_gnn.models import ChessXAttnEngine, ChessBERT
from chess_gnn.utils import PGNBoardHelper, ChessPoint

In [2]:
class AttentionMapGetter:
    def __init__(self, attn_matrix: np.ndarray):
        self.attn_matrix = attn_matrix
    
    def get_attention_map(self, query: str):
        if query.lower() == 'cls':
            point_idx = 0
        else:
            point = ChessPoint.from_square(query)
            point_idx = point.to_str_position() + 1
    
        attention = np.flipud(einops.rearrange(self.attn_matrix.transpose()[point_idx][1:], '(h w) -> h w', h=8))
        return attention
        

In [3]:
import dash
from dash import dcc, html, Input, Output, State
import plotly.graph_objects as go
import chess
import chess.svg
import base64

def create_game_attention_app(boards: list[chess.Board], attn_maps: list[AttentionMapGetter]):
    assert len(boards) == len(attn_maps), "Boards and attention maps must be same length"

    app = dash.Dash(__name__)
    num_moves = len(boards)
    default_square = "a1"
    default_idx = 0

    def get_svg_images(board: chess.Board):
        images = []
        for idx in range(64):
            piece = board.piece_at(idx)
            point = ChessPoint.from_1d(idx)
            if piece:
                svg = chess.svg.piece(piece)
                svg_bytes = svg.encode('utf-8')
                uri = f"data:image/svg+xml;base64,{base64.b64encode(svg_bytes).decode('utf-8')}"
                images.append(dict(
                    source=uri,
                    xref="x", yref="y",
                    x=point.x,
                    y=point.y,
                    sizex=0.9, sizey=0.9,
                    xanchor="center", yanchor="middle",
                    layer="above"
                ))
        return images

    def create_figure(board: chess.Board, attention: np.ndarray, highlight_square: str | None = None):
        fig = go.Figure(data=go.Heatmap(
            z=attention,
            x=list(range(8)),
            y=list(range(8)),
            colorscale='Viridis',
            hoverongaps=False,
            opacity=0.5
        ))
        fig.update_layout(
            yaxis=dict(scaleanchor="x", scaleratio=1),
            xaxis=dict(constrain='domain'),
            images=get_svg_images(board),
            shapes=[]
        )
    
        if highlight_square and highlight_square != 'cls':
            # Calculate col and row of the square
            col, row = ChessPoint.from_square(highlight_square)
            # Plotly rect corners: x0,y0 is bottom left, x1,y1 is top right
            fig.update_layout(shapes=[dict(
                type="rect",
                x0=col - 0.5, y0=row - 0.5,
                x1=col + 0.5, y1=row + 0.5,
                line=dict(color="red", width=1),
                fillcolor="rgba(255,0,0,0.2)",
                layer="above"
            )])
        elif highlight_square == 'cls':
            # Add a border somewhere meaningful to indicate cls is selected
            fig.update_layout(shapes=[dict(
                type="rect",
                x0=-0.5, y0=-1,
                x1=7.5, y1=-0.6,
                line=dict(color="red", width=2),
                fillcolor="rgba(255,0,0,0.2)",
                layer="above"
            )])
    
        return fig

    app.layout = html.Div([
        html.Div([
            html.Button("Prev", id="prev-btn", n_clicks=0),
            html.Button("Next", id="next-btn", n_clicks=0),
            html.Span(id="move-label", style={"marginLeft": "1rem"}),
        ], style={"marginBottom": "1rem"}),

        dcc.Store(id="current-index", data=0),
        dcc.Store(id="current-square", data=default_square),

        dcc.Graph(id='heatmap', figure=create_figure(boards[default_idx], attn_maps[default_idx].get_attention_map(default_square))),
        
        html.Div([
            html.Button("CLS", id="cls-btn", n_clicks=0, style={"marginTop": "10px"})
        ]),
        
        html.Div(id='click-output')
    ])

    @app.callback(
        Output('current-index', 'data'),
        Output('move-label', 'children'),
        Input('prev-btn', 'n_clicks'),
        Input('next-btn', 'n_clicks'),
        State('current-index', 'data')
    )
    def update_index(prev, nxt, current):
        ctx = dash.callback_context.triggered_id
        if ctx == 'prev-btn':
            current = max(0, current - 1)
        elif ctx == 'next-btn':
            current = min(num_moves - 1, current + 1)
        return current, f"Move: {current}/{num_moves - 1}"

    @app.callback(
        Output('current-square', 'data'),
        Output('click-output', 'children'),
        Input('heatmap', 'clickData'),
        Input('cls-btn', 'n_clicks'),
        State('cls-btn', 'n_clicks_timestamp'),
        State('heatmap', 'clickData'),
        State('current-square', 'data'),
        prevent_initial_call=True
    )
    def update_square(clickData, cls_clicks, cls_ts, heatmap_click, current_square):
        ctx = dash.callback_context
        if ctx.triggered_id == 'cls-btn':
            return 'cls', "Viewing attention from: cls"
        elif ctx.triggered_id == 'heatmap' and clickData and 'points' in clickData:
            point = clickData['points'][0]
            col = int(point['x'])
            row = int(point['y']) + 1
            file = chr(ord('a') + col)
            rank = row
            square = f"{file}{rank}"
            return square, f"Viewing attention from: {square}"
        return current_square, f"Viewing attention from: {current_square}"

    @app.callback(
        Output('heatmap', 'figure'),
        Input('current-index', 'data'),
        Input('current-square', 'data')
    )
    def update_figure(move_idx, square):
        board = boards[move_idx]
        attention = attn_maps[move_idx].get_attention_map(square)
        return create_figure(board, attention, highlight_square=square)

    return app

In [8]:
pgn = PGNBoardHelper(Path('/Users/ray/Datasets/chess/Carlsen.pgn'))

model = ChessXAttnEngine.load_from_checkpoint(
    '/Users/ray/models/chess/engines/57797a5d-3f0c-4629-9db2-4fddf71cb7d7/last.ckpt')
encoder = model.get_encoder()
predictor = ChessBoardPredictor(encoder=encoder)

for i in range(2):
    pgn.get_game()
    
board_fens = pgn.get_board_fens()

boards_in = []
attns_in = []
for i, board_fen in enumerate(board_fens):
    board = chess.Board(board_fen)
    attn_matrix = predictor.get_attn_at_head_and_layer(chess_board=board, layer=11, head=15)
    attn = AttentionMapGetter(attn_matrix)
    boards_in.append(board)
    attns_in.append(attn)


Attribute 'encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder'])`.


Attribute 'decoder_layer' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['decoder_layer'])`.



In [9]:
app = create_game_attention_app(boards_in, attns_in)
app.run()

In [17]:
import torch

from chess_gnn.utils import process_board_string
from chess_gnn.tokenizers import SimpleChessTokenizer

def prep_model_inputs(chess_board: chess.Board):
    tokenizer = SimpleChessTokenizer()
    board = process_board_string(str(chess_board))
    print(board)
    board_tokens = torch.Tensor(tokenizer.tokenize(board)).long().unsqueeze(0)
    whose_move = torch.Tensor([int(not chess_board.turn)]).long()
    
    return board_tokens, whose_move

In [40]:
def bert_mask(model: ChessBERT, board_tokens: torch.Tensor, whose_move: torch.Tensor):
    out = model.forward_mask(board_tokens, whose_move)
    mlm_preds = model.mlm_head(out['tokens'])
    
    return board_tokens, torch.argmax(mlm_preds, dim=-1), out['masked_token_labels']

In [53]:
from chess_gnn.configuration import LocalHydraConfiguration
untrained_model = ChessBERT.from_hydra_configuration(LocalHydraConfiguration('/Users/ray/Projects/ChessGNN/configs/bert/training/bert.yaml'))

In [100]:
labels, preds, masked = bert_mask(model, *prep_model_inputs(chess_board=boards_in[28]))

r....rk..ppqbppp.nn.p.b.p..pP....P.P.B..P.P..N.P..QNBPP.R....RK.


In [101]:
einops.rearrange(torch.eq(labels, preds), "1 (h w) -> h w", h=8)

tensor([[ True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True, False,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True, False,  True, False,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True]])

In [102]:
torch.sum(~torch.eq(labels, preds))

tensor(3)

In [103]:
einops.rearrange(labels, "1 (h w) -> h w", h=8)

tensor([[12,  0,  0,  0,  0, 12,  8,  0],
        [ 0, 10, 10, 11,  7, 10, 10, 10],
        [ 0,  9,  9,  0, 10,  0,  7,  0],
        [10,  0,  0, 10,  4,  0,  0,  0],
        [ 0,  4,  0,  4,  0,  1,  0,  0],
        [ 4,  0,  4,  0,  0,  3,  0,  4],
        [ 0,  0,  5,  3,  1,  4,  4,  0],
        [ 6,  0,  0,  0,  0,  6,  2,  0]])

In [104]:
einops.rearrange(preds, "1 (h w) -> h w", h=8)

tensor([[12,  0,  0,  0,  0, 12,  8,  0],
        [ 0, 10, 10, 11,  7, 10, 10, 10],
        [ 0,  9,  9,  0, 10,  0,  0,  0],
        [10,  0,  0, 10,  4,  0,  0,  0],
        [ 0,  4,  0,  4,  0,  1,  0,  0],
        [ 4,  0,  4,  0,  0,  3,  0,  4],
        [ 0,  0,  0,  3,  5,  4,  4,  0],
        [ 6,  0,  0,  0,  0,  6,  2,  0]])

In [105]:
einops.rearrange(masked!=-100, "1 (h w) -> h w", h=8)

tensor([[False, False, False, False, False, False, False, False],
        [False, False,  True, False,  True,  True, False,  True],
        [False, False, False, False, False, False,  True, False],
        [False,  True,  True, False, False, False,  True, False],
        [False, False,  True, False, False, False, False,  True],
        [ True, False, False, False, False, False,  True, False],
        [False, False,  True, False,  True, False, False, False],
        [False, False, False, False, False, False, False, False]])

In [99]:
torch.sum(masked!=-100)

tensor(10)