In [1]:
import os

DATASET_DIR = os.path.expanduser("~/Data/lichess")
DATASETS = [
    "standard_rated_2013-01_filtered.jsonl.gz",  # 7.1 MB
    # "standard_rated_2017-02_filtered.jsonl.gz",  # 682 MB
    # "standard_rated_2019-03_filtered.jsonl.gz",  # 3.8 GB
    # "standard_rated_2019-10_filtered.jsonl.gz",  # 4.3 GB
]

SEED = 1337
VAL_FRAC = 0.1
CHARS_PER_TOKEN = 1.2

In [2]:
import torch

torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
device_type = "cuda" if device.startswith("cuda") else "cpu"
print(f"using device: {device}")

using device: mps


## Load data

In [3]:
import gzip
from chesslm.lichess import LichessPGNEntry

entries: list[LichessPGNEntry] = []
for dataset in DATASETS:
    with gzip.open(os.path.join(DATASET_DIR, dataset), "rb") as f:
        entries += list(map(LichessPGNEntry.model_validate_json, f.readlines()))

display(n_entries := len(entries))
print((example := entries[0].sequence))

37938

1. e4 e6 2. d4 b6 3. a3 Bb7 4. Nc3 Nh6 5. Bxh6 gxh6 6. Be2 Qg5 7. Bg4 h5 8. Nf3 Qg6 9. Nh4 Qg5 10. Bxh5 Qxh4 11. Qf3 Kd8 12. Qxf7 Nc6 13. Qe8# 1-0


## Tokenization examples

In [4]:
from chesslm.tokenizer import PGNTokenizer

encoder = PGNTokenizer()
example_tokens = encoder.encode(example)[:24 + 1]

print(example_tokens)
print(encoder.decode(example_tokens))

[0, 5, 4, 72, 62, 4, 72, 64, 4, 9, 4, 71, 62, 4, 69, 64, 4, 10, 4, 68, 61, 4, 76, 69, 65]
<start>1. e4 e6 2. d4 b6 3. a3 Bb7


In [5]:
example_buf = torch.tensor(example_tokens)
example_x = example_buf[:-1].view(4, 6)
example_y = example_buf[1:].view(4, 6)  # predict the next token

print(example_x)
print(example_y)

tensor([[ 0,  5,  4, 72, 62,  4],
        [72, 64,  4,  9,  4, 71],
        [62,  4, 69, 64,  4, 10],
        [ 4, 68, 61,  4, 76, 69]])
tensor([[ 5,  4, 72, 62,  4, 72],
        [64,  4,  9,  4, 71, 62],
        [ 4, 69, 64,  4, 10,  4],
        [68, 61,  4, 76, 69, 65]])


## Train / test split

In [6]:
n_val_entries = int(VAL_FRAC * n_entries)

def get_sequences(entries: list[LichessPGNEntry]) -> list[str]:
    return [encoder.add_special_tokens(entry.plain_sequence) for entry in entries]

train_text = " ".join(get_sequences(entries[:-n_val_entries]))
val_text = " ".join(get_sequences(entries[-n_val_entries:]))

print(train_text[:200])

<start>1. e4 e6 2. d4 b6 3. a3 Bb7 4. Nc3 Nh6 5. Bxh6 gxh6 6. Be2 Qg5 7. Bg4 h5 8. Nf3 Qg6 9. Nh4 Qg5 10. Bxh5 Qxh4 11. Qf3 Kd8 12. Qxf7 Nc6 13. Qe8# 1-0<end> <start>1. e4 g6 2. d4 d6 3. Nf3 c6 4. h3 


## Utils

In [7]:
import re

from chesslm.tokenizer import UNK_TOKEN_ID

