In [199]:
import torch 
import torch.nn as nn 
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

import chess 
import chess.pgn 
import chess.engine
from stockfish import Stockfish

import math 
from dataclasses import dataclass 
import numpy as np 
import os 
import time 
import asyncio

print("Chess", chess.__version__)
print("Torch", torch.__version__)

Chess 1.9.4
Torch 2.0.0


## Data Preprocessing

Only need to run once. Create a data folder before running. 

### 1. Generate PGN and FEN files for Black and White Colors

In [2]:
# file_path = "data/Tal.pgn"

# # save the games based on if tal played white or black
# games = {
#     "white": [], 
#     "black": []
# }

# with open(file_path) as f: 
#     lines = f.readlines() 
#     tal_color = None 
#     on_pgn_line = False  
#     pgn = ""
#     for line in lines:
#         if on_pgn_line:
#             if line.startswith("[Event"): 
#                 on_pgn_line = False 
#                 games[tal_color].append(pgn)
#                 pgn = ""
#             else: 
#                 pgn += line 
#         if line.startswith("[White "): 
#             player_white = line.split('"')[1]
#             if player_white == "Tal, Mihail":
#                 tal_color = "white"
#         elif line.startswith("[Black "):
#             player_black = line.split('"')[1]
#             if player_black == "Tal, Mihail":
#                 tal_color = "black"
#         elif line.startswith("1."): 
#             on_pgn_line = True 
#             pgn += line 
#         else: 
#             # print(tal_color)
#             continue 


# f.close()  

# with open("data/tal_white_games.pgn", "w") as f: 
#     for game in games["white"]:
#         f.write(game)

# f.close()

# with open("data/tal_black_games.pgn", "w") as f:
#     for game in games["black"]:
#         f.write(game)

# f.close()

### 2. Generate UCI Vocab

In [3]:
# from itertools import product

# files = ["a", "b", "c", "d", "e", "f", "g", "h"]
# ranks = ["1", "2", "3", "4", "5", "6", "7", "8"]
# promotion = ["Q", "R", "B", "N"]

# # gets an upper bound for the UCI vocab

# with open("data/UCIvocab.txt", "w") as f:
#     for file1, rank1, file2, rank2 in product(files, ranks, files, ranks):
#         if file1 == file2 and rank1 == rank2:
#             continue 
#         if rank2 == "8" and rank1 == "7" or rank2 == "1" and rank1 == "2":
#             idx1 = files.index(file1)
#             idx2 = files.index(file2)
#             if abs(idx1 - idx2) > 1:
#                 continue
#             for promotion_piece in promotion:
#                 f.write(f'{file1}{rank1}{file2}{rank2}{promotion_piece}\n') 
#         f.write(f'{file1}{rank1}{file2}{rank2}\n')

# f.close() 

### 3. Generate PGN Vocab

In [115]:
# pieces = ["K", "Q", "R", "B", "N"] 
# print(pieces[1:])
# ranks = ["1", "2", "3", "4", "5", "6", "7", "8"]
# for rank in ranks[1:7]:
#     print(rank)

['Q', 'R', 'B', 'N']
2
3
4
5
6
7


In [128]:
# pieces = ["K", "Q", "R", "B", "N"] 
# pawns = ["a", "b", "c", "d", "e", "f", "g", "h"]
# files = ["a", "b", "c", "d", "e", "f", "g", "h"]
# ranks = ["1", "2", "3", "4", "5", "6", "7", "8"]

# moves = set()

# for piece in pieces: 
#     for file in files: 
#         for rank in ranks: 
#             moves.add(f"{piece}{file}{rank}")
#             moves.add(f"{piece}{file}{rank}+")
#             moves.add(f"{piece}{file}{rank}#")
#             moves.add(f"{piece}x{file}{rank}")
#             moves.add(f"{piece}x{file}{rank}+")
#             moves.add(f"{piece}x{file}{rank}#")
#             # 2 pieces can move to the same square 
#             if piece == "R" or piece == "N" or piece == "B" or piece == "Q": 
#                 for rank1 in ranks: 
#                     moves.add(f"{piece}{rank1}{file}{rank}")
#                     moves.add(f"{piece}{rank1}{file}{rank}+")
#                     moves.add(f"{piece}{rank1}{file}{rank}#")
#                     moves.add(f"{piece}{rank1}x{file}{rank}")
#                     moves.add(f"{piece}{rank1}x{file}{rank}+")
#                     moves.add(f"{piece}{rank1}x{file}{rank}#")
#                 for file1 in files: 
#                     moves.add(f"{piece}{file1}{file}{rank}")
#                     moves.add(f"{piece}{file1}{file}{rank}+")
#                     moves.add(f"{piece}{file1}{file}{rank}#")
#                     moves.add(f"{piece}{file1}x{file}{rank}")
#                     moves.add(f"{piece}{file1}x{file}{rank}+")
#                     moves.add(f"{piece}{file1}x{file}{rank}#")
#                 # if there are 3 pieces due to promotion, need to specify file and rank 
#                 for rank1 in ranks: 
#                     for file1 in files: 
#                         moves.add(f"{piece}{file1}{rank1}{file}{rank}")
#                         moves.add(f"{piece}{file1}{rank1}{file}{rank}+")
#                         moves.add(f"{piece}{file1}{rank1}{file}{rank}#")
#                         moves.add(f"{piece}{file1}{rank1}x{file}{rank}")
#                         moves.add(f"{piece}{file1}{rank1}x{file}{rank}+")
#                         moves.add(f"{piece}{file1}{rank1}x{file}{rank}#")

