In [1]:
import numpy as np
import torch

In [None]:
from datasets import load_dataset

# Load using streaming mode
dataset = load_dataset("Lichess/standard-chess-games", split="train", streaming=True)

# Get a small sample of 10 games
sample = []
for i, row in enumerate(dataset):
    sample.append(row)
    if i == 10000:
        break

In [2]:
WP = torch.randn(1000, requires_grad=True, device='cuda')

WR = torch.randn(1000, requires_grad=True, device='cuda')

WN = torch.randn(1000, requires_grad=True, device='cuda')

WB1 = torch.randn(1000, requires_grad=True, device='cuda')
WB2 = torch.randn(1000, requires_grad=True, device='cuda') 

WQ = torch.randn(1000, requires_grad=True, device='cuda')
WK = torch.randn(1000, requires_grad=True, device='cuda')

BP = torch.randn(1000, requires_grad=True, device='cuda')

BR = torch.randn(1000, requires_grad=True, device='cuda')

BN = torch.randn(1000, requires_grad=True, device='cuda')

BB1 = torch.randn(1000, requires_grad=True, device='cuda') 
BB2 = torch.randn(1000, requires_grad=True, device='cuda')

BQ = torch.randn(1000, requires_grad=True, device='cuda')
BK = torch.randn(1000, requires_grad=True, device='cuda')   

E = torch.randn(1000, requires_grad=True, device='cuda')   

In [3]:
turn_weights = torch.randn([1000, 1000], requires_grad=True, device='cuda')

In [4]:
queryW = torch.randn([1000, 1000], requires_grad=True, device='cuda')
keyW = torch.randn([1000, 1000], requires_grad=True, device='cuda')
valueW = torch.randn([1000, 64], requires_grad=True, device='cuda')

In [19]:
# final_emb = (query_embeddings @ key_embeddings.T) + value_embeddings
# final_emb.shape

In [65]:
import datetime

games = [{'Event': 'Rated Classical game',
 'Site': 'https://lichess.org/j1dkb5dw',
 'White': 'BFG9k',
 'Black': 'mamalak',
 'Result': '1-0',
 'WhiteTitle': None,
 'BlackTitle': None,
 'WhiteElo': 1639,
 'BlackElo': 1403,
 'WhiteRatingDiff': 5,
 'BlackRatingDiff': -8,
 'UTCDate': datetime.date(2012, 12, 31),
 'UTCTime': datetime.time(23, 1, 3),
 'ECO': 'C00',
 'Opening': 'French Defense: Normal Variation',
 'Termination': 'Normal',
 'TimeControl': '600+8',
 'movetext': '1. e4 e6 2. d4 b6 3. a3 Bb7 4. Nc3 Nh6 5. Bxh6 gxh6 6. Be2 Qg5 7. Bg4 h5 8. Nf3 Qg6 9. Nh4 Qg5 10. Bxh5 Qxh4 11. Qf3 Kd8 12. Qxf7 Nc6 13. Qe8# 1-0'}]

In [14]:
import re
import chess
import torch
from datetime import datetime

# ------------------- CONFIG: integer IDs for piece types ------------------- #
piece_label_to_id = {
    "WR": 1, "WN": 2, "WB1": 3, "WQ": 4, "WK": 5, "WB2": 6, "WP": 7,
    "BP": 8, "BR": 9, "BN": 10, "BB1": 11, "BQ": 12, "BK": 13, "BB2": 14,
}

# ------------------- INITIAL BOARD ------------------- #
initial_map = {
    0: "WR", 1: "WN", 2: "WB1", 3: "WQ", 4: "WK", 5: "WB2", 6: "WN", 7: "WR",
    8: "WP", 9: "WP", 10: "WP", 11: "WP", 12: "WP", 13: "WP", 14: "WP", 15: "WP",
    48: "BP", 49: "BP", 50: "BP", 51: "BP", 52: "BP", 53: "BP", 54: "BP", 55: "BP",
    56: "BR", 57: "BN", 58: "BB1", 59: "BQ", 60: "BK", 61: "BB2", 62: "BN", 63: "BR",
}

# ------------------- Move Parser ------------------- #
_move_num = re.compile(r"^\d+\.(\.\.)?$")

def san_stream(movetext: str):
    for tok in movetext.replace("\n", " ").split():
        if _move_num.match(tok) or tok in {"1-0", "0-1", "1/2-1/2", "*"}:
            continue
        yield tok