class DataLoaderLite:
    def __init__(
        self,
        B: int,
        T: int,
        text: str,
        loop: bool = False,
    ):
        self.B: int = B
        self.T: int = T
        self.text: str = text
        self.loop: bool = loop
        self.reset()

    @property
    def BT(self) -> int:
        return self.B * self.T

    def reset(self):
        self._buf = []
        self._pos = 0

    def next_tokens(self, add_pred_token: bool = False) -> list[int]:
        BT = self.BT
        pos_step = int((BT + 1) * CHARS_PER_TOKEN)

        while len(self._buf) < BT + 1:
            segment = self.text[self._pos : self._pos + pos_step]
            if not(segment):
                raise RuntimeError("no tokens remaining")

            # e.g. "6. Be2 O-O 7. O-"
            if match := re.search(r" +", segment[::-1]):
                segment = segment[: len(segment) - match.end() + 1]
            tokens = encoder.encode(segment, add_special_tokens=False)

            if UNK_TOKEN_ID in tokens:
                unk_token_idx = len(
                    encoder.decode(tokens[: tokens.index(UNK_TOKEN_ID) - 1])
                )
                unk_token_text = segment[unk_token_idx - 50 : unk_token_idx]
                unk_token_text += " <unk> "
                unk_token_text += segment[unk_token_idx : unk_token_idx + 50]
                raise ValueError(f"<unk> found: '{unk_token_text}'")

            self._buf.extend(tokens)
            self._pos += len(segment)

        tokens = self._buf[: BT + add_pred_token]  # we want this many tokens
        self._buf = self._buf[BT :]  # remove BT tokens (not BT + 1!)
        return tokens

    def next_batch(self) -> tuple[torch.Tensor, torch.Tensor]:
        batch_tokens = torch.as_tensor(self.next_tokens(add_pred_token=True))
        try:
            x = (batch_tokens[:-1]).view(self.B, self.T) # inputs
            y = (batch_tokens[1:]).view(self.B, self.T) # targets
        except RuntimeError:
            raise RuntimeError("no more batches remaining")
        return x, y

    def __iter__(self):
        return self

    def __next__(self) -> tuple[torch.Tensor, torch.Tensor]:
        try:
            return self.next_batch()
        except RuntimeError:
            if self.loop:
                self.reset()
                try:
                    return self.next_batch()
                except RuntimeError:
                    raise StopIteration
            else:
                raise StopIteration

In [8]:
example = train_text[:200]
dataloader = DataLoaderLite(3, 6, example)

def print_next_tokens(add_pred_token: bool):
    print(dataloader.next_tokens(add_pred_token=add_pred_token))

def print_next_segment(add_pred_token: bool):
    print(encoder.decode(dataloader.next_tokens(add_pred_token=add_pred_token)).replace("\n", "\\n"))

dataloader.reset()
print_next_tokens(True)
print_next_tokens(True)
print_next_tokens(True)
print_next_tokens(True)
print()
dataloader.reset()
print_next_segment(True)
print_next_segment(True)
print_next_segment(True)
print_next_segment(True)
print()
dataloader.reset()
print_next_tokens(False)
print_next_tokens(False)
print_next_tokens(False)
print_next_tokens(False)
print()
dataloader.reset()
print_next_segment(False)
print_next_segment(False)
print_next_segment(False)
print_next_segment(False)

print()
print(example)

[0, 5, 4, 72, 62, 4, 72, 64, 4, 9, 4, 71, 62, 4, 69, 64, 4, 10, 4]
[4, 68, 61, 4, 76, 69, 65, 4, 11, 4, 78, 70, 61, 4, 78, 75, 64, 4, 12]
[12, 4, 76, 84, 75, 64, 4, 74, 84, 75, 64, 4, 13, 4, 76, 72, 60, 4, 79]
[79, 74, 63, 4, 14, 4, 76, 74, 62, 4, 75, 63, 4, 15, 4, 78, 73, 61, 4]

<start>1. e4 e6 2. d4 b6 3. 
 a3 Bb7 4. Nc3 Nh6 5.
5. Bxh6 gxh6 6. Be2 Q
Qg5 7. Bg4 h5 8. Nf3 

[0, 5, 4, 72, 62, 4, 72, 64, 4, 9, 4, 71, 62, 4, 69, 64, 4, 10]
[4, 68, 61, 4, 76, 69, 65, 4, 11, 4, 78, 70, 61, 4, 78, 75, 64, 4]
[12, 4, 76, 84, 75, 64, 4, 74, 84, 75, 64, 4, 13, 4, 76, 72, 60, 4]
[79, 74, 63, 4, 14, 4, 76, 74, 62, 4, 75, 63, 4, 15, 4, 78, 73, 61]

<start>1. e4 e6 2. d4 b6 3.
 a3 Bb7 4. Nc3 Nh6 
