In [1]:
import numpy as np
import torch

In [None]:
from datasets import load_dataset

# Load using streaming mode
dataset = load_dataset("Lichess/standard-chess-games", split="train", streaming=True)

# Get a small sample of 10 games
sample = []
for i, row in enumerate(dataset):
    sample.append(row)
    if i == 10000:
        break

In [2]:
WP = torch.randn(1000, requires_grad=True, device='cuda')

WR = torch.randn(1000, requires_grad=True, device='cuda')

WN = torch.randn(1000, requires_grad=True, device='cuda')

WB1 = torch.randn(1000, requires_grad=True, device='cuda')
WB2 = torch.randn(1000, requires_grad=True, device='cuda') 

WQ = torch.randn(1000, requires_grad=True, device='cuda')
WK = torch.randn(1000, requires_grad=True, device='cuda')

BP = torch.randn(1000, requires_grad=True, device='cuda')

BR = torch.randn(1000, requires_grad=True, device='cuda')

BN = torch.randn(1000, requires_grad=True, device='cuda')

BB1 = torch.randn(1000, requires_grad=True, device='cuda') 
BB2 = torch.randn(1000, requires_grad=True, device='cuda')

BQ = torch.randn(1000, requires_grad=True, device='cuda')
BK = torch.randn(1000, requires_grad=True, device='cuda')   

E = torch.randn(1000, requires_grad=True, device='cuda')   

In [3]:
turn_weights = torch.randn([1000, 1000], requires_grad=True, device='cuda')

In [4]:
queryW = torch.randn([1000, 1000], requires_grad=True, device='cuda')
keyW = torch.randn([1000, 1000], requires_grad=True, device='cuda')
valueW = torch.randn([1000, 64], requires_grad=True, device='cuda')

In [19]:
# final_emb = (query_embeddings @ key_embeddings.T) + value_embeddings
# final_emb.shape

In [65]:
import datetime

games = [{'Event': 'Rated Classical game',
 'Site': 'https://lichess.org/j1dkb5dw',
 'White': 'BFG9k',
 'Black': 'mamalak',
 'Result': '1-0',
 'WhiteTitle': None,
 'BlackTitle': None,
 'WhiteElo': 1639,
 'BlackElo': 1403,
 'WhiteRatingDiff': 5,
 'BlackRatingDiff': -8,
 'UTCDate': datetime.date(2012, 12, 31),
 'UTCTime': datetime.time(23, 1, 3),
 'ECO': 'C00',
 'Opening': 'French Defense: Normal Variation',
 'Termination': 'Normal',
 'TimeControl': '600+8',
 'movetext': '1. e4 e6 2. d4 b6 3. a3 Bb7 4. Nc3 Nh6 5. Bxh6 gxh6 6. Be2 Qg5 7. Bg4 h5 8. Nf3 Qg6 9. Nh4 Qg5 10. Bxh5 Qxh4 11. Qf3 Kd8 12. Qxf7 Nc6 13. Qe8# 1-0'}]

In [14]:
import re
import chess
import torch
from datetime import datetime

# ------------------- CONFIG: integer IDs for piece types ------------------- #
piece_label_to_id = {
    "WR": 1, "WN": 2, "WB1": 3, "WQ": 4, "WK": 5, "WB2": 6, "WP": 7,
    "BP": 8, "BR": 9, "BN": 10, "BB1": 11, "BQ": 12, "BK": 13, "BB2": 14,
}

# ------------------- INITIAL BOARD ------------------- #
initial_map = {
    0: "WR", 1: "WN", 2: "WB1", 3: "WQ", 4: "WK", 5: "WB2", 6: "WN", 7: "WR",
    8: "WP", 9: "WP", 10: "WP", 11: "WP", 12: "WP", 13: "WP", 14: "WP", 15: "WP",
    48: "BP", 49: "BP", 50: "BP", 51: "BP", 52: "BP", 53: "BP", 54: "BP", 55: "BP",
    56: "BR", 57: "BN", 58: "BB1", 59: "BQ", 60: "BK", 61: "BB2", 62: "BN", 63: "BR",
}