# ------------------- Convert mapping to tensor ------------------- #
def mapping_to_tensor(mapping: dict) -> torch.LongTensor:
    return torch.tensor([
        piece_label_to_id.get(mapping.get(i), 0) for i in range(64)
    ], dtype=torch.long)

# ------------------- Determine winner string ------------------- #
def result_to_winner(result: str) -> str:
    if result == "1-0":
        return "white"
    elif result == "0-1":
        return "black"
    else:
        return "draw"
    
    
def create_dataset_from_games(game_dicts: list[dict]) -> list[dict]:
    full_dataset = []

    for game in game_dicts:
        try:
            board = chess.Board()
            mapping = initial_map.copy()
            states = [mapping_to_tensor(mapping)]
            turns = ["white"]  # starting with white
            move_vectors = []

            for san in san_stream(game["movetext"]):
                move = board.parse_san(san)
                from_sq, to_sq = move.from_square, move.to_square
                move_vector = [from_sq, to_sq]

                # Update mapping
                moving_id = mapping.pop(from_sq, None)

                if to_sq in mapping:
                    mapping.pop(to_sq)

                if board.is_en_passant(move):
                    ep_target = to_sq + (-8 if board.turn else 8)
                    mapping.pop(ep_target, None)

                mapping[to_sq] = moving_id

                # Skip rook move, just rely on king's move in castling
                board.push(move)

                states.append(mapping_to_tensor(mapping))
                turns.append("white" if board.turn else "black")
                move_vectors.append(move_vector)

        except Exception:
            continue

        winner = result_to_winner(game["Result"])

        for i in range(len(states) - 1):
            full_dataset.append({
                "input": states[i],
                "output": torch.tensor(move_vectors[i], dtype=torch.long),  # output is [from_sq, to_sq]
                "turn": turns[i],
                "winner": winner
            })

    return full_dataset


loading dataset


In [1]:
import huggingface_hub
import torch
import datasets
import os

huggingface_hub.login(os.environ['HF_TOKEN'])

  from .autonotebook import tqdm as notebook_tqdm
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [2]:
ds = datasets.load_dataset('youngchiller40/chessset')

In [3]:
# Split into train/test (e.g., 90/10)
split_dataset = ds['train'].train_test_split(test_size=0.2, seed=42)

train_ds = split_dataset["train"]
test_ds = split_dataset["test"]

training block

In [8]:
import torch
import torch.nn.functional as F

# ──────────────────────────────────────────────
# 1.  Hyper‑params & device
# ──────────────────────────────────────────────
EMB_D    = 1024
BATCH_SZ = 512
LR       = 2e-4
device   = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ──────────────────────────────────────────────
# 2.  Rotary Positional Embedding
# ──────────────────────────────────────────────
def build_rope_tables(seq_len: int, dim: int, device):
    half = dim // 2
    inv_freq = 1.0 / (10000 ** (torch.arange(half, device=device) / half))
    ang = torch.arange(seq_len, device=device).float().unsqueeze(1) * inv_freq[None, :]
    return ang.sin(), ang.cos()

rope_sin, rope_cos = build_rope_tables(64, EMB_D, device)
rope_sin.requires_grad_(False)
rope_cos.requires_grad_(False)

def apply_rope(x, sin, cos):
    sin = sin.unsqueeze(0)
    cos = cos.unsqueeze(0)
    x_even = x[..., 0::2]
    x_odd  = x[..., 1::2]
    rot_even = x_even * cos - x_odd * sin
    rot_odd  = x_even * sin + x_odd * cos
    x[..., 0::2] = rot_even
    x[..., 1::2] = rot_odd
    return x

# ──────────────────────────────────────────────
# 3.  Learnable Weights
# ──────────────────────────────────────────────
piece_tensors = torch.nn.ParameterList([
    torch.nn.Parameter(torch.randn(EMB_D, device=device)) for _ in range(15)
])
turn_weights = torch.nn.Parameter(torch.randn(EMB_D, EMB_D, device=device))
queryW       = torch.nn.Parameter(torch.randn(EMB_D, EMB_D, device=device))
keyW         = torch.nn.Parameter(torch.randn(EMB_D, EMB_D, device=device))
valueW       = torch.nn.Parameter(torch.randn(EMB_D, EMB_D, device=device))

# MLP Head: outputs 64 logits per square
mlp_head = torch.nn.Sequential(
    torch.nn.LayerNorm(EMB_D),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.1),
    torch.nn.Linear(EMB_D, 64)
).to(device)

