# Welcome to Modal notebooks!

Write Python code and collaborate in real time. Your code runs in Modal's
**serverless cloud**, and anyone in the same workspace can join.

This notebook comes with some common Python libraries installed. Run
cells with `Shift+Enter`.

In [1]:
# ===================== SAFE INSTALL (KAGGLE-COMPATIBLE) =====================
!pip install -U transformers tokenizers accelerate python-chess tqdm --quiet

# ===================== ENV =====================
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# ===================== IMPORTS =====================
import torch
import chess.pgn
import io
import random
import json
from tqdm import tqdm
from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup

print("Torch:", torch.__version__)
print("CUDA:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

# ===================== CONFIG =====================
@dataclass
class Config:
    base_model = "lazy-guy12/chess-llama"
    pgn_path = "/root/Anand.pgn"
    output_dir = "/root/anand_chess_modelv3"

    max_seq_length = 128
    batch_size = 5
    epochs = 15
    lr = 1e-5
    warmup_ratio = 0.1
    weight_decay = 0.01

    device = "cuda" if torch.cuda.is_available() else "cpu"

cfg = Config()

# ===================== LOAD PGN =====================
def load_anand_games(path):
    games = []
    with open(path, "r") as f:
        pgn_io = io.StringIO(f.read())

    while True:
        game = chess.pgn.read_game(pgn_io)
        if game is None:
            break

        h = game.headers
        if "Anand" not in h.get("White","") and "Anand" not in h.get("Black",""):
            continue

        moves = [m.uci() for m in game.mainline_moves()]
        if len(moves) < 5:
            continue

        games.append(f'[Result "{h.get("Result","*")}"] ' + " ".join(moves))

    return games

games = load_anand_games(cfg.pgn_path)
random.shuffle(games)

split = int(0.85 * len(games))
train_data = games[:split]
val_data = games[split:]

print("Games used:", len(games))

# ===================== DATASET =====================
class ChessDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            self.data[idx],
            max_length=cfg.max_seq_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )

        input_ids = enc["input_ids"].squeeze()
        attention_mask = enc["attention_mask"].squeeze()

        labels = input_ids.clone()
        labels[attention_mask == 0] = -100   # critical fix

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

# ===================== TOKENIZER =====================
tokenizer = AutoTokenizer.from_pretrained(
    cfg.base_model,
    trust_remote_code=True
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# ===================== MODEL (SAFE LOAD) =====================
model = AutoModelForCausalLM.from_pretrained(
    cfg.base_model,
    torch_dtype=torch.float32,
    trust_remote_code=True,
    device_map=None
)

model = model.to(cfg.device)
torch.cuda.empty_cache()

print("Model loaded successfully")

# ===================== LOADERS =====================
train_loader = DataLoader(
    ChessDataset(train_data, tokenizer),
    batch_size=cfg.batch_size,
    shuffle=True
)

val_loader = DataLoader(
    ChessDataset(val_data, tokenizer),
    batch_size=cfg.batch_size
)

optimizer = AdamW(
    model.parameters(),
    lr=cfg.lr,
    weight_decay=cfg.weight_decay
)

total_steps = len(train_loader) * cfg.epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    int(total_steps * cfg.warmup_ratio),
    total_steps
)

# ===================== TRAIN =====================
best_val = float("inf")

for epoch in range(cfg.epochs):
    model.train()
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} Train"):
        batch = {k:v.to(cfg.device) for k,v in batch.items()}
        out = model(**batch)
        loss = out.loss

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            batch = {k:v.to(cfg.device) for k,v in batch.items()}
            val_loss += model(**batch).loss.item()

    val_loss /= len(val_loader)
    print(f"Epoch {epoch+1} | Val Loss: {val_loss:.4f}")

    if val_loss < best_val:
        best_val = val_loss
        model.save_pretrained(cfg.output_dir)
        tokenizer.save_pretrained(cfg.output_dir)