# ------------------- Move Parser ------------------- #
_move_num = re.compile(r"^\d+\.(\.\.)?$")

def san_stream(movetext: str):
    for tok in movetext.replace("\n", " ").split():
        if _move_num.match(tok) or tok in {"1-0", "0-1", "1/2-1/2", "*"}:
            continue
        yield tok

# ------------------- Convert mapping to tensor ------------------- #
def mapping_to_tensor(mapping: dict) -> torch.LongTensor:
    return torch.tensor([
        piece_label_to_id.get(mapping.get(i), 0) for i in range(64)
    ], dtype=torch.long)

# ------------------- Determine winner string ------------------- #
def result_to_winner(result: str) -> str:
    if result == "1-0":
        return "white"
    elif result == "0-1":
        return "black"
    else:
        return "draw"
    
    
def create_dataset_from_games(game_dicts: list[dict]) -> list[dict]:
    full_dataset = []

    for game in game_dicts:
        try:
            board = chess.Board()
            mapping = initial_map.copy()
            states = [mapping_to_tensor(mapping)]
            turns = ["white"]  # starting with white
            move_vectors = []

            for san in san_stream(game["movetext"]):
                move = board.parse_san(san)
                from_sq, to_sq = move.from_square, move.to_square
                move_vector = [from_sq, to_sq]

                # Update mapping
                moving_id = mapping.pop(from_sq, None)

                if to_sq in mapping:
                    mapping.pop(to_sq)

                if board.is_en_passant(move):
                    ep_target = to_sq + (-8 if board.turn else 8)
                    mapping.pop(ep_target, None)

                mapping[to_sq] = moving_id

                # Skip rook move, just rely on king's move in castling
                board.push(move)

                states.append(mapping_to_tensor(mapping))
                turns.append("white" if board.turn else "black")
                move_vectors.append(move_vector)

        except Exception:
            continue

        winner = result_to_winner(game["Result"])

        for i in range(len(states) - 1):
            full_dataset.append({
                "input": states[i],
                "output": torch.tensor(move_vectors[i], dtype=torch.long),  # output is [from_sq, to_sq]
                "turn": turns[i],
                "winner": winner
            })

    return full_dataset


In [4]:
import os


# chess_token = os.environ['HF_TOKEN']

In [6]:
import huggingface_hub
import datasets

huggingface_hub.login(HF_TOKEN)

  from .autonotebook import tqdm as notebook_tqdm
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [None]:
# ds.push_to_hub('youngchiller40/chessset')

In [7]:
ds = datasets.load_dataset('youngchiller40/chessset')

training block

In [8]:
# Split into train/test (e.g., 90/10)
split_dataset = ds['train'].train_test_split(test_size=0.2, seed=42)

train_ds = split_dataset["train"]
test_ds = split_dataset["test"]

outputW = torch.randn([4096, 64], requires_grad=True)

In [80]:
small_ds = train_ds.select(range(1000))


In [9]:
import torch

# assumes E, WR, … BB2 are already defined integers 0‑14
piece_label_to_id = {                 # <‑‑ same as you showed
    0: E, 1: WR, 2: WN, 3: WB1, 4: WQ, 5: WK, 6: WB2, 7: WP,
    8: BP, 9: BR, 10: BN, 11: BB1, 12: BQ, 13: BK, 14: BB2,
}

def board_to_tensor(board, turn, turn_weights):
    device = turn_weights.device
    board_tensor = torch.stack([
        torch.tensor(piece_label_to_id[i], device=device) for i in board
    ])
    return board_tensor


def turn_mask(board, board_tensor, turn, turn_weights):
    device = board_tensor.device  # infer device from input

    mask_ids = []
    if turn == 'white':
        mask_ids = [ix for ix, tile in enumerate(board) if 1 <= tile <= 7]
    else:
        mask_ids = [ix for ix, tile in enumerate(board) if 8 <= tile <= 14]

    turn_weights = turn_weights.to(device)  # move to same device

    board = torch.stack([
        tile @ turn_weights if ix in mask_ids else tile
        for ix, tile in enumerate(board_tensor)
    ])

    return board

