In [2]:
import torch 
import torch.nn as nn 
from torch.nn import functional as F
import chess 
import chess.pgn
import math 
from dataclasses import dataclass 
import numpy as np 
import os 
import time 
from torch.utils.data import Dataset, DataLoader

In [11]:
class ChessDataset(Dataset): 
    def __init__(self, white, black): 
        super().__init__()

        assert os.path.exists(f"data/{white}.pgn"), f"data/{white} does not exist" 
        assert os.path.exists(f"data/{white}_fen.txt"), f"data/{white}_fen.txt does not exist" 

        self.white_uci = self.load_uci_moves(white)
        self.white_fen = self.load_fen_moves(white)

        assert len(self.white_uci) == len(self.white_fen), f"white_uci and white_fen are not the same length"

        assert os.path.exists(f"data/{black}.pgn"), f"data/{black} does not exist"
        assert os.path.exists(f"data/{black}_fen.txt"), f"data/{black}_fen.txt does not exist"

        self.black_uci = self.load_uci_moves(black)
        self.black_fen = self.load_fen_moves(black)

        assert len(self.black_uci) == len(self.black_fen), f"black_uci and black_fen are not the same length"

        self.pgn_vocab = self.load_pgn_vocab() 
         
        self.pgn_embedding = nn.Embedding(131072, 768)
        self.uci_embedding = nn.Embedding(8192, 768)

    def load_pgn_vocab(self): 
        vocab = open("data/PGNVocab.txt", 'r')
        vocab_dict = dict()
        for (idx,line) in enumerate(vocab.readlines()): 
            vocab_dict[line.strip("\n")] = idx
            
        vocab.close() 

        return vocab_dict
        
    def load_fen_moves(self, file_name): 
        with open(f"data/{file_name}_fen.txt", 'r') as f: 
            lines = f.readlines() 
            games = []
            for line in lines: 
                games.append(line.strip("\n").split(";")[:-1])
        f.close() 

        return games

    def load_uci_moves(self, file_name): 
        pgn = open(f"data/{file_name}.pgn", 'r')
        games = [] 

        while True: 
            game = chess.pgn.read_game(pgn)
            if game is None: 
                break
            board = game.board() 
            moves = []
            for move in game.mainline_moves(): 
                moves.append(str(move))
            games.append(moves)
        
        return games 

    def load_pgn_moves(self, file_name): 
        pgn = open(f"data/{file_name}.pgn", "r")
        games = [] 

        pgn_content = pgn.read() 

        games = pgn_content.split("\n\n")[:-1]

        game_move_list = []

        for game in games: 
            moves = [move[max(0,move.find(".")+1):].strip("\n") for move in game.split(" ")[:-2]]
            # print(moves)
            game_move_list.append(moves)

        return game_move_list

    def __len__(self): 
        return len(self.white_uci) + len(self.black_uci)

    def test(self): 
        white_uci_game = self.white_uci[0]
        white_fen_game = self.white_fen[0]
        game = {"UCI": white_uci_game, "FEN": white_fen_game, "color": "white"}

        return game 

    def __getitem__(self, idx): 
        white_uci_game = self.white_uci[idx]
        white_fen_game = self.white_fen[idx]

        game = {"UCI": white_uci_game, "FEN": white_fen_game, "color": "white"}

        if idx >= len(self.white_uci): 
            idx = idx - len(self.white_uci)
            black_uci_game = self.black_uci[idx]
            black_fen_game = self.black_fen[idx]
            game = {"UCI": black_uci_game, "FEN": black_fen_game, "color": "black"}
        
        return game 


In [None]:
class Chess: 
    def __init__(self)
        self.engine = Stockfish(path="/Users/austinliu/Stockfish/src/stockfish")
        self.board = chess.Board()
        

In [4]:
class LayerNorm(nn.Module):
    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

# masked attention 
class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) # in features, out features. ok this makes sense because it is a n_embd size matrix. 
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout

        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # hs = hidden state 
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) # .view() returns a new tensor of the same data but with a different shape specified by the args 
        # reshape() is a more robust version of view(). has to do with the tensor being contiguous or not. view() is more efficient if the tensor is contiguous. 
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

class CausalCrossAttention(CausalSelfAttention): 

    def __init__(self, config): 
        super().__init__()
        self.c_attn = nn.Linear(config.n_embd, config.n_embd * 2, bias=config.bias)
        
    def forward(self, x1, x2): 
        # x1 provides the Q and K 
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        q, k = self.c_attn(x1).split(self.n_embd, dim=2)
        v = self.c_attn(x2)

        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        if self.flash: 
            y = torch.nn.functional.scaled_dot_product_attention(q, k , v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
        else: 
            att = (q @ k.transpose(-2,-1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 

        # output projection 
        y = self.resid_dropout(self.c_proj(y))
        return y 

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) # classifier fully connected
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) # classifier projection 
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = new_gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class EncoderBlock(nn.Module):
    # uses pre layer norm 
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class DecoderBlock(nn.Module): 
    def __init__(self, config): 
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn_1 = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn_2 = CausalCrossAttention(config)
        self.ln_3 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)
    
    def forward(self, x1, x2): 
        x2 = x2 + self.attn_1(self.ln_1(x2))
        x2 = x2 + self.attn_2(self.ln_2(x2), x1)
        x2 = x2 + self.mlp(self.ln_3(x2))
        return x2

