In [1]:
import datetime

games = [{'Event': 'Rated Classical game',
 'Site': 'https://lichess.org/j1dkb5dw',
 'White': 'BFG9k',
 'Black': 'mamalak',
 'Result': '1-0',
 'WhiteTitle': None,
 'BlackTitle': None,
 'WhiteElo': 1639,
 'BlackElo': 1403,
 'WhiteRatingDiff': 5,
 'BlackRatingDiff': -8,
 'UTCDate': datetime.date(2012, 12, 31),
 'UTCTime': datetime.time(23, 1, 3),
 'ECO': 'C00',
 'Opening': 'French Defense: Normal Variation',
 'Termination': 'Normal',
 'TimeControl': '600+8',
 'movetext': '1. e4 e6 2. d4 b6 3. a3 Bb7 4. Nc3 Nh6 5. Bxh6 gxh6 6. Be2 Qg5 7. Bg4 h5 8. Nf3 Qg6 9. Nh4 Qg5 10. Bxh5 Qxh4 11. Qf3 Kd8 12. Qxf7 Nc6 13. Qe8# 1-0'}]

loading dataset


In [2]:
import huggingface_hub
import torch
import datasets
import os

huggingface_hub.login(os.environ['HF_TOKEN'])

  from .autonotebook import tqdm as notebook_tqdm
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [3]:
ds = datasets.load_dataset('youngchiller40/chess-combined-dataset')

Downloading data: 100%|██████████| 23/23 [00:23<00:00,  1.03s/files]
Generating train split: 100%|██████████| 40885661/40885661 [00:26<00:00, 1551716.07 examples/s]


In [4]:
# Split into train/test (e.g., 90/10)
split_dataset = ds['train'].train_test_split(test_size=0.2, seed=42)

train_ds = split_dataset["train"]
test_ds = split_dataset["test"]

# legacy training block

In [8]:
import torch
import torch.nn.functional as F

# ──────────────────────────────────────────────
# 1.  Hyper‑params & device
# ──────────────────────────────────────────────
EMB_D    = 1024
BATCH_SZ = 512
LR       = 2e-4
device   = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ──────────────────────────────────────────────
# 2.  Rotary Positional Embedding
# ──────────────────────────────────────────────
def build_rope_tables(seq_len: int, dim: int, device):
    half = dim // 2
    inv_freq = 1.0 / (10000 ** (torch.arange(half, device=device) / half))
    ang = torch.arange(seq_len, device=device).float().unsqueeze(1) * inv_freq[None, :]
    return ang.sin(), ang.cos()

rope_sin, rope_cos = build_rope_tables(64, EMB_D, device)
rope_sin.requires_grad_(False)
rope_cos.requires_grad_(False)

def apply_rope(x, sin, cos):
    sin = sin.unsqueeze(0)
    cos = cos.unsqueeze(0)
    x_even = x[..., 0::2]
    x_odd  = x[..., 1::2]
    rot_even = x_even * cos - x_odd * sin
    rot_odd  = x_even * sin + x_odd * cos
    x[..., 0::2] = rot_even
    x[..., 1::2] = rot_odd
    return x

# ──────────────────────────────────────────────
# 3.  Learnable Weights
# ──────────────────────────────────────────────
piece_tensors = torch.nn.ParameterList([
    torch.nn.Parameter(torch.randn(EMB_D, device=device)) for _ in range(15)
])
turn_weights = torch.nn.Parameter(torch.randn(EMB_D, EMB_D, device=device))
queryW       = torch.nn.Parameter(torch.randn(EMB_D, EMB_D, device=device))
keyW         = torch.nn.Parameter(torch.randn(EMB_D, EMB_D, device=device))
valueW       = torch.nn.Parameter(torch.randn(EMB_D, EMB_D, device=device))

# MLP Head: outputs 64 logits per square
mlp_head = torch.nn.Sequential(
    torch.nn.LayerNorm(EMB_D),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.1),
    torch.nn.Linear(EMB_D, 64)
).to(device)

# Optimizer
model_params = list(piece_tensors) + [turn_weights, queryW, keyW, valueW] + list(mlp_head.parameters())
optimizer    = torch.optim.Adam(model_params, lr=LR)

# ──────────────────────────────────────────────
# 4.  Helper Functions
# ──────────────────────────────────────────────
piece_label_to_tensor = {i: piece_tensors[i] for i in range(15)}

def board_to_tensor(board_ids):
    return torch.stack([piece_label_to_tensor[i] for i in board_ids])

def apply_turn_mask(board_ids, board_tensor, turn):
    mask = [(1 <= p <= 7) if turn == "white" else (8 <= p <= 14) for p in board_ids]
    if any(mask):
        board_tensor = board_tensor.clone()
        board_tensor[mask] = board_tensor[mask] @ turn_weights
    return board_tensor

def minmax_norm(t):
    return (t - t.min()) / (t.max() - t.min() + 1e-6)