In [10]:
def normalize_tensor(tensor):
    min_val = tensor.min()
    max_val = tensor.max()
    if max_val == min_val:
        return torch.zeros_like(tensor)  # or tensor.clone() if you want to preserve values
    return (tensor - min_val) / (max_val - min_val)

In [12]:
import torch
import torch.nn.functional as F

# ──────────────────────────────────────────────
# 1.  Hyper‑params & device
# ──────────────────────────────────────────────
EMB_D    = 2000      # length of each piece‑embedding vector
BATCH_SZ = 256        # drop if you still OOM
LR       = 2e-4      # tweak as needed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ──────────────────────────────────────────────
# 0.  Rotary positional embedding  (64 squares, 2000‑d vectors)
# ──────────────────────────────────────────────
def build_rope_tables(seq_len: int, dim: int, device):
    """
    Return tensors  sin, cos  of shape  [seq_len, dim//2].
    dim must be even.  These are kept out of autograd.
    """
    half = dim // 2
    inv_freq = 1.0 / (10000 ** (torch.arange(half, device=device) / half))
    ang = torch.arange(seq_len, device=device).float().unsqueeze(1) * inv_freq[None, :]
    return ang.sin(), ang.cos()          # each  [seq_len, half]

rope_sin, rope_cos = build_rope_tables(64, EMB_D, device)
rope_sin.requires_grad_(False)
rope_cos.requires_grad_(False)

def apply_rope(x, sin, cos):
    """
    x   : [B, 64, D]
    sin : [64, D//2]   cos : [64, D//2]
    Rotates every (even,odd) channel pair in‑place and returns x.
    """
    sin = sin.unsqueeze(0)               # → [1, 64, D//2] for broadcast
    cos = cos.unsqueeze(0)

    x_even = x[..., 0::2]                # [B, 64, D//2]
    x_odd  = x[..., 1::2]                # [B, 64, D//2]

    rot_even = x_even * cos - x_odd * sin
    rot_odd  = x_even * sin + x_odd * cos

    # write back
    x[..., 0::2] = rot_even
    x[..., 1::2] = rot_odd
    return x


# ──────────────────────────────────────────────
# 2.  Learnable piece embeddings (E, WP, WR, …)
#     – 15 distinct pieces including empty
# ──────────────────────────────────────────────
piece_tensors = torch.nn.ParameterList(
    [torch.nn.Parameter(torch.randn(EMB_D, device=device))
     for _ in range(15)]
)
E, WR, WN, WB1, WQ, WK, WB2, WP, BP, BR, BN, BB1, BQ, BK, BB2 = piece_tensors

piece_label_to_tensor = {
    0: E,  1: WR, 2: WN, 3: WB1, 4: WQ, 5: WK, 6: WB2, 7: WP,
    8: BP, 9: BR, 10: BN, 11: BB1, 12: BQ, 13: BK, 14: BB2,
}

# ──────────────────────────────────────────────
# 3.  Other learnable weights
# ──────────────────────────────────────────────
turn_weights = torch.nn.Parameter(torch.randn(EMB_D, EMB_D, device=device))
queryW       = torch.nn.Parameter(torch.randn(EMB_D, EMB_D, device=device))
keyW         = torch.nn.Parameter(torch.randn(EMB_D, EMB_D, device=device))
valueW       = torch.nn.Parameter(torch.randn(EMB_D,  64,    device=device))  # 64 = value dim

# bundle everything for the optimizer
model_params = list(piece_tensors) + [turn_weights, queryW, keyW, valueW]
optimizer    = torch.optim.Adam(model_params, lr=LR)

# ──────────────────────────────────────────────
# 4.  Helper fns
# ──────────────────────────────────────────────
def board_to_tensor(board_ids):
    # board_ids: list/tuple length 64 of ints 0‑14
    return torch.stack([piece_label_to_tensor[i] for i in board_ids])  # [64, EMB_D]

