In [1]:
import torch
import chess
from models.transformer import Transformer

In [17]:
# thank you chatgpt (:

def board_to_tensor_vectorized(board: chess.Board) -> torch.Tensor:
    """
    Converts a python-chess board into an 8x8x12 PyTorch tensor using vectorized operations.

    Args:
        board (chess.Board): The chess board.

    Returns:
        torch.Tensor: An 8x8x12 tensor representing the board state.
    """
    # Initialize the empty tensor
    tensor = torch.zeros((8, 8, 12), dtype=torch.float32)

    # Get the piece map from python-chess
    piece_map = board.piece_map()  # Returns a dictionary {square: piece}

    # Extract square indices and corresponding pieces
    squares = torch.tensor(list(piece_map.keys()), dtype=torch.long)
    pieces = torch.tensor([piece.piece_type for piece in piece_map.values()], dtype=torch.long)
    colors = torch.tensor([piece.color for piece in piece_map.values()], dtype=torch.long)

    # Compute ranks (rows) and files (columns) from square indices
    ranks = squares // 8
    files = squares % 8

    # Compute the channels based on piece type and color
    channels = pieces - 1 + colors * 6  # Map pieces to channels: 0-5 (white), 6-11 (black)

    # Assign 1.0 to the corresponding positions in the tensor
    tensor[ranks, files, channels] = 1.0

    return tensor

# Example usage
board = chess.Board()
tensor = board_to_tensor_vectorized(board)
print(tensor.shape)  # Output: torch.Size([8, 8, 12])


torch.Size([8, 8, 12])


In [3]:
tensor

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [6]:
model = Transformer(
    piece_dim=12,
    board_size=64,
    n_blocks=4,
    n_heads=4,
    proj_dim=64,
)

In [7]:
model(tensor, is_white=True)

tensor([[0.4986]], grad_fn=<SigmoidBackward0>)