# for (idx, pawn) in enumerate(pawns): 
#     for rank in ranks[1:7]: #2,3,4,5,6,7
#         if pawn == "a": 
#             if rank == "7": 
#                 for piece in pieces[1:]: 
#                     moves.add(f"{pawn}xb{str(int(rank)+1)}={piece}")
#                     moves.add(f"{pawn}xb{str(int(rank)+1)}={piece}+")
#                     moves.add(f"{pawn}xb{str(int(rank)+1)}={piece}#")
#                     moves.add(f"{pawn}{str(int(rank)+1)}={piece}")
#                     moves.add(f"{pawn}{str(int(rank)+1)}={piece}+")
#                     moves.add(f"{pawn}{str(int(rank)+1)}={piece}#")
#             elif rank == "2":
#                 for piece in pieces[1:]:
#                     moves.add(f"{pawn}xb{str(int(rank)-1)}={piece}")
#                     moves.add(f"{pawn}xb{str(int(rank)-1)}={piece}+")
#                     moves.add(f"{pawn}xb{str(int(rank)-1)}={piece}#")
#                     moves.add(f"{pawn}{str(int(rank)-1)}={piece}")
#                     moves.add(f"{pawn}{str(int(rank)-1)}={piece}+")
#                     moves.add(f"{pawn}{str(int(rank)-1)}={piece}#")
#             else: 
#                 moves.add(f"{pawn}xb{str(int(rank)+1)}")
#                 moves.add(f"{pawn}xb{str(int(rank)+1)}+")
#                 moves.add(f"{pawn}xb{str(int(rank)+1)}#")
#                 moves.add(f"{pawn}xb{str(int(rank)-1)}")
#                 moves.add(f"{pawn}xb{str(int(rank)-1)}+")
#                 moves.add(f"{pawn}xb{str(int(rank)-1)}#")
#                 moves.add(f"{pawn}{str(int(rank)+1)}")
#                 moves.add(f"{pawn}{str(int(rank)+1)}+")
#                 moves.add(f"{pawn}{str(int(rank)+1)}#")
#                 moves.add(f"{pawn}{str(int(rank)-1)}")
#                 moves.add(f"{pawn}{str(int(rank)-1)}+")
#                 moves.add(f"{pawn}{str(int(rank)-1)}#")
#         elif pawn == "h": 
#             if rank == "7":
#                 for piece in pieces[1:]:
#                     moves.add(f"{pawn}xg{str(int(rank)+1)}={piece}")
#                     moves.add(f"{pawn}xg{str(int(rank)+1)}={piece}+")
#                     moves.add(f"{pawn}xg{str(int(rank)+1)}={piece}#")
#                     moves.add(f"{pawn}{str(int(rank)+1)}={piece}")
#                     moves.add(f"{pawn}{str(int(rank)+1)}={piece}+")
#                     moves.add(f"{pawn}{str(int(rank)+1)}={piece}#")
#             elif rank == "2":
#                 for piece in pieces[1:]:
#                     moves.add(f"{pawn}xg{str(int(rank)-1)}={piece}")
#                     moves.add(f"{pawn}xg{str(int(rank)-1)}={piece}+")
#                     moves.add(f"{pawn}xg{str(int(rank)-1)}={piece}#")
#                     moves.add(f"{pawn}{str(int(rank)-1)}={piece}")
#                     moves.add(f"{pawn}{str(int(rank)-1)}={piece}+")
#                     moves.add(f"{pawn}{str(int(rank)-1)}={piece}#")
#             else: 
#                 moves.add(f"{pawn}xg{str(int(rank)+1)}")
#                 moves.add(f"{pawn}xg{str(int(rank)+1)}+")
#                 moves.add(f"{pawn}xg{str(int(rank)+1)}#")
#                 moves.add(f"{pawn}xg{str(int(rank)-1)}")
#                 moves.add(f"{pawn}xg{str(int(rank)-1)}+")
#                 moves.add(f"{pawn}xg{str(int(rank)-1)}#")
#                 moves.add(f"{pawn}{str(int(rank)+1)}")
#                 moves.add(f"{pawn}{str(int(rank)+1)}+")
#                 moves.add(f"{pawn}{str(int(rank)+1)}#")
#                 moves.add(f"{pawn}{str(int(rank)-1)}")
#                 moves.add(f"{pawn}{str(int(rank)-1)}+")
#                 moves.add(f"{pawn}{str(int(rank)-1)}#")
#         else: 
#             if rank == "7":
#                 for piece in pieces[1:]: 
#                     moves.add(f"{pawn}{str(int(rank)+1)}={piece}")
#                     moves.add(f"{pawn}{str(int(rank)+1)}={piece}+")
#                     moves.add(f"{pawn}{str(int(rank)+1)}={piece}#")
#                     moves.add(f"{pawn}x{pawns[idx-1]}{str(int(rank)+1)}={piece}")
#                     moves.add(f"{pawn}x{pawns[idx-1]}{str(int(rank)+1)}={piece}+")
#                     moves.add(f"{pawn}x{pawns[idx-1]}{str(int(rank)+1)}={piece}#")
#                     moves.add(f"{pawn}x{pawns[idx+1]}{str(int(rank)+1)}={piece}")
#                     moves.add(f"{pawn}x{pawns[idx+1]}{str(int(rank)+1)}={piece}+")
#                     moves.add(f"{pawn}x{pawns[idx+1]}{str(int(rank)+1)}={piece}#")
#             elif rank == "2":
#                 for piece in pieces[1:]: 
#                     moves.add(f"{pawn}{str(int(rank)-1)}={piece}")
#                     moves.add(f"{pawn}{str(int(rank)-1)}={piece}+")
#                     moves.add(f"{pawn}{str(int(rank)-1)}={piece}#")
#                     moves.add(f"{pawn}x{pawns[idx-1]}{str(int(rank)-1)}={piece}")
#                     moves.add(f"{pawn}x{pawns[idx-1]}{str(int(rank)-1)}={piece}+")
#                     moves.add(f"{pawn}x{pawns[idx-1]}{str(int(rank)-1)}={piece}#")
#                     moves.add(f"{pawn}x{pawns[idx+1]}{str(int(rank)-1)}={piece}")
#                     moves.add(f"{pawn}x{pawns[idx+1]}{str(int(rank)-1)}={piece}+")
#                     moves.add(f"{pawn}x{pawns[idx+1]}{str(int(rank)-1)}={piece}#")
#             else: 
#                 moves.add(f"{pawn}x{pawns[idx-1]}{str(int(rank)+1)}")
#                 moves.add(f"{pawn}x{pawns[idx-1]}{str(int(rank)+1)}+")
#                 moves.add(f"{pawn}x{pawns[idx-1]}{str(int(rank)+1)}#")
#                 moves.add(f"{pawn}x{pawns[idx+1]}{str(int(rank)+1)}")
#                 moves.add(f"{pawn}x{pawns[idx+1]}{str(int(rank)+1)}+")
#                 moves.add(f"{pawn}x{pawns[idx+1]}{str(int(rank)+1)}#")
#                 moves.add(f"{pawn}x{pawns[idx-1]}{str(int(rank)-1)}")
#                 moves.add(f"{pawn}x{pawns[idx-1]}{str(int(rank)-1)}+")
#                 moves.add(f"{pawn}x{pawns[idx-1]}{str(int(rank)-1)}#")
#                 moves.add(f"{pawn}x{pawns[idx+1]}{str(int(rank)-1)}")
#                 moves.add(f"{pawn}x{pawns[idx+1]}{str(int(rank)-1)}+")
#                 moves.add(f"{pawn}x{pawns[idx+1]}{str(int(rank)-1)}#")
#                 moves.add(f"{pawn}{str(int(rank)+1)}")
#                 moves.add(f"{pawn}{str(int(rank)+1)}+")
#                 moves.add(f"{pawn}{str(int(rank)+1)}#")
#                 moves.add(f"{pawn}{str(int(rank)-1)}")
#                 moves.add(f"{pawn}{str(int(rank)-1)}+")
#                 moves.add(f"{pawn}{str(int(rank)-1)}#")