def apply_turn_mask(board_ids, board_tensor, turn):
    # only multiply pieces that belong to the side to move
    mask = [(1 <= p <= 7) if turn == "white" else (8 <= p <= 14) for p in board_ids]
    if any(mask):
        board_tensor = board_tensor.clone()
        board_tensor[mask] = board_tensor[mask] @ turn_weights
    return board_tensor

def minmax_norm(t):
    return (t - t.min()) / (t.max() - t.min() + 1e-6)

# ──────────────────────────────────────────────
# 5.  Training loop (batched)
# ──────────────────────────────────────────────
valid_cards = [c for c in train_ds if c["winner"] == c["turn"]]

for epoch in range(1, 6):  # run 5 epochs for demo
    total_loss = 0
    for i in range(0, len(valid_cards), BATCH_SZ):
        batch = valid_cards[i:i+BATCH_SZ]

        boards   = []
        targets  = []

        for card in batch:
            turn  = card["turn"]
            ids   = card["input"]   # list of 64 ints
            row, col = card["output"]

            bt = board_to_tensor(ids)           # [64, EMB_D]
            bt = apply_turn_mask(ids, bt, turn) # [64, EMB_D]
            bt = minmax_norm(bt).to(device)

            boards.append(bt)
            targets.append(row * 64 + col)      # scalar 0‑4095

        boards  = torch.stack(boards).to(device)   # [B, 64, EMB_D]
        targets = torch.tensor(targets, device=device)  # [B]

        # -------- forward
        Q = minmax_norm(boards @ queryW)
        K = minmax_norm(boards @ keyW)

        Q = apply_rope(Q, rope_sin, rope_cos)
        K = apply_rope(K, rope_sin, rope_cos)

        V = minmax_norm(boards @ valueW)         # [B, 64, 64]

        attn  = torch.bmm(Q, K.transpose(1, 2))  # [B, 64, 64]
        logits = (attn + V).view(boards.size(0), -1)

        loss = F.cross_entropy(logits, targets)
        total_loss += loss.item() * boards.size(0)

        # -------- backward + update
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        print(loss.item())

    avg_loss = total_loss / len(valid_cards)
    print(f"Epoch {epoch}  |  avg CE loss: {avg_loss:.4f}")

152.8197021484375
142.6471405029297
141.97628784179688
137.90016174316406
128.45675659179688
133.25181579589844
127.67395782470703
125.36485290527344
122.2296142578125
121.32923126220703
113.3207015991211
114.26789093017578
110.00043487548828
108.06700134277344
107.53797912597656
105.87357330322266
107.51192474365234
101.77266693115234
107.65882110595703
100.88972473144531
99.20147705078125
97.83201599121094
94.4180679321289
92.87915802001953
93.69343566894531
87.99940490722656
87.1670913696289
88.82514190673828
84.86175537109375
84.10306549072266
85.72481536865234
80.89652252197266
79.33987426757812
79.94363403320312
76.838134765625
74.49542999267578
72.05389404296875
73.59241485595703
73.65058135986328
72.05203247070312
69.20565032958984
69.42005157470703
68.80998229980469
64.0354995727539
64.20487976074219
68.46539306640625
67.48515319824219
64.95450592041016
62.84821319580078
61.16017532348633
59.86363220214844
61.855133056640625
59.40317153930664
59.85856628417969
59.3476486206054

KeyboardInterrupt: 

In [13]:

import safetensors


In [14]:
from safetensors.torch import save_file

# Bundle all parameters in a dict with string keys
state_dict = {
    "turn_weights": turn_weights,
    "queryW": queryW,
    "keyW": keyW,
    "valueW": valueW,
    "rope_sin": rope_sin,
    "rope_cos": rope_cos,
}

# Add each piece tensor manually
for idx, tensor in enumerate(piece_tensors):
    state_dict[f"piece_{idx}"] = tensor

# Save to file
save_file(state_dict, "chess_model.safetensors")