# Optimizer
model_params = list(piece_tensors) + [turn_weights, queryW, keyW, valueW] + list(mlp_head.parameters())
optimizer    = torch.optim.Adam(model_params, lr=LR)

# ──────────────────────────────────────────────
# 4.  Helper Functions
# ──────────────────────────────────────────────
piece_label_to_tensor = {i: piece_tensors[i] for i in range(15)}

def board_to_tensor(board_ids):
    return torch.stack([piece_label_to_tensor[i] for i in board_ids])

def apply_turn_mask(board_ids, board_tensor, turn):
    mask = [(1 <= p <= 7) if turn == "white" else (8 <= p <= 14) for p in board_ids]
    if any(mask):
        board_tensor = board_tensor.clone()
        board_tensor[mask] = board_tensor[mask] @ turn_weights
    return board_tensor

def minmax_norm(t):
    return (t - t.min()) / (t.max() - t.min() + 1e-6)

# ──────────────────────────────────────────────
# 5.  Training Loop
# ──────────────────────────────────────────────
def train_model(train_ds, epochs=5):
    valid_cards = [c for c in train_ds if c["winner"] == c["turn"]]

    for epoch in range(1, epochs + 1):
        total_loss = 0
        for i in range(0, len(valid_cards), BATCH_SZ):
            batch = valid_cards[i:i+BATCH_SZ]
            boards = []
            targets = []

            for card in batch:
                turn = card["turn"]
                ids = card["input"]
                row, col = card["output"]

                bt = board_to_tensor(ids)
                bt = apply_turn_mask(ids, bt, turn)
                bt = minmax_norm(bt).to(device)

                boards.append(bt)
                targets.append(row * 64 + col)

            boards  = torch.stack(boards).to(device)               # [B, 64, EMB_D]
            targets = torch.tensor(targets, device=device)         # [B]

            Q = minmax_norm(boards @ queryW)
            K = minmax_norm(boards @ keyW)
            V = boards @ valueW

            Q = apply_rope(Q, rope_sin, rope_cos)
            K = apply_rope(K, rope_sin, rope_cos)

            # Normalize
            Q = torch.nn.functional.layer_norm(Q, (EMB_D,))
            K = torch.nn.functional.layer_norm(K, (EMB_D,))
            V = torch.nn.functional.layer_norm(V, (EMB_D,))

            attn_weights = torch.bmm(Q, K.transpose(1, 2)) / EMB_D**0.5
            attn_out = torch.bmm(attn_weights.softmax(dim=-1), V)

            logits_per_square = mlp_head(attn_out)      # [B, 64, 64]
            logits = logits_per_square.view(boards.size(0), -1)  # [B, 4096]

            loss = F.cross_entropy(logits, targets)
            total_loss += loss.item() * boards.size(0)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            print(f"Loss: {loss.item():.4f}")

        avg_loss = total_loss / len(valid_cards)
        print(f"Epoch {epoch} | Avg Loss: {avg_loss:.4f}")


In [9]:
train_model(train_ds, 1)

Loss: 8.3911
Loss: 8.3788
Loss: 8.3576
Loss: 8.3013
Loss: 8.2600
Loss: 8.2483
Loss: 8.2566
Loss: 8.1601
Loss: 8.2021
Loss: 8.1980
Loss: 8.1554
Loss: 8.1453
Loss: 8.1204
Loss: 8.1052
Loss: 8.2068
Loss: 8.1087
Loss: 8.1614
Loss: 8.1249
Loss: 8.1497
Loss: 8.0826
Loss: 8.1276
Loss: 8.1154
Loss: 8.1464
Loss: 8.1264
Loss: 8.1286
Loss: 8.0979
Loss: 8.0896
Loss: 8.1569
Loss: 8.0828
Loss: 8.0853
Loss: 8.0954
Loss: 8.0041
Loss: 8.0696
Loss: 8.0422
Loss: 8.1046
Loss: 8.0586
Loss: 8.0431
Loss: 8.0100
Loss: 8.0700
Loss: 8.0167
Loss: 8.0161
Loss: 8.0215
Loss: 8.0480
Loss: 8.0402
Loss: 8.0179
Loss: 8.0089
Loss: 8.0258
Loss: 8.0087
Loss: 7.9814
Loss: 8.0021
Loss: 7.9700
Loss: 7.9736
Loss: 7.9737
Loss: 7.8579
Loss: 7.9301
Loss: 7.9264
Loss: 7.9307
Loss: 7.9264
Loss: 7.8791
Loss: 7.9198
Loss: 7.9122
Loss: 7.8644
Loss: 7.8889
Loss: 7.8929
Loss: 7.8948
Loss: 7.8834
Loss: 7.8417
Loss: 7.8200
Loss: 7.8103
Loss: 7.8478
Loss: 7.8485
Loss: 7.8144
Loss: 7.7908
Loss: 7.8188
Loss: 7.7358
Loss: 7.7207
Loss: 7.7730