# # add castling         
# moves.add("O-O")
# moves.add("O-O+")
# moves.add("O-O#")
# moves.add("O-O-O")
# moves.add("O-O-O+")
# moves.add("O-O-O#")

# with open("data/PGNVocab.txt", "w") as f: 
#     f.write("<eos>\n1-0\n0-1\n1/2-1/2\n\n*\n")
#     for move in moves: 
#         f.write(move + "\n")

# f.close() 

In [129]:
# file_names = ["tal_black_games", "tal_white_games"]

# for file in file_names:
#     pgn = open(f"data/{file}.pgn", 'r')
#     iter = 0 
#     while iter < 1: 
#         iter += 1 
#         game = chess.pgn.read_game(pgn)
#         if game is None: 
#             break
#         board = game.board()
#         # print(type(board)) 
#         for move in game.mainline_moves():
#             # print(type(move))
#             # print(move)
#             # print(board.san(move)) 
#             board.push(move) 
#             # print(move) #uci 
#             # print(board.fen()) #fen
#             # print(board.san(move))
#             # how to get pgn? 

#     pgn.close()

### 4. Generate PGNVocab Stripped of + and \#  

In [130]:
# moves = []
# move_set = set()

# with open("data/PGNVocab.txt", "r") as f: 
#     lines = f.readlines() 
#     for line in lines:
#         line = line.strip("+#\n")
#         if line not in move_set:
#             moves.append(line)
#             move_set.add(line)