@dataclass 
class ChessGPTConfig: 
    block_size: int = 512 # longest chess game was 269 moves or 538 half moves. However, that's the only one. We can discard that as an outlier. 512 covers basically every chess game. 
    vocab_size: int = 131072 # using the PGNVocab.txt, there are 125394 words in the vocabulary. Not every vocabulary will be found in real games. Most will not because most of the moves are very rare. Some are also impossible. 
    n_layer: int = 12 # number of encoder and decoder blocks in the stack 
    n_head: int = 12 # multihead attention blocks 
    n_embd: int = 768 # size of the embedding... originally chosen by GPT-2 ... may have to modify later 
    dropout: float = 0.0 # none for now 
    bias: bool = True 

In [5]:
class ChessGPT(nn.Module): 

    def __init__(self, config): 
        super().__init__() 
        self.config = config 

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd), 
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout), 
            enc = nn.ModuleList([EncoderBlock(config) for _ in range(config.n_layer)]),
            dec = nn.ModuleList([DecoderBlock(config) for _ in range(config.n_layer)]), 
            ln_f = nn.LayerNorm(config.n_embd, bias = config.bias)
        ))

        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight 

        self.apply(self._init_weights)

        for pn, p in self.named_parameters(): 
            if pn.endswith('c_proj.weight'): 
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True): 
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding: 
            n_params -= self.transformer.wpe.weight.numel()
        return n_params 

    def _init_weights(self, module): 
        if isinstance(module, nn.Linear): 
            torch.nn.init.normal_(module.weight, mena=0.0, std=0.02)
            if module.bias is not None: 
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            # initializes module.weight with normal distribution from N(mean, std^2)

    def forward(self, idx, targets=None): 
        device = idx.device 
        b, t = idx.size() 
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)

        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h: 
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None: 
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else: 
            logits = self.lm_head(x[:, [-1], :])
            loss = None 

        return logits, loss 

    @torch.no_grad()
    def generate(self, idx, max_new_tokens = 1, temperature=1.0, top_k=None): 
        for _ in range(max_new_tokens): 
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None: 
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)

        return idx 

In [16]:
dataset = ChessDataset("tal_white_games", "tal_black_games")

In [17]:
dataset.load_pgn_moves("tal_white_games")

[['d4',
  'f5',
  'e4',
  'fxe4',
  'Nc3',
  'Nf6',
  'f3',
  'd5',
  'fxe4',
  'dxe4',
  'Bc4',
  'Bf5',
  'Nge2',
  'Nc6',
  'O-O',
  'Bb5',
  'a6',
  'Ba4',
  'Qd7',
  'Bg5',
  'O-O-O',
  'Kh1',
  'Be7',
  'Bxf6',
  'Bxf6',
  'd5',
  'exd5',
  'Bxc6',
  'Nd4',
  'Bg4',
  'Qd2',
  'Qd6',
  'Nb3',
  'c5',
  'h3',
  'h5',
  'Na5',
  'e3',
  'Qxe3',
  'd4',
  'Qe4',
  'Qb7+',
  'Kd7',
  'Nc4',
  'Qd4',
  'Nb6+',
  'Ke8',
  'bxc3',
  'Qd6',
  'Nc4',
  'Qg3',
  'Rae1+',
  'Kf8',
  'Re3',
  'Qxc7',
  'Kg8',
  'Kg1',
  'Bc8',
  'Rf4',
  'Qg5',
  'Ref3',
  'Kh7',
  'h4',
  'Rd1+',
  'Kh2',
  'Qd5',
  'Rg3',
  'Rxf6',
  'Qxc4',
  'Rf5',
  'Kh6',
  'Rxh5+',
  'Kxh5',
  'Qxg7',
  'Qf4',
  'Qxh8+',
  'Kg6',
  'Qg8+',
  'Qc8+',
  'Ke5',
  'Qxg4',
  'Qxg4',
  'Rxg4',
  'Rd2',
  'Rg5+',
  'Kf4',
  'Rxc5',
  'Rxc2',
  'Rc6',
  'Rc4+',
  'Kh5',
  'a4',
  'Rd2',
  'Kh3',
  'Rd3+',
  'g3',
  'Rd6',
  'a5',
  'Rd5',
  'g4+',
  'Kg6',
  'Rc6+',
  'Rxa6',
  'Rd3+',
  'Kg2',
  'Rxc3',
  'h5',
  'Kg7',
  'R

In [53]:
print(type(dataset.test()['UCI'][0]))

<class 'str'>


In [8]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [9]:
for data in dataloader: 
    print(data)

IndexError: list index out of range