KeyboardInterrupt: 

trial block

In [29]:
# chess_move_transformer_split_bishops_hf.py
import torch, random
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict
from datasets import Dataset  # type hint

# ---------------- Hyper‑params ----------------
D_MODEL, N_HEADS, DEPTH = 1024, 8, 6
FFN_MULT, BATCH_SZ, LR, EPOCHS = 4, 4, 2e-4, 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# =============== Model definition ===============
class ChessBlock(nn.Module):
    def __init__(self, d=D_MODEL, heads=N_HEADS, ffn_mult=FFN_MULT):
        super().__init__()
        self.ln1 = nn.LayerNorm(d)
        self.attn = nn.MultiheadAttention(d, heads, batch_first=True)
        self.ln2 = nn.LayerNorm(d)
        self.ffn = nn.Sequential(
            nn.Linear(d, ffn_mult * d), nn.GELU(), nn.Linear(ffn_mult * d, d)
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0]
        x = x + self.ffn(self.ln2(x))
        return x


class ChessMoveModel(nn.Module):
    def __init__(self, d=D_MODEL, heads=N_HEADS, layers=DEPTH):
        super().__init__()
        self.piece_emb = nn.Embedding(15, d)
        self.square_emb = nn.Embedding(64, d)
        self.turn_weights = nn.Parameter(torch.randn(d, d))
        self.blocks = nn.ModuleList([ChessBlock(d, heads) for _ in range(layers)])
        self.head = nn.Sequential(nn.LayerNorm(d), nn.Linear(d, 64))

    def forward(self, board_ids, turn_mask):          # board_ids [B,64]
        sq = torch.arange(64, device=board_ids.device)
        x = self.piece_emb(board_ids) + self.square_emb(sq)         # [B,64,d]
        x = torch.where(turn_mask.unsqueeze(-1), x @ self.turn_weights, x)
        for blk in self.blocks:
            x = blk(x)
        return self.head(x).view(x.size(0), -1)                     # [B,4096]


# =============== Helper utilities ===============
def make_turn_mask(board_batch: torch.Tensor, turns):
    white_owned = (board_batch >= 1) & (board_batch <= 7)
    black_owned = (board_batch >= 8)
    masks = [
        white_owned[i] if (t == "white" or t == 0) else black_owned[i]
        for i, t in enumerate(turns)
    ]
    return torch.stack(masks)


def hf_batch_iter(ds: Dataset, bs: int):
    """Iterate over a HF Dataset in shuffled mini‑batches."""
    ds = ds.shuffle(seed=random.randint(0, 1_000_000))
    for start in range(0, len(ds), bs):
        yield ds.select(range(start, min(start + bs, len(ds))))


def list_batch_iter(data: List[Dict], bs: int):
    random.shuffle(data)
    for i in range(0, len(data), bs):
        yield data[i : i + bs]


# =============== Training loop ===============
def train(model: nn.Module, dataset):
    optim = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
    model.train()

    # If it's a HF Dataset, ask it to output torch tensors for input/output
    if isinstance(dataset, Dataset):
        dataset = dataset.with_format(type="torch", columns=["input", "output"])

    for epoch in range(1, EPOCHS + 1):
        total, correct, agg_loss = 0, 0, 0.0

        batcher = hf_batch_iter(dataset, BATCH_SZ) if isinstance(dataset, Dataset) \
                  else list_batch_iter(dataset, BATCH_SZ)

        for batch in batcher:
            if isinstance(batch, Dataset):               # Hugging‑Face case
                boards  = batch["input"].to(DEVICE)      # already LongTensor [B,64]
                outs    = batch["output"].to(DEVICE)     # [B,2]
                turns   = batch["turn"]                  # still python list/ndarray
            else:                                        # list-of-dicts case
                boards  = torch.stack([b["input"] for b in batch]).to(DEVICE)
                outs    = torch.stack([b["output"] for b in batch]).to(DEVICE)
                turns   = [b["turn"] for b in batch]

            targets = (outs[:, 0] * 64 + outs[:, 1]).to(DEVICE)          # [B]
            turn_mask = make_turn_mask(boards, turns).to(DEVICE)

            logits = model(boards, turn_mask)
            loss   = F.cross_entropy(logits, targets)

            optim.zero_grad(set_to_none=True)
            loss.backward()
            optim.step()

            agg_loss += loss.item() * boards.size(0)
            print(f"Loss: {loss.item():.4f}")
            total    += boards.size(0)
            correct  += (logits.argmax(dim=-1) == targets).sum().item()

        print(f"Epoch {epoch}/{EPOCHS} | loss {agg_loss/total:.4f} | acc {100*correct/total:.2f}%")