# f.close() 

# with open("data/PGNVocabStrip.txt", "w") as f: 
#     for move in moves: 
#         f.write(move + "\n")

## Chess Agent used to interface with Stockfish engine

In [298]:
@dataclass 
class ChessAgentConfig: 
    depth: int = 20 # engine depth 
    mate_score: int = 10000 # engine mate score

class ChessAgent: 
    def __init__(self, config):
        self.engine = chess.engine.SimpleEngine.popen_uci("stockfish")
        self.config = config 

    def fen_and_pgn_from_uci(self, uci, fen): 
        board = chess.Board(fen)
        move = chess.Move.from_uci(uci)
        pgn = board.san(move)
        board.push(move)
        fen = board.fen()
        
        return fen, pgn 

    async def async_position_info(self, fen): 
        board = chess.Board(fen)
        return self.engine.analyse(board, chess.engine.Limit(depth=20))

    async def async_position_eval(self, fen): 
        info = await self.async_position_info(fen)
        return info["score"].white().score(mate_score=self.config.mate_score)

    def position_info(self, fen): 
        board = chess.Board(fen)
        return self.engine.analyse(board, chess.engine.Limit(depth=20))

    def position_eval(self, fen): 
        info = self.position_info(fen)
        return info["score"].white().score(mate_score=self.config.mate_score)

    def quit(self): 
        self.engine.quit()

asyncio.set_event_loop_policy(chess.engine.EventLoopPolicy())

In [299]:
chessConfig = ChessAgentConfig()

chessAgent = ChessAgent(chessConfig) 

fen, pgn = chessAgent.fen_and_pgn_from_uci("e2e4","rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq -")

In [133]:
black_mate_in_1 = "rnbqkbnr/pppp1ppp/4p3/8/5PP1/8/PPPPP2P/RNBQKBNR b KQkq - 0 2"
white_mate_in_1 = "rnbqkbnr/ppppp2p/8/5pp1/4P3/8/PPPPQPPP/RNB1KBNR w KQkq - 0 3"

print(chess.Board(black_mate_in_1))
eval_black = await chessAgent.position_eval(black_mate_in_1)
print("Black Eval: ", eval_black)
print("\n")
print(chess.Board(white_mate_in_1))
eval_white = await chessAgent.position_eval(white_mate_in_1)
print("White Eval: ", eval_white)

chessAgent.quit()

r n b q k b n r
p p p p . p p p
. . . . p . . .
. . . . . . . .
. . . . . P P .
. . . . . . . .
P P P P P . . P
R N B Q K B N R
Black Eval:  -9999


r n b q k b n r
p p p p p . . p
. . . . . . . .
. . . . . p p .
. . . . P . . .
. . . . . . . .
P P P P Q P P P
R N B . K B N R
White Eval:  9999


## PyTorch Custom Dataset Class to be used in Dataloader