5. Bxh6 gxh6 6. Be2 
Qg5 7. Bg4 h5 8. Nf3

<start>1. e4 e6 2. d4 b6 3. a3 Bb7 4. Nc3 Nh6 5. Bxh6 gxh6 6. Be2 Qg5 7. Bg4 h5 8. Nf3 Qg6 9. Nh4 Qg5 10. Bxh5 Qxh4 11. Qf3 Kd8 12. Qxf7 Nc6 13. Qe8# 1-0<end> <start>1. e4 g6 2. d4 d6 3. Nf3 c6 4. h3 


In [9]:
import math
from pydantic import BaseModel

class CosineLRSchedule(BaseModel):
    max_lr: float
    min_lr: float
    warmup_steps: float
    max_steps: float

    def __call__(self, it: int) -> float:
        # 1) linear warmup for warmup_iters steps
        if it < self.warmup_steps:
            return self.max_lr * (it + 1) / self.warmup_steps
        # 2) if it > lr_decay_iters, return min learning rate
        if it > self.max_steps:
            return self.min_lr
        # 3) in between, use cosine decay down to min learning rate
        decay_ratio = (it - self.warmup_steps) / (self.max_steps - self.warmup_steps)
        assert 0 <= decay_ratio <= 1
        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # starts at 1 and goes to 0
        return self.min_lr + coeff * (self.max_lr - self.min_lr)

In [10]:
from chesslm.gpt import GPT
from torch.nn import functional as F

def evaluate_model(model: GPT, dataloader: DataLoaderLite, val_steps: int = 20) -> float:
    model.eval()
    dataloader.reset()
    with torch.no_grad():
        val_loss = 0.0
        for _ in range(val_steps):
            x, y = dataloader.next_batch()
            x, y = x.to(device), y.to(device)
            with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
                _, loss = model(x, y)
            loss = loss / val_steps
            val_loss += loss.detach().item()
    return val_loss

def write_model_checkpoint(
        model: GPT,
        step: int,
        val_loss: float,
        fpath: str
    ):
    checkpoint = {
        'model': model.state_dict(),
        'config': model.config,
        'step': step,
        'val_loss': val_loss
    }
    # you might also want to add optimizer.state_dict() and
    # rng seeds etc., if you wanted to more exactly resume training
    torch.save(checkpoint, fpath)

def generate_sequences(model: GPT, prefix: str, num_sequences: int, max_length: int) -> list[str]:
    tokens = encoder.encode(prefix)
    tokens = torch.tensor(tokens, dtype=torch.long)
    tokens = tokens.unsqueeze(0).repeat(num_sequences, 1)
    xgen = tokens.to(device)

    sample_rng = torch.Generator(device=device)
    sample_rng.manual_seed(SEED)

    while xgen.size(1) < max_length:
        # forward the model to get the logits
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, _ = model(xgen) # (B, T, vocab_size)

        # take the logits at the last position
        logits = logits[:, -1, :] # (B, vocab_size)

        # get the probabilities
        probs = F.softmax(logits, dim=-1)

        # do top-k sampling of 50 (huggingface pipeline default)
        # topk_probs here becomes (5, 50), topk_indices is (5, 50)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)

        # select a token from the top-k probabilities
        # note: multinomial does not demand the input to sum to 1
        ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)

        # gather the corresponding indices
        xcol = torch.gather(topk_indices, -1, ix) # (B, 1)

        # append to the sequence
        xgen = torch.cat((xgen, xcol), dim=1)

    return [encoder.decode(x[:max_length].tolist()) for x in xgen]

In [11]:
from chesslm.gpt import GPTConfig

MAX_STEPS = 100
WARMUP_STEPS = 10

PRINT_TRAIN_INTERVAL = 5
EVAL_INTERVAL = 25
EVAL_NUM_SEQUENCES = 4
EVAL_SEQUENCE_MAX_LENGTH = 32

MAX_LR = 6e-4
MIN_LR = MAX_LR * 0.1
WEIGHT_DECAY = 0.1
GRAD_ACCUM_STEPS = 2

T = 512
GPT_CONFIG = GPTConfig(
    block_size=T,
    vocab_size=encoder.n_vocab,
    n_layer=4,
    n_head=4,
    n_embd=384,
)

