In [1]:
from safetensors.torch import load_file
import torch
# Load the safetensors weights
weights = load_file("chess_model_3k.safetensors", device='cuda')


In [8]:
# Piece tensors (EMB_D must match!)
piece_tensors = torch.nn.ParameterList(
    [torch.nn.Parameter(weights[f"piece_{i}"]) for i in range(15)]
)

# Unpack if needed
E, WR, WN, WB1, WQ, WK, WB2, WP, BP, BR, BN, BB1, BQ, BK, BB2 = piece_tensors
piece_label_to_tensor = {
    0: E,  1: WR, 2: WN, 3: WB1, 4: WQ, 5: WK, 6: WB2, 7: WP,
    8: BP, 9: BR, 10: BN, 11: BB1, 12: BQ, 13: BK, 14: BB2,
}

# Other weights
turn_weights = torch.nn.Parameter(weights["turn_weights"])
queryW       = torch.nn.Parameter(weights["queryW"])
keyW         = torch.nn.Parameter(weights["keyW"])
valueW       = torch.nn.Parameter(weights["valueW"])
rope_sin     = weights["rope_sin"]
rope_cos     = weights["rope_cos"]


In [9]:
def build_rope_tables(seq_len: int, dim: int, device):
    """
    Return tensors  sin, cos  of shape  [seq_len, dim//2].
    dim must be even.  These are kept out of autograd.
    """
    half = dim // 2
    inv_freq = 1.0 / (10000 ** (torch.arange(half, device=device) / half))
    ang = torch.arange(seq_len, device=device).float().unsqueeze(1) * inv_freq[None, :]
    return ang.sin(), ang.cos()          # each  [seq_len, half]


def apply_rope(x, sin, cos):
    """
    x   : [B, 64, D]
    sin : [64, D//2]   cos : [64, D//2]
    Rotates every (even,odd) channel pair in‑place and returns x.
    """
    sin = sin.unsqueeze(0)               # → [1, 64, D//2] for broadcast
    cos = cos.unsqueeze(0)

    x_even = x[..., 0::2]                # [B, 64, D//2]
    x_odd  = x[..., 1::2]                # [B, 64, D//2]

    rot_even = x_even * cos - x_odd * sin
    rot_odd  = x_even * sin + x_odd * cos

    # write back
    x[..., 0::2] = rot_even
    x[..., 1::2] = rot_odd
    return x

# ──────────────────────────────────────────────
# 4.  Helper fns
# ──────────────────────────────────────────────
def board_to_tensor(board_ids):
    # board_ids: list/tuple length 64 of ints 0‑14
    return torch.stack([piece_label_to_tensor[i] for i in board_ids])  # [64, EMB_D]

def apply_turn_mask(board_ids, board_tensor, turn):
    # only multiply pieces that belong to the side to move
    mask = [(1 <= p <= 7) if turn == "white" else (8 <= p <= 14) for p in board_ids]
    if any(mask):
        board_tensor = board_tensor.clone()
        board_tensor[mask] = board_tensor[mask] @ turn_weights
    return board_tensor

def minmax_norm(t):
    return (t - t.min()) / (t.max() - t.min() + 1e-6)


In [10]:
from torch.nn import functional as F

def model_move(board_ids, turn):
    board_tensor = board_to_tensor(board_ids)
    board_tensor = apply_turn_mask(board_ids, board_tensor, turn)
    board_tensor = minmax_norm(board_tensor).to('cuda').unsqueeze(0)  # [1, 64, EMB_D]

    Q = minmax_norm(board_tensor @ queryW)
    K = minmax_norm(board_tensor @ keyW)

    Q = apply_rope(Q, rope_sin, rope_cos)
    K = apply_rope(K, rope_sin, rope_cos)

    V = minmax_norm(board_tensor @ valueW)         # [1, 64, 64]

    attn = torch.bmm(Q, K.transpose(1, 2))          # [1, 64, 64]
    logits = (attn + V).view(1, -1)                 # [1, 4096]
    probs = F.softmax(logits, dim=-1).squeeze()

    top_idx = torch.argmax(probs).item()
    row, col = divmod(top_idx, 64)
    return row, col


In [11]:
import chess

# Mapping from python-chess pieces to your integer labels
piece_symbol_to_label = {
    None: 0,       # Empty square
    "R": 1,  "N": 2,  "B": 3,  "Q": 4,  "K": 5,  "B2": 6, "P": 7,   # White
    "p": 8,  "r": 9,  "n":10,  "b":11, "q":12, "k":13, "b2":14      # Black
}

def convert_board_to_tensor_indices(board: chess.Board):
    board_ids = []
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece is None:
            board_ids.append(0)
        else:
            symbol = piece.symbol()
            if symbol.lower() == "b":
                # Hack for double bishop slots if you want (optional)
                square_color = (square // 8 + square % 2) % 2  # just for diversity
                label = 3 if piece.color == chess.WHITE and square_color == 0 else \
                        6 if piece.color == chess.WHITE else \
                        11 if square_color == 0 else 14
            else:
                label = piece_symbol_to_label[symbol]
            board_ids.append(label)
    return board_ids  # length 64

import chess

def decode_model_move(board: chess.Board, from_idx: int, to_idx: int) -> str:
    from_square = chess.square(from_idx % 8, 7 - from_idx // 8)  # file, rank
    to_square   = chess.square(to_idx % 8, 7 - to_idx // 8)

    move = chess.Move(from_square, to_square)
    if move in board.legal_moves:
        return move.uci()
    else:
        # fallback: pick the first legal move
        print(f"Illegal move predicted: {move.uci()} — falling back.")
        return list(board.legal_moves)[0].uci()



In [6]:
import chess

board = chess.Board()

while not board.is_game_over():
    print(board)

    if board.turn == chess.WHITE:
        move = input("Your move (e.g., e2e4): ")
        try:
            board.push_uci(move)
        except:
            print("Invalid move. Try again.")
    else:
        board_ids = convert_board_to_tensor_indices(board)
        row, col = model_move(board_ids, "black")  # your inference fn
        ai_move = decode_model_move(board, row, col)
        print("AI move:", ai_move)
        board.push_uci(ai_move)

print("Game over:", board.result())


r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
P P P P P P P P
R N B Q K B N R
r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . P . . .
P P P P . P P P
R N B Q K B N R
Illegal move predicted: d1d1 — falling back.
AI move: g8h6
r n b q k b . r
p p p p p p p p
. . . . . . . n
. . . . . . . .
. . . . . . . .
. . . . P . . .
P P P P . P P P
R N B Q K B N R
r n b q k b . r
p p p p p p p p
. . . . . . . n
. . . . . . . .
. . . . P . . .
. . . . . . . .
P P P P . P P P
R N B Q K B N R
Illegal move predicted: h3g3 — falling back.
AI move: h8g8
r n b q k b r .
p p p p p p p p
. . . . . . . n
. . . . . . . .
. . . . P . . .
. . . . . . . .
P P P P . P P P
R N B Q K B N R