In [159]:
class ChessDataset(Dataset): 
    def __init__(self, white, black): 
        super().__init__()

        assert os.path.exists(f"data/{white}.pgn"), f"data/{white} does not exist" 
        assert os.path.exists(f"data/{white}_fen.txt"), f"data/{white}_fen.txt does not exist" 

        self.white_uci = self.load_uci_moves(white)
        self.white_fen = self.load_fen_moves(white)
        self.white_pgn = self.load_pgn_moves(white)

        assert len(self.white_uci) == len(self.white_fen), f"white_uci and white_fen are not the same length"

        assert os.path.exists(f"data/{black}.pgn"), f"data/{black} does not exist"
        assert os.path.exists(f"data/{black}_fen.txt"), f"data/{black}_fen.txt does not exist"

        self.black_uci = self.load_uci_moves(black)
        self.black_fen = self.load_fen_moves(black)
        self.black_pgn = self.load_pgn_moves(black)

        assert len(self.black_uci) == len(self.black_fen), f"black_uci and black_fen are not the same length"

        self.pgn_vocab = self.load_pgn_vocab() 
         
        self.pgn_embedding = nn.Embedding(131072, 768)
        self.uci_embedding = nn.Embedding(8192, 768)

        self.pad_game_length()

        self.to_numpy()

    def load_pgn_vocab(self): 
        vocab = open("data/PGNVocab.txt", 'r')
        vocab_dict = dict()
        for (idx,line) in enumerate(vocab.readlines()): 
            vocab_dict[line.strip("\n")] = idx
            
        vocab.close() 

        return vocab_dict
        
    def load_fen_moves(self, file_name): 
        with open(f"data/{file_name}_fen.txt", 'r') as f: 
            lines = f.readlines() 
            games = []
            for line in lines: 
                games.append(line.strip("\n").split(";")[:-1])
        f.close() 

        return games

    def load_uci_moves(self, file_name): 
        pgn = open(f"data/{file_name}.pgn", 'r')
        games = [] 

        while True: 
            game = chess.pgn.read_game(pgn)
            if game is None: 
                break
            board = game.board() 
            moves = []
            for move in game.mainline_moves(): 
                moves.append(str(move))
            games.append(moves)
        
        return games 

    def load_pgn_moves(self, file_name): 
        pgn = open(f"data/{file_name}.pgn", "r")
        games = [] 

        pgn_content = pgn.read() 

        games = pgn_content.split("\n\n")[:-1]

        game_move_list = []

        for game in games: 
            moves = [move[max(0,move.find(".")+1):].strip("\n") for move in game.split(" ")]
            game_move_list.append(moves)

        return game_move_list

    def __len__(self): 
        return len(self.white_uci) + len(self.black_uci)

    # for pgn moves 
    def to_numpy(self): 
        stoi = { move:i for i, move in enumerate(self.pgn_vocab) }
        itos = { i:move for i, move in enumerate(self.pgn_vocab) }

        def encode(pgn): 
            return [stoi[move] for move in pgn]
        def decode(l): 
            return [itos[i] for i in l]

        # print(self.white_pgn[0:10])
        
        self.white_pgn_np = np.array([encode(pgn) for pgn in self.white_pgn])
        self.black_pgn_np = np.array([encode(pgn) for pgn in self.black_pgn])



    # required for batch loading 
    def pad_game_length(self): 
        max_len = 512 
        lists = [self.white_uci, self.black_uci, self.white_fen, self.black_fen, self.white_pgn, self.black_pgn]
        for game_list in lists:
            for game in game_list:
                game_len = len(game)
                while game_len < max_len:
                    game.append("<eos>")
                    game_len += 1

    def game_item(self, idx): 
        if idx < len(self.white_uci):
            game = {"UCI": self.white_uci[idx], "PGN": self.white_pgn[idx], "FEN": self.white_fen[idx], "np": self.white_np[idx], "color": "white"}
        else:
            game = {"UCI": self.black_uci[idx - len(self.white_uci)], "PGN": self.black_pgn[idx - len(self.white_uci)], "FEN": self.black_fen[idx - len(self.white_uci)], "np": self.black_np[idx - len(self.white_uci)],  "color": "black"}

        return game

    def __getitem__(self, idx):
        return self.game_item(idx)


In [68]:
# need to do this for the dataloader to work 
def collate_fn(batch):
    # Combine individual samples into a list of dictionaries (batch)
    batch_list = []
    for sample in batch:
        batch_list.append(sample)

    return batch_list

In [69]:
dataloader1 = DataLoader(dataset1, batch_size=64, shuffle=True, num_workers=0, collate_fn=collate_fn)

In [72]:
batch = next(iter(dataloader1))
print(len(batch))