# ──────────────────────────────────────────────
# 5.  Training Loop
# ──────────────────────────────────────────────
def train_model(train_ds, epochs=5):
    valid_cards = [c for c in train_ds if c["winner"] == c["turn"]]

    for epoch in range(1, epochs + 1):
        total_loss = 0
        for i in range(0, len(valid_cards), BATCH_SZ):
            batch = valid_cards[i:i+BATCH_SZ]
            boards = []
            targets = []

            for card in batch:
                turn = card["turn"]
                ids = card["input"]
                row, col = card["output"]

                bt = board_to_tensor(ids)
                bt = apply_turn_mask(ids, bt, turn)
                bt = minmax_norm(bt).to(device)

                boards.append(bt)
                targets.append(row * 64 + col)

            boards  = torch.stack(boards).to(device)               # [B, 64, EMB_D]
            targets = torch.tensor(targets, device=device)         # [B]

            Q = minmax_norm(boards @ queryW)
            K = minmax_norm(boards @ keyW)
            V = boards @ valueW

            Q = apply_rope(Q, rope_sin, rope_cos)
            K = apply_rope(K, rope_sin, rope_cos)

            # Normalize
            Q = torch.nn.functional.layer_norm(Q, (EMB_D,))
            K = torch.nn.functional.layer_norm(K, (EMB_D,))
            V = torch.nn.functional.layer_norm(V, (EMB_D,))

            attn_weights = torch.bmm(Q, K.transpose(1, 2)) / EMB_D**0.5
            attn_out = torch.bmm(attn_weights.softmax(dim=-1), V)

            logits_per_square = mlp_head(attn_out)      # [B, 64, 64]
            logits = logits_per_square.view(boards.size(0), -1)  # [B, 4096]

            loss = F.cross_entropy(logits, targets)
            total_loss += loss.item() * boards.size(0)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            print(f"Loss: {loss.item():.4f}")

        avg_loss = total_loss / len(valid_cards)
        print(f"Epoch {epoch} | Avg Loss: {avg_loss:.4f}")


In [None]:
train_model(train_ds, 1)

# trial blocks 

In [10]:

model = ChessMoveModel()
model.load_state_dict(torch.hub.load_state_dict_from_url(
    'https://huggingface.co/youngchiller40/chess-move-transformer/resolve/main/chess_move_model.pt',
    map_location="cpu"
))
model.eval()


Downloading: "https://huggingface.co/youngchiller40/chess-move-transformer/resolve/main/chess_move_model.pt" to /root/.cache/torch/hub/checkpoints/chess_move_model.pt


100%|██████████| 293M/293M [00:07<00:00, 41.5MB/s] 


ChessMoveModel(
  (piece_emb): Embedding(15, 1024)
  (square_emb): Embedding(64, 1024)
  (blocks): ModuleList(
    (0-5): 6 x ChessBlock(
      (ln1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
      )
      (ln2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (ffn): Sequential(
        (0): Linear(in_features=1024, out_features=4096, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=4096, out_features=1024, bias=True)
      )
    )
  )
  (head): Sequential(
    (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=1024, out_features=64, bias=True)
  )
)

In [13]:
# chess_move_transformer_split_bishops_hf.py
import torch, random
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict
from datasets import Dataset  # type hint
import tqdm

# ---------------- Hyper‑params ----------------
D_MODEL, N_HEADS, DEPTH = 1024, 8, 6
FFN_MULT, BATCH_SZ, LR, EPOCHS = 4, 1111, 2e-4, 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# =============== Model definition ===============
class ChessBlock(nn.Module):
    def __init__(self, d=D_MODEL, heads=N_HEADS, ffn_mult=FFN_MULT):
        super().__init__()
        self.ln1 = nn.LayerNorm(d)
        self.attn = nn.MultiheadAttention(d, heads, batch_first=True)
        self.ln2 = nn.LayerNorm(d)
        self.ffn = nn.Sequential(
            nn.Linear(d, ffn_mult * d), nn.GELU(), nn.Linear(ffn_mult * d, d)
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x))[0]
        x = x + self.ffn(self.ln2(x))
        return x


class ChessMoveModel(nn.Module):
    def __init__(self, d=D_MODEL, heads=N_HEADS, layers=DEPTH):
        super().__init__()
        self.piece_emb = nn.Embedding(15, d)
        self.square_emb = nn.Embedding(64, d)
        self.turn_weights = nn.Parameter(torch.randn(d, d))
        self.blocks = nn.ModuleList([ChessBlock(d, heads) for _ in range(layers)])
        self.head = nn.Sequential(nn.LayerNorm(d), nn.Linear(d, 64))

    def forward(self, board_ids, turn_mask):          # board_ids [B,64]
        sq = torch.arange(64, device=board_ids.device)
        x = self.piece_emb(board_ids) + self.square_emb(sq)         # [B,64,d]
        x = torch.where(turn_mask.unsqueeze(-1), x @ self.turn_weights, x)
        for blk in self.blocks:
            x = blk(x)
        return self.head(x).view(x.size(0), -1)                     # [B,4096]


# =============== Helper utilities ===============
def make_turn_mask(board_batch: torch.Tensor, turns):
    white_owned = (board_batch >= 1) & (board_batch <= 7)
    black_owned = (board_batch >= 8)
    masks = [
        white_owned[i] if (t == "white" or t == 0) else black_owned[i]
        for i, t in enumerate(turns)
    ]
    return torch.stack(masks)