# ===================== GENERATION (SAFE) =====================
def generate_move(moves):
    text = '[Result "*"] ' + " ".join(moves)
    inputs = tokenizer(text, return_tensors="pt")
    inputs = {k:v.to(cfg.device) for k,v in inputs.items() if k != "token_type_ids"}

    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=6,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id
        )

    gen = tokenizer.decode(out[0], skip_special_tokens=True).split()
    for m in gen[len(text.split()):]:
        if len(m) == 4 and m.isalnum():
            return m
    return None

print("Test move:", generate_move(["e2e4","e7e5","g1f3","b8c6"]))





[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Torch: 2.8.0+cu129
CUDA: True
GPU: NVIDIA A100 80GB PCIe
Games used: 4177


tokenizer_config.json:   0%|          | 0.00/945 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/681 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/92.0M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

Model loaded successfully


Epoch 1 Train:   0%|                                                                   | 0/710 [00:00<?, ?it/s]Epoch 1 Train:   0%|                                                           | 1/710 [00:00<10:00,  1.18it/s]Epoch 1 Train:   1%|▎                                                          | 4/710 [00:00<02:18,  5.11it/s]Epoch 1 Train:   1%|▌                                                          | 7/710 [00:01<01:19,  8.85it/s]Epoch 1 Train:   1%|▊                                                         | 10/710 [00:01<00:57, 12.16it/s]Epoch 1 Train:   2%|█                                                         | 13/710 [00:01<00:46, 14.86it/s]Epoch 1 Train:   2%|█▎                                                        | 16/710 [00:01<00:40, 17.08it/s]Epoch 1 Train:   3%|█▌                                                        | 19/710 [00:01<00:36, 18.84it/s]Epoch 1 Train:   3%|█▊                                                        | 22/710 [00:01<00:34, 20

Epoch 1 | Val Loss: 1.8038


Epoch 2 Train:   0%|                                                                   | 0/710 [00:00<?, ?it/s]Epoch 2 Train:   0%|▏                                                          | 3/710 [00:00<00:30, 23.13it/s]Epoch 2 Train:   1%|▍                                                          | 6/710 [00:00<00:29, 23.64it/s]Epoch 2 Train:   1%|▋                                                          | 9/710 [00:00<00:29, 23.89it/s]Epoch 2 Train:   2%|▉                                                         | 12/710 [00:00<00:29, 23.93it/s]Epoch 2 Train:   2%|█▏                                                        | 15/710 [00:00<00:29, 23.70it/s]Epoch 2 Train:   3%|█▍                                                        | 18/710 [00:00<00:29, 23.72it/s]Epoch 2 Train:   3%|█▋                                                        | 21/710 [00:00<00:28, 23.85it/s]Epoch 2 Train:   3%|█▉                                                        | 24/710 [00:01<00:28, 23

Epoch 2 | Val Loss: 1.6988


Epoch 3 Train:   0%|                                                                   | 0/710 [00:00<?, ?it/s]Epoch 3 Train:   0%|▏                                                          | 3/710 [00:00<00:30, 23.26it/s]Epoch 3 Train:   1%|▍                                                          | 6/710 [00:00<00:29, 23.65it/s]Epoch 3 Train:   1%|▋                                                          | 9/710 [00:00<00:29, 23.76it/s]Epoch 3 Train:   2%|▉                                                         | 12/710 [00:00<00:29, 23.86it/s]Epoch 3 Train:   2%|█▏                                                        | 15/710 [00:00<00:29, 23.89it/s]Epoch 3 Train:   3%|█▍                                                        | 18/710 [00:00<00:28, 23.88it/s]Epoch 3 Train:   3%|█▋                                                        | 21/710 [00:00<00:28, 23.85it/s]Epoch 3 Train:   3%|█▉                                                        | 24/710 [00:01<00:28, 23

Epoch 3 | Val Loss: 1.6438


Epoch 4 Train:   0%|                                                                   | 0/710 [00:00<?, ?it/s]Epoch 4 Train:   0%|▏                                                          | 3/710 [00:00<00:30, 23.48it/s]Epoch 4 Train:   1%|▍                                                          | 6/710 [00:00<00:29, 23.75it/s]Epoch 4 Train:   1%|▋                                                          | 9/710 [00:00<00:29, 23.85it/s]Epoch 4 Train:   2%|▉                                                         | 12/710 [00:00<00:29, 23.94it/s]Epoch 4 Train:   2%|█▏                                                        | 15/710 [00:00<00:29, 23.90it/s]Epoch 4 Train:   3%|█▍                                                        | 18/710 [00:00<00:28, 23.92it/s]Epoch 4 Train:   3%|█▋                                                        | 21/710 [00:00<00:28, 23.85it/s]Epoch 4 Train:   3%|█▉                                                        | 24/710 [00:01<00:28, 23

Epoch 4 | Val Loss: 1.6203


Epoch 5 Train:   0%|                                                                   | 0/710 [00:00<?, ?it/s]Epoch 5 Train:   0%|▏                                                          | 3/710 [00:00<00:30, 23.21it/s]Epoch 5 Train:   1%|▍                                                          | 6/710 [00:00<00:29, 23.52it/s]Epoch 5 Train:   1%|▋                                                          | 9/710 [00:00<00:29, 23.53it/s]Epoch 5 Train:   2%|▉                                                         | 12/710 [00:00<00:29, 23.63it/s]Epoch 5 Train:   2%|█▏                                                        | 15/710 [00:00<00:29, 23.73it/s]Epoch 5 Train:   3%|█▍                                                        | 18/710 [00:00<00:29, 23.78it/s]Epoch 5 Train:   3%|█▋                                                        | 21/710 [00:00<00:28, 23.85it/s]Epoch 5 Train:   3%|█▉                                                        | 24/710 [00:01<00:28, 23

Epoch 5 | Val Loss: 1.6089


Epoch 6 Train:   0%|                                                                   | 0/710 [00:00<?, ?it/s]Epoch 6 Train:   0%|▏                                                          | 3/710 [00:00<00:30, 23.11it/s]Epoch 6 Train:   1%|▍                                                          | 6/710 [00:00<00:29, 23.58it/s]Epoch 6 Train:   1%|▋                                                          | 9/710 [00:00<00:29, 23.63it/s]Epoch 6 Train:   2%|▉                                                         | 12/710 [00:00<00:29, 23.77it/s]Epoch 6 Train:   2%|█▏                                                        | 15/710 [00:00<00:29, 23.65it/s]Epoch 6 Train:   3%|█▍                                                        | 18/710 [00:00<00:29, 23.71it/s]Epoch 6 Train:   3%|█▋                                                        | 21/710 [00:00<00:29, 23.76it/s]Epoch 6 Train:   3%|█▉                                                        | 24/710 [00:01<00:28, 23

Epoch 6 | Val Loss: 1.6028


Epoch 7 Train:   0%|                                                                   | 0/710 [00:00<?, ?it/s]Epoch 7 Train:   0%|▏                                                          | 3/710 [00:00<00:30, 23.21it/s]Epoch 7 Train:   1%|▍                                                          | 6/710 [00:00<00:29, 23.68it/s]Epoch 7 Train:   1%|▋                                                          | 9/710 [00:00<00:29, 23.86it/s]Epoch 7 Train:   2%|▉                                                         | 12/710 [00:00<00:29, 23.89it/s]Epoch 7 Train:   2%|█▏                                                        | 15/710 [00:00<00:29, 23.91it/s]Epoch 7 Train:   3%|█▍                                                        | 18/710 [00:00<00:28, 23.93it/s]Epoch 7 Train:   3%|█▋                                                        | 21/710 [00:00<00:28, 23.99it/s]Epoch 7 Train:   3%|█▉                                                        | 24/710 [00:01<00:28, 23

Epoch 7 | Val Loss: 1.5971


Epoch 8 Train:   0%|                                                                   | 0/710 [00:00<?, ?it/s]Epoch 8 Train:   0%|▏                                                          | 3/710 [00:00<00:30, 23.08it/s]Epoch 8 Train:   1%|▍                                                          | 6/710 [00:00<00:29, 23.50it/s]Epoch 8 Train:   1%|▋                                                          | 9/710 [00:00<00:29, 23.66it/s]Epoch 8 Train:   2%|▉                                                         | 12/710 [00:00<00:29, 23.76it/s]Epoch 8 Train:   2%|█▏                                                        | 15/710 [00:00<00:29, 23.78it/s]Epoch 8 Train:   3%|█▍                                                        | 18/710 [00:00<00:29, 23.81it/s]Epoch 8 Train:   3%|█▋                                                        | 21/710 [00:00<00:28, 23.81it/s]Epoch 8 Train:   3%|█▉                                                        | 24/710 [00:01<00:28, 23

Epoch 8 | Val Loss: 1.5950


Epoch 9 Train:   0%|                                                                   | 0/710 [00:00<?, ?it/s]Epoch 9 Train:   0%|▏                                                          | 3/710 [00:00<00:30, 23.51it/s]Epoch 9 Train:   1%|▍                                                          | 6/710 [00:00<00:29, 23.66it/s]Epoch 9 Train:   1%|▋                                                          | 9/710 [00:00<00:29, 23.79it/s]Epoch 9 Train:   2%|▉                                                         | 12/710 [00:00<00:29, 23.72it/s]Epoch 9 Train:   2%|█▏                                                        | 15/710 [00:00<00:29, 23.75it/s]Epoch 9 Train:   3%|█▍                                                        | 18/710 [00:00<00:29, 23.75it/s]Epoch 9 Train:   3%|█▋                                                        | 21/710 [00:00<00:28, 23.77it/s]Epoch 9 Train:   3%|█▉                                                        | 24/710 [00:01<00:29, 23

Epoch 9 | Val Loss: 1.5943


Epoch 10 Train:   0%|                                                                  | 0/710 [00:00<?, ?it/s]Epoch 10 Train:   0%|▏                                                         | 3/710 [00:00<00:31, 22.64it/s]Epoch 10 Train:   1%|▍                                                         | 6/710 [00:00<00:30, 22.84it/s]Epoch 10 Train:   1%|▋                                                         | 9/710 [00:00<00:30, 22.98it/s]Epoch 10 Train:   2%|▉                                                        | 12/710 [00:00<00:30, 23.02it/s]Epoch 10 Train:   2%|█▏                                                       | 15/710 [00:00<00:29, 23.25it/s]Epoch 10 Train:   3%|█▍                                                       | 18/710 [00:00<00:29, 23.16it/s]Epoch 10 Train:   3%|█▋                                                       | 21/710 [00:00<00:29, 23.32it/s]Epoch 10 Train:   3%|█▉                                                       | 24/710 [00:01<00:29, 23

Epoch 10 | Val Loss: 1.5933


Epoch 11 Train:   0%|                                                                  | 0/710 [00:00<?, ?it/s]Epoch 11 Train:   0%|▏                                                         | 3/710 [00:00<00:30, 23.37it/s]Epoch 11 Train:   1%|▍                                                         | 6/710 [00:00<00:29, 23.61it/s]Epoch 11 Train:   1%|▋                                                         | 9/710 [00:00<00:29, 23.72it/s]Epoch 11 Train:   2%|▉                                                        | 12/710 [00:00<00:29, 23.69it/s]Epoch 11 Train:   2%|█▏                                                       | 15/710 [00:00<00:29, 23.69it/s]Epoch 11 Train:   3%|█▍                                                       | 18/710 [00:00<00:29, 23.74it/s]Epoch 11 Train:   3%|█▋                                                       | 21/710 [00:00<00:28, 23.85it/s]Epoch 11 Train:   3%|█▉                                                       | 24/710 [00:01<00:28, 23

Epoch 11 | Val Loss: 1.5922


Epoch 12 Train:   0%|                                                                  | 0/710 [00:00<?, ?it/s]Epoch 12 Train:   0%|▏                                                         | 3/710 [00:00<00:30, 23.12it/s]Epoch 12 Train:   1%|▍                                                         | 6/710 [00:00<00:30, 23.35it/s]Epoch 12 Train:   1%|▋                                                         | 9/710 [00:00<00:29, 23.58it/s]Epoch 12 Train:   2%|▉                                                        | 12/710 [00:00<00:29, 23.55it/s]Epoch 12 Train:   2%|█▏                                                       | 15/710 [00:00<00:29, 23.60it/s]Epoch 12 Train:   3%|█▍                                                       | 18/710 [00:00<00:29, 23.61it/s]Epoch 12 Train:   3%|█▋                                                       | 21/710 [00:00<00:29, 23.74it/s]Epoch 12 Train:   3%|█▉                                                       | 24/710 [00:01<00:28, 23

Epoch 12 | Val Loss: 1.5927


Epoch 13 Train:   0%|                                                                  | 0/710 [00:00<?, ?it/s]Epoch 13 Train:   0%|▏                                                         | 3/710 [00:00<00:30, 23.48it/s]Epoch 13 Train:   1%|▍                                                         | 6/710 [00:00<00:30, 23.35it/s]Epoch 13 Train:   1%|▋                                                         | 9/710 [00:00<00:29, 23.37it/s]Epoch 13 Train:   2%|▉                                                        | 12/710 [00:00<00:29, 23.39it/s]Epoch 13 Train:   2%|█▏                                                       | 15/710 [00:00<00:29, 23.50it/s]Epoch 13 Train:   3%|█▍                                                       | 18/710 [00:00<00:29, 23.47it/s]Epoch 13 Train:   3%|█▋                                                       | 21/710 [00:00<00:29, 23.44it/s]Epoch 13 Train:   3%|█▉                                                       | 24/710 [00:01<00:29, 23

Epoch 13 | Val Loss: 1.5920


Epoch 14 Train:   0%|                                                                  | 0/710 [00:00<?, ?it/s]Epoch 14 Train:   0%|▏                                                         | 3/710 [00:00<00:30, 23.24it/s]Epoch 14 Train:   1%|▍                                                         | 6/710 [00:00<00:29, 23.73it/s]Epoch 14 Train:   1%|▋                                                         | 9/710 [00:00<00:29, 23.74it/s]Epoch 14 Train:   2%|▉                                                        | 12/710 [00:00<00:29, 23.83it/s]Epoch 14 Train:   2%|█▏                                                       | 15/710 [00:00<00:29, 23.81it/s]Epoch 14 Train:   3%|█▍                                                       | 18/710 [00:00<00:28, 23.86it/s]Epoch 14 Train:   3%|█▋                                                       | 21/710 [00:00<00:28, 23.82it/s]Epoch 14 Train:   3%|█▉                                                       | 24/710 [00:01<00:28, 23

Epoch 14 | Val Loss: 1.5916


Epoch 15 Train:   0%|                                                                  | 0/710 [00:00<?, ?it/s]Epoch 15 Train:   0%|▏                                                         | 3/710 [00:00<00:30, 23.23it/s]Epoch 15 Train:   1%|▍                                                         | 6/710 [00:00<00:29, 23.50it/s]Epoch 15 Train:   1%|▋                                                         | 9/710 [00:00<00:29, 23.61it/s]Epoch 15 Train:   2%|▉                                                        | 12/710 [00:00<00:29, 23.64it/s]Epoch 15 Train:   2%|█▏                                                       | 15/710 [00:00<00:29, 23.77it/s]Epoch 15 Train:   3%|█▍                                                       | 18/710 [00:00<00:29, 23.75it/s]Epoch 15 Train:   3%|█▋                                                       | 21/710 [00:00<00:28, 23.77it/s]Epoch 15 Train:   3%|█▉                                                       | 24/710 [00:01<00:28, 23

Epoch 15 | Val Loss: 1.5915
Test move: b5a4


IsADirectoryError: [Errno 21] Is a directory: '/root'