64


In [73]:
for batch in dataloader1: 
    print(len(batch))

64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
64
62


## Building Blocks for Transformer 

Sourced Heavily from https://github.com/karpathy/nanoGPT/blob/master/model.py

In [74]:
class LayerNorm(nn.Module):
    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

# masked attention 
class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) # in features, out features. ok this makes sense because it is a n_embd size matrix. 
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout

        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # hs = hidden state 
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) # .view() returns a new tensor of the same data but with a different shape specified by the args 
        # reshape() is a more robust version of view(). has to do with the tensor being contiguous or not. view() is more efficient if the tensor is contiguous. 
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

class CausalCrossAttention(CausalSelfAttention): 

    def __init__(self, config): 
        super().__init__(config)
        self.c_attn = nn.Linear(config.n_embd, config.n_embd * 2, bias=config.bias)
        
    def forward(self, x1, x2): 
        # x1 provides the Q and K 
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        q, k = self.c_attn(x1).split(self.n_embd, dim=2)
        v = self.c_attn(x2)

        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        if self.flash: 
            y = torch.nn.functional.scaled_dot_product_attention(q, k , v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
        else: 
            att = (q @ k.transpose(-2,-1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 

        # output projection 
        y = self.resid_dropout(self.c_proj(y))
        return y 

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) # classifier fully connected
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) # classifier projection 
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = new_gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class EncoderBlock(nn.Module):
    # uses pre layer norm 
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class DecoderBlock(nn.Module): 
    def __init__(self, config): 
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn_1 = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn_2 = CausalCrossAttention(config)
        self.ln_3 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)
    
    def forward(self, x1, x2): 
        x2 = x2 + self.attn_1(self.ln_1(x2))
        x2 = x2 + self.attn_2(self.ln_2(x2), x1)
        x2 = x2 + self.mlp(self.ln_3(x2))
        return x2

@dataclass 
class ChessGPTConfig: 
    block_size: int = 512 # longest chess game was 269 moves or 538 half moves. However, that's the only one. We can discard that as an outlier. 512 covers basically every chess game. 
    vocab_size: int = 131072 # using the PGNVocab.txt, there are 125736 words in the vocabulary. Not every vocabulary will be found in real games. Most will not because most of the moves are very rare. Some are also impossible. 
    n_layer: int = 12 # number of encoder and decoder blocks in the stack 
    n_head: int = 12 # multihead attention blocks 
    n_embd: int = 768 # size of the embedding... originally chosen by GPT-2 ... may have to modify later 
    dropout: float = 0.0 # none for now 
    bias: bool = True 

## Chess Transformer Module 

Also sourced heavily from https://github.com/karpathy/nanoGPT/blob/master/model.py