# # =============== Main ===============
# if __name__ == "__main__":
#     # ---------------------------------------------------------------------
#     # Supply either:
#     #   • dataset: List[Dict]    (python list)
#     #   • dataset: datasets.Dataset  (already built, with split bishops)
#     # ---------------------------------------------------------------------
#     dataset = ...  # ← insert your Dataset object or python‑list here
#     print("Dataset size:", len(dataset))

#     model = ChessMoveModel().to(DEVICE)
#     train(model, dataset)

#     # torch.save(model.state_dict(), "chess_split_bishops_transformer.pt")


In [22]:
model = ChessMoveModel().to('cuda')

In [28]:
train(model, train_ds)

Loss: 6.8266
Loss: 4.4609
Loss: 5.1473
Loss: 9.8373
Loss: 6.2641
Loss: 2.4945
Loss: 5.1563
Loss: 6.4276
Loss: 6.4022
Loss: 4.1561
Loss: 4.5891
Loss: 8.7551
Loss: 4.5865
Loss: 6.0529
Loss: 6.2152
Loss: 5.9307
Loss: 5.3208
Loss: 3.3181
Loss: 6.2891
Loss: 3.1643
Loss: 4.8770
Loss: 4.2761
Loss: 5.2825
Loss: 5.6304
Loss: 8.7402
Loss: 7.2991
Loss: 4.4057
Loss: 6.7966
Loss: 4.2684
Loss: 3.0531
Loss: 6.1960
Loss: 7.8238
Loss: 6.4094
Loss: 3.8770
Loss: 11.6176
Loss: 7.6597
Loss: 8.8364
Loss: 7.1482
Loss: 4.4741
Loss: 6.1272
Loss: 4.9766
Loss: 4.1329
Loss: 4.5895
Loss: 3.1525
Loss: 6.6231
Loss: 3.8012
Loss: 2.5324
Loss: 5.7185
Loss: 9.1834
Loss: 6.0745
Loss: 6.8239
Loss: 7.1409
Loss: 5.4639
Loss: 4.4346
Loss: 4.2987
Loss: 6.6895
Loss: 6.4287
Loss: 5.2054
Loss: 4.3897
Loss: 6.3267
Loss: 6.6747
Loss: 5.1607
Loss: 3.9339
Loss: 4.5935
Loss: 4.6025
Loss: 5.0308
Loss: 5.2208
Loss: 5.7789
Loss: 5.6781
Loss: 6.0125
Loss: 3.5026
Loss: 3.8318
Loss: 4.5714
Loss: 5.3498
Loss: 2.8557
Loss: 3.5446
Loss: 6.327

KeyboardInterrupt: 

saving model

In [10]:

import safetensors


In [12]:
from safetensors.torch import save_file, save_model

# Remove these lines (they cause shared memory issues)
# "E": E, "WR": WR, ..., "BB2": BB2

# Use only the indexed loop version:
state_dict = {
    "turn_weights": turn_weights,
    "queryW": queryW,
    "keyW": keyW,
    "valueW": valueW,
    "rope_sin": rope_sin,
    "rope_cos": rope_cos,
    "mlp_head": mlp_head
    
}

for idx, tensor in enumerate(piece_tensors):
    state_dict[f"piece_{idx}"] = tensor

save_file(state_dict, "chess_model_3k.safetensors")



ValueError: Key `mlp_head` is invalid, expected torch.Tensor but received <class 'torch.nn.modules.container.Sequential'>

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [16]:
import torch
torch.cuda.set_device(0)


In [17]:
torch.cuda.is_available()

True