def hf_batch_iter(ds: Dataset, bs: int):
    """Iterate over a HF Dataset in shuffled mini‑batches."""
    ds = ds.shuffle(seed=random.randint(0, 1_000_000))
    for start in range(0, len(ds), bs):
        yield ds.select(range(start, min(start + bs, len(ds))))


def list_batch_iter(data: List[Dict], bs: int):
    random.shuffle(data)
    for i in range(0, len(data), bs):
        yield data[i : i + bs]


# =============== Training loop ===============
def train(model: nn.Module, dataset):
    optim = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
    model.train()

    # If it's a HF Dataset, ask it to output torch tensors for input/output
    if isinstance(dataset, Dataset):
        dataset = dataset.with_format(type="torch", columns=["input", "output"])

    for epoch in range(1, EPOCHS + 1):
        total, correct, agg_loss = 0, 0, 0.0

        batcher = hf_batch_iter(dataset, BATCH_SZ) if isinstance(dataset, Dataset) \
                  else list_batch_iter(dataset, BATCH_SZ)

        total_batches = len(dataset) // BATCH_SZ
        pbar = tqdm.tqdm(batcher, desc=f"Epoch {epoch}/{EPOCHS}", total=total_batches)
        for batch_num, batch in enumerate(pbar):
            if isinstance(batch, Dataset):               # Hugging‑Face case
                boards  = batch["input"].to(DEVICE)      # already LongTensor [B,64]
                outs    = batch["output"].to(DEVICE)     # [B,2]
                turns   = batch["turn"]                  # still python list/ndarray
            else:                                        # list-of-dicts case
                boards  = torch.stack([b["input"] for b in batch]).to(DEVICE)
                outs    = torch.stack([b["output"] for b in batch]).to(DEVICE)
                turns   = [b["turn"] for b in batch]

            targets = (outs[:, 0] * 64 + outs[:, 1]).to(DEVICE)          # [B]
            turn_mask = make_turn_mask(boards, turns).to(DEVICE)

            logits = model(boards, turn_mask)
            loss   = F.cross_entropy(logits, targets)

            optim.zero_grad(set_to_none=True)
            loss.backward()
            optim.step()

            agg_loss += loss.item() * boards.size(0)
            total    += boards.size(0)
            correct  += (logits.argmax(dim=-1) == targets).sum().item()
            
            # Update progress bar with current loss and percentage complete
            percent_complete = 100 * (batch_num + 1) / total_batches
            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'complete': f'{percent_complete:.1f}%'})

        print(f"Epoch {epoch}/{EPOCHS} | loss {agg_loss/total:.4f} | acc {100*correct/total:.2f}%")

In [11]:
model.to('cuda')

ChessMoveModel(
  (piece_emb): Embedding(15, 1024)
  (square_emb): Embedding(64, 1024)
  (blocks): ModuleList(
    (0-5): 6 x ChessBlock(
      (ln1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
      )
      (ln2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (ffn): Sequential(
        (0): Linear(in_features=1024, out_features=4096, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=4096, out_features=1024, bias=True)
      )
    )
  )
  (head): Sequential(
    (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=1024, out_features=64, bias=True)
  )
)

In [14]:
train(model, train_ds)

Epoch 1/1:  20%|██        | 6019/29440 [4:10:25<16:14:26,  2.50s/it, loss=2.7011, complete=20.4%]


KeyboardInterrupt: 

In [15]:
# Save model weights (recommended)
torch.save(model.state_dict(), "chess_move_model.pt")


In [16]:
model_config = {
    "d_model": D_MODEL,
    "n_heads": N_HEADS,
    "depth": DEPTH,
    "ffn_mult": FFN_MULT
}

import json
with open("config.json", "w") as f:
    json.dump(model_config, f)


In [17]:
from huggingface_hub import HfApi, upload_file

# Upload model weights
upload_file(
    path_or_fileobj="chess_move_model.pt",
    path_in_repo="chess_move_model.pt",
    repo_id="youngchiller40/chess-move-transformer",
    repo_type="model"
)

# Upload config
upload_file(
    path_or_fileobj="config.json",
    path_in_repo="config.json",
    repo_id="youngchiller40/chess-move-transformer",
    repo_type="model"
)


chess_move_model.pt: 100%|██████████| 307M/307M [00:11<00:00, 26.3MB/s] 
No files have been modified since last commit. Skipping to prevent empty commit.


CommitInfo(commit_url='https://huggingface.co/youngchiller40/chess-move-transformer/commit/d38c786a7604c0c595fd6f10eaf320d04cb2cc2f', commit_message='Upload config.json with huggingface_hub', commit_description='', oid='d38c786a7604c0c595fd6f10eaf320d04cb2cc2f', pr_url=None, repo_url=RepoUrl('https://huggingface.co/youngchiller40/chess-move-transformer', endpoint='https://huggingface.co', repo_type='model', repo_id='youngchiller40/chess-move-transformer'), pr_revision=None, pr_num=None)