In [324]:
class ChessGPT(nn.Module): 

    def __init__(self, config): 
        super().__init__() 
        self.config = config 

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd), 
            wpe = nn.Embedding(config.block_size * 128, config.n_embd),
            wtoe = nn.Embedding(config.vocab_size, config.n_embd), # target output embedding
            wpoe = nn.Embedding(config.block_size * 128, config.n_embd), # target output positional embedding
            drop = nn.Dropout(config.dropout), 
            enc = nn.ModuleList([EncoderBlock(config) for _ in range(config.n_layer)]),
            dec = nn.ModuleList([DecoderBlock(config) for _ in range(config.n_layer)]), 
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))

        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight
        # experiment with weight tying 
        self.transformer.wtoe.weight = self.lm_head.weight

        self.apply(self._init_weights)

        for pn, p in self.named_parameters(): 
            if pn.endswith('c_proj.weight'): 
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True): 
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding: 
            n_params -= self.transformer.wpe.weight.numel()
        return n_params 

    def _init_weights(self, module): 
        if isinstance(module, nn.Linear): 
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None: 
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            # initializes module.weight with normal distribution from N(mean, std^2)

    def forward(self, pgn, fen, targets=None): 
        device = pgn.device
        b, t_pgn = pgn.size()
        b, t_fen, _ = fen.size()
        assert t_pgn <= self.config.block_size, f"Cannot forward sequence of length {t_pgn}, block size is only {self.config.block_size}"
        assert t_fen <= self.config.block_size, f"Cannot forward sequence of length {t_fen}, block size is only {self.config.block_size}"

        input_tok_emb = self.transformer.wte(pgn) # token embeddings of shape (b, t_pgn, n_embd = 768)
        input_pos_emb = self.transformer.wpe(pos_fen)  # token embeddings of shape (b, t_fen, n_embd = 768)

        x1 = self.transformer.drop(input_tok_emb + input_pos_emb)
        for block in self.transformer.enc:
            x1 = block(x1)

        output_tok_emb = self.transformer.wtoe(pgn)
        output_pos_emb = self.transformer.wpoe(pos_pgn)  # Use pos_pgn for positional encoding

        for block in self.transformer.dec: 
            x = block()
        x = self.transformer.ln_f(x)

        if targets is not None: 
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else: 
            logits = self.lm_head(x[:, [-1], :])
            loss = None 

        return logits, loss 

    @torch.no_grad()
    def generate(self, idx, max_new_tokens = 1, temperature=1.0, top_k=None): 
        for _ in range(max_new_tokens): 
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None: 
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)

        return idx 

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, )
        blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
                # random note: because named_modules and named_parameters are recursive
                # we will see the same tensors p many many times. but doing it this way
                # allows us to know which parent module any tensor p belongs to...
                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
        # will appear in the no_decay and decay sets respectively after the above.
        # In addition, because named_parameters() doesn't return duplicates, it
        # will only return the first occurence, key'd by 'transformer.wte.weight', below.
        # so let's manually remove 'lm_head.weight' from decay set. This will include
        # this tensor into optimization via transformer.wte.weight only, and not decayed.
        decay.remove('lm_head.weight')

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        # new PyTorch nightly has a new 'fused' option for AdamW that is much faster
        use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters)
        print(f"using fused AdamW: {use_fused}")
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)

        return optimizer

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
        # first estimate the number of flops we do per iteration.
        # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
        N = self.get_num_params()
        cfg = self.config
        L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
        flops_per_token = 6*N + 12*L*H*Q*T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        # express our flops throughput as ratio of A100 bfloat16 peak flops
        flops_achieved = flops_per_iter * (1.0/dt) # per second
        flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
        mfu = flops_achieved / flops_promised
        return mfu

## Training the Model

### 1. Load the Data

In [197]:
# -----------------------------------------------------------------------------
# default config values designed to train a gpt2 (124M) on OpenWebText
# I/O
out_dir = 'out'
eval_interval = 2000
log_interval = 1
eval_iters = 200
eval_only = False # if True, script exits right after the first eval
always_save_checkpoint = True # if True, always save a checkpoint after each eval
init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
# wandb logging
wandb_log = False # disabled by default
wandb_project = 'owt'
wandb_run_name = 'gpt2' # 'run' + str(time.time())
# data
dataset = 'openwebtext'
gradient_accumulation_steps = 5 # used to simulate larger batch sizes
batch_size = 64 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size = 512 
# model
n_layer = 12
n_head = 12
n_embd = 768
dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
bias = False # do we use bias inside LayerNorm and Linear layers?
# adamw optimizer
learning_rate = 6e-4 # max learning rate
max_iters = 600000 # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
# learning rate decay settings
decay_lr = True # whether to decay the learning rate
warmup_iters = 2000 # how many steps to warm up for
lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
# DDP settings
backend = 'nccl' # 'nccl', 'gloo', etc.
# system
device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
compile = False # use PyTorch 2.0 to compile the model to be faster
device_type = 'cpu'
master_process = True 

In [210]:
dataset = ChessDataset("tal_white_games", "tal_black_games")
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

In [161]:
print(len(train_dataset))
print(len(val_dataset))

2187
243


In [181]:
train = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [300]:
chessAgentConfig = ChessAgentConfig()
chessAgent = ChessAgent(chessAgentConfig)

In [328]:
fen_evals = {}  # Dictionary to cache position evaluations

In [329]:
# def fen_to_number(fen):
#     if fen == '<eos>': 
#         return 0 

#     mapping = {
#         '1': 0, '2': 1, '3': 2, '4': 3, '5': 4, '6': 5, '7': 6, '8': 7,
#         'p': -1, 'n': -3, 'b': -3, 'r': -5, 'q': -9, 'k': 0,
#         'P': 1, 'N': 3, 'B': 3, 'R': 5, 'Q': 9, 'K': 0
#     }

#     try: 
#         board, turn, castling, en_passant, half_move, full_move = fen.split(' ')[:6]
#     except: 
#         print("error with fen: ", fen)
        
#     number = 0

#     for char in board:
#         if char == '/':
#             continue
#         elif char.isdigit():
#             number += int(char)
#         else:
#             number += mapping[char]

#     if turn in mapping:
#         number += mapping[turn] * 100
#     else:
#         number += 0  # Assigning a default value if turn is not recognized

#     number += castling_score(castling) * 1000