B = 4
BT = B * T

EVAL_SEQUENCE = " ".join(val_text.split(" ")[:10])

print(len(train_text))
print(int((BT + 1) * CHARS_PER_TOKEN * MAX_STEPS))

15736652
245879


In [12]:
import time

model = GPT(GPT_CONFIG).to(device)

cosine_schedule = CosineLRSchedule(
    max_lr=MAX_LR,
    min_lr=MIN_LR,
    warmup_steps=WARMUP_STEPS,
    max_steps=MAX_STEPS,
)

train_loader = DataLoaderLite(B=B, T=T, text=train_text)
val_loader = DataLoaderLite(B=B, T=T, text=val_text, loop=True)

optimizer = model.configure_optimizers(
    weight_decay=WEIGHT_DECAY,
    learning_rate=MAX_LR,
    device_type=device_type
)

for step in range(MAX_STEPS):
    t0 = time.time()
    last_step = (step == MAX_STEPS - 1)

    # once in a while:
    # evaluate our validation loss
    # generate from the model and print
    if step % EVAL_INTERVAL == 0 or last_step:
        model.eval()
        with torch.no_grad():
            val_loss = evaluate_model(model, val_loader)
            sequences = generate_sequences(
                model,
                EVAL_SEQUENCE,
                EVAL_NUM_SEQUENCES,
                EVAL_SEQUENCE_MAX_LENGTH,
            )

        print("\n--------")
        print(f"Validation loss: {val_loss:.4f}")
        for i, seq in enumerate(sequences, 1):
            seq = seq.replace('\n', '\\n')
            print(f"Sample {i}: {seq}")
        print("--------\n")

    # do one step of the optimization
    model.train()
    optimizer.zero_grad()
    train_loss = 0.0
    for micro_step in range(GRAD_ACCUM_STEPS):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)

        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(x, y)

        loss = loss / GRAD_ACCUM_STEPS
        loss.backward()

        train_loss += loss.detach().item()

    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr = cosine_schedule(step)

    optimizer.step()

    if device_type == "cuda":
        torch.cuda.synchronize() # wait for the GPU to finish work

    t1 = time.time()
    dt = t1 - t0 # time difference in seconds
    tokens_processed = BT * GRAD_ACCUM_STEPS
    tokens_per_sec = tokens_processed / dt

    if step % PRINT_TRAIN_INTERVAL == 0:
        print(
            f"step {step:5d}"
            f" | loss: {train_loss:.6f}"
            f" | lr {lr:.4e}"
            f" | norm: {norm:.4f}"
            f" | dt: {dt*1000:.2f}ms"
            f" | tok/sec: {tokens_per_sec:.2f}"
        )

num decayed parameter tensors: 18, with 7,307,904 parameters
num non-decayed parameter tensors: 34, with 20,736 parameters
using fused AdamW: False

--------
Validation loss: 4.5987
Sample 1: <start>1. e4 f6 2. Nc3 e5 3. Qh5+ g6 4.<end>31.31.
Sample 2: <start>1. e4 f6 2. Nc3 e5 3. Qh5+ g6 4.<end>34.7.
Sample 3: <start>1. e4 f6 2. Nc3 e5 3. Qh5+ g6 4.<end>R16.
Sample 4: <start>1. e4 f6 2. Nc3 e5 3. Qh5+ g6 4.<end>49.R
--------

step     0 | loss: 4.603565 | lr 6.0000e-05 | norm: 19.5527 | dt: 1357.80ms | tok/sec: 3016.65
step     5 | loss: 3.266981 | lr 3.6000e-04 | norm: 3.1363 | dt: 149.89ms | tok/sec: 27327.06
step    10 | loss: 2.757719 | lr 6.0000e-04 | norm: 3.6970 | dt: 141.78ms | tok/sec: 28890.82
step    15 | loss: 2.229734 | lr 5.9590e-04 | norm: 1.2082 | dt: 141.34ms | tok/sec: 28979.17
step    20 | loss: 2.128807 | lr 5.8372e-04 | norm: 0.8028 | dt: 141.26ms | tok/sec: 28996.78

--------
Validation loss: 2.0854
Sample 1: <start>1. e4 f6 2. Nc3 e5 3. Qh5+ g6 4.<end> N
Sample 