#     if en_passant != '-':
#         if en_passant in mapping:
#             number += mapping[en_passant] * 10000
#         else:
#             number += 0  # Assigning a default value if en_passant is not recognized

#     number += int(half_move) * 100000
#     number += int(full_move) * 1000000

#     # eval = pos_eval(fen)  # Evaluate the position using the pos_eval function
#     # number += eval * 10000000  # Incorporate the position evaluation

#     return number

In [361]:
def fen_to_tensor(fen): 
    if fen == '<eos>': 
        return torch.zeros(128, dtype=torch.long)

    fen_vocab = ' wPNBRQKpnbrqk12345678/KQkq-abcdefgh0123456789'

    stoi = { ch:i for i, ch in enumerate(fen_vocab) }
    itos = { i:ch for i, ch in enumerate(fen_vocab) }

    def encode(fen): 
        return [stoi[ch] for ch in fen]
    def decode(l): 
        return [itos[i] for i in l]

    encode_fen = torch.tensor(encode(fen), dtype=torch.long)
    padded_encode_fen = F.pad(encode_fen, (0, 128 - len(encode_fen)), value=0)

    return padded_encode_fen

In [362]:
print(len(fen_evals))

0


In [363]:
def pos_eval(fen):
    if fen in fen_evals: 
        return fen_evals[fen]
    eval = chessAgent.position_eval(fen) 
    fen_evals[fen] = eval

    return eval

In [383]:
def get_batch(split):
    dataloader = train if split == 'train' else val
    batch = next(iter(dataloader))
    fens = [[fen_to_tensor(fen) for fen in game['FEN']] for game in batch]
    x = torch.stack([torch.from_numpy((game['np']).astype(np.int64)) for game in batch]) # shape batch_size x block_size 
    y = torch.stack([torch.cat((torch.from_numpy((game['np'][1:]).astype(np.int64)), torch.tensor([0]))) for game in batch]) # shape batch_size x block_size
    p = torch.stack([torch.stack([fen_tensor for fen_tensor in game]) for game in fens]) # shape batch_size x block_size x 128 
    print(x.shape)
    print(y.shape)
    print(p.shape)
    x, y, p = x.to(device), y.to(device), p.to(device)
    return x, y, p

In [382]:
x, y, p  = get_batch('train')

torch.Size([64, 512])
torch.Size([64, 512])
torch.Size([64, 512, 128])


### 2. Run Training Epochs

In [322]:
iter_num = 0
best_val_loss = 1e9

In [325]:
chessGPTConfig = ChessGPTConfig()
model = ChessGPT(chessGPTConfig)

model.to(device)

number of parameters: 292.45M


ChessGPT(
  (transformer): ModuleDict(
    (wte): Embedding(131072, 768)
    (wpe): Embedding(512, 768)
    (wtoe): Embedding(131072, 768)
    (wpoe): Embedding(512, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (enc): ModuleList(
      (0-11): 12 x EncoderBlock(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (dec): ModuleList(
      (0-11): 12 x DecoderBlock(
        (ln_1): LayerNorm()
        (attn_1): CausalSelfAttention(
    

In [84]:
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

In [85]:
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)

using fused AdamW: False


In [200]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [201]:
# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

In [204]:
game_batch = next(iter(train))
t0 = time.time() 
local_iter_num = 0 
raw_model = model  
running_mfu = -1.0
ctx = nullcontext()

In [206]:
while True: 
    lr = get_lr(iter_num)
    for param_group in optimizer.param_groups: 
        param_group['lr'] = lr 

    if iter_num % eval_interval == 0 and master_process: 
        losses = estimate_loss() 
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    if iter_num == 0 and eval_only:
        break

    # forward backward update, with optional gradient accumulation to simulate larger batch size
    # and using the GradScaler if data type is float16
    for micro_step in range(gradient_accumulation_steps):
        with ctx:
            logits, loss = model(X, Y)
        # immediately async prefetch next batch while model is doing the forward pass on the GPU
        game_batch = next(iter(train))
        # backward pass, with gradient scaling if training in fp16
        scaler.scale(loss).backward()
    # clip the gradient
    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # step the optimizer and scaler if training in fp16
    scaler.step(optimizer)
    scaler.update()
    # flush the gradients as soon as we can, no need for this memory anymore
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0 and master_process:
        lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point
        if local_iter_num >= 5: # let the training loop settle a bit
            mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
    iter_num += 1
    local_iter_num += 1

    # termination conditions
    if iter_num > max_iters:
        break

AttributeError: 'ModuleDict' object has no attribute 'h'