## Install relevant dependencies
I make the assumption that Pytorch is already installed

In [1]:
!pip install python-chess
!pip install tqdm



## Download chess game data
Can skip this step if your own pgns to use

In [2]:
!curl https://database.lichess.org/standard/lichess_db_standard_rated_2013-01.pgn.zst --output games.pgn.zst
!zstd --decompress games.pgn.zst

# !mkdir pgns/ # uncomment if pgns dir does not exist
!mv games.pgn pgns/
!rm games.pgn.zst

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 16.9M  100 16.9M    0     0  5758k      0  0:00:03  0:00:03 --:--:-- 5760k
games.pgn.zst       : 92811021 bytes                                           


## Create list of games

In [3]:
import chess.pgn

pgn = open("pgns/games.pgn", "r", encoding="utf-8")

all_games= []

# while True: 
for i in range(100000): # increase this limit for a better model
    game = chess.pgn.read_game(pgn)
    if game is None:
        break  # End of games
        
    all_games.append(game)

pgn.close()
print(f"{len(all_games)} games parsed")

100000 games parsed


## Create list of distinct chess positions
Goal of this is to create diverse set of chess FENs that can be used to create the training dataset

In [4]:
import random

all_positions = set()

for game in all_games:
    board = game.board()
    moves = list(game.mainline_moves())
    positions = []
    
    for move in moves:
        board.push(move)
        positions.append(board.fen())
    
    random_positions = random.sample(positions, min(10, len(moves)) // 7)
    all_positions.update(random_positions)

all_positions = list(all_positions)
print(f"{len(all_positions)} unique positions")

89690 unique positions


## Define functions to convert between tensor and FEN string
This will let us encode chess positions in a way the NNs can use

In [5]:
import torch

piece_to_idx = {'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
                'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11}

def pos_to_tensor(fen, device="cpu"):
    parts = fen.split(" ")
    wtm = parts[1] == "w"
    castling_rights = parts[2]

    board = chess.Board(fen)
    tensor = torch.zeros(15, 8, 8, device=device)

    for row in range(8): 
        for col in range(8):
            sqr = chess.square(col, 7 - row)
            piece = board.piece_at(sqr)
            if piece != None:
                p = piece.symbol()
                idx = piece_to_idx[p]
                tensor[idx, row, col] = 1 if p.isupper() else -1
                  
    # Encode castling rights
    if 'K' in castling_rights:
        tensor[12, 0, 0] = 1
    if 'Q' in castling_rights:
        tensor[12, 0, 7] = 1
    if 'k' in castling_rights:
        tensor[13, 7, 0] = -1  
    if 'q' in castling_rights:
        tensor[13, 7, 7] = -1
          
    # Encode side to move
    tensor[14] = 1 if wtm else -1
      
    return tensor

def tensor_to_pos(tensor):
    board = chess.Board(None)
    piece_symbols = list(piece_to_idx.keys())
    
    # Decode the board pieces
    for idx, piece_symbol in enumerate(piece_symbols[:12]):
        mask = tensor[idx].abs() > 0
        positions = mask.nonzero(as_tuple=True)
        for row, col in zip(*positions):
            square = chess.square(col, 7 - row)
            board.set_piece_at(square, chess.Piece.from_symbol(piece_symbol))
    
    # Decode castling rights
    castling_rights = ''
    if tensor[12, 0, 0] == 1:
        castling_rights += 'K'
    if tensor[12, 0, 7] == 1:
        castling_rights += 'Q'
    if tensor[13, 7, 0] == -1:
        castling_rights += 'k'
    if tensor[13, 7, 7] == -1:
        castling_rights += 'q'
    board.set_castling_fen(castling_rights)
    
    # Decode side to move
    side_to_move = 'w' if tensor[14].mean() > 0 else 'b'
    board.turn = True if side_to_move == 'w' else False
    
    return board.fen()
    
print(f"Shape of encoded chess position tensor: {pos_to_tensor(all_positions[0]).shape}")

Shape of encoded chess position tensor: torch.Size([15, 8, 8])


  tensor = torch.zeros(15, 8, 8, device=device)


## Set pytorch device to cuda if available

In [6]:
device = "mps" if torch.backends.mps.is_available else "cpu"
torch.set_default_device("cpu")
print(f"using {device}")

using mps


## Create datasets for training and testing

In [7]:
from torch.utils.data import Dataset, DataLoader

class PositionDataset(Dataset):
  def __init__(self, tensors):
        self.tensors = tensors

  def __len__(self):
      return len(self.tensors)

  def __getitem__(self, idx):
        return self.tensors[idx]

tensors = [pos_to_tensor(pos, device) for pos in all_positions]
random.shuffle(tensors)

# Calculate the indices for splitting
total_tensors = len(tensors)
train_end = int(total_tensors * 0.8)
val_end = int(total_tensors * 0.9)

# Split the tensors into train, validation, and test sets
train_tensors = tensors[:train_end]
val_tensors = tensors[train_end:val_end]
test_tensors = tensors[val_end:]

# Create datasets for each split
train_dataset = PositionDataset(train_tensors)
val_dataset = PositionDataset(val_tensors)
test_dataset = PositionDataset(test_tensors)

print(f"len training set: {len(train_dataset)}")
print(f"len validation set: {len(val_dataset)}")
print(f"len test set: {len(test_dataset)}")

len training set: 71752
len validation set: 8969
len test set: 8969


## Define the structure of the NN
We are training an autoencoder that will learn to deconstruct, then reconstruct chess positions.\
Once trained, we can use the encoder to generate our embeddings

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, hyperparams):
        super(Encoder, self).__init__()

        channels = hyperparams["position_channels"]
        n_embed = hyperparams["n_embed"]
        filters = hyperparams["filters"]
        fc_size = hyperparams["fc_size"]
        
        self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(filters, filters * 2, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(filters * 2, filters * 4, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(filters * 4 * 1 * 1, fc_size)
        self.fc2 = nn.Linear(fc_size, n_embed)  # Compressed representation

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))  
        x = self.pool(x)
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class Decoder(nn.Module):
    def __init__(self, hyperparams):        
        super(Decoder, self).__init__()

        channels = hyperparams["position_channels"]
        n_embed = hyperparams["n_embed"]
        filters = hyperparams["filters"]
        fc_size = hyperparams["fc_size"]

        
        self.fc1 = nn.Linear(n_embed, fc_size)
        self.fc2 = nn.Linear(fc_size, filters * 4 * 1 * 1)
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(filters * 4 , 1, 1))
        self.deconv1 = nn.ConvTranspose2d(filters * 4, filters * 2, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(filters * 2, filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(filters, channels, kernel_size=3, stride=2, padding=1, output_padding=1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.unflatten(x)
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = self.deconv3(x)
        return x

class PositionAutoEncoder(nn.Module):
    def __init__(self, hyperparams):
        super(PositionAutoEncoder, self).__init__()
        self.encoder = Encoder(hyperparams)
        self.decoder = Decoder(hyperparams)

    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)

        return x

    @torch.no_grad()
    def embed(self, x):
        code = self.encoder(x)
        return code

## Define model hyperparameter

In [9]:
hyperparams = {
    "batch_size": 32,
    "n_epochs": 50,
    "learning_rate": 17e-4,
    "dropout_rate": 0,
    "position_channels": 15,
    "n_embed": 128,
    "filters": 32,
    "fc_size": 256,
    "version": 6
}

batch_size = hyperparams["batch_size"]
n_epochs = hyperparams["n_epochs"]
learning_rate = hyperparams["learning_rate"]

## Initialize the model and optimizer

In [10]:
from torch.optim import AdamW, lr_scheduler

# init dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

#init model
model = PositionAutoEncoder(hyperparams)
model.to(device)

# init optimizer
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_params = sum(p.numel() for p in model.parameters())/1e6
print(f"{num_params:.2f}M parameters")

# init lr scheduler
num_training_steps = n_epochs * len(train_loader)
scheduler = lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=learning_rate, total_steps=num_training_steps)
print(f"training iterations: {num_training_steps}")


0.33M parameters
training iterations: 112150


## Run the training loop
We train by minimizing MSE loss on the reconstructed posittion encoding.\
Validation loss is calculated after each epoch to ensure learning

In [11]:
from tqdm.auto import tqdm

criterion = nn.MSELoss()
progress_bar = tqdm(range(num_training_steps))

for epoch in range(n_epochs):
    model.train() # switch model to training mode

    for batch in train_loader:
        batch = batch.to(device)
        outputs = model(batch)
        loss = criterion(outputs, batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        progress_bar.update(1)
    
    print(f"finished epcoh: {epoch}")
    with torch.no_grad():
        # evaluate validation loss
        model.eval() # switch model to evaluation mode
        losses = torch.zeros(len(val_loader), device=device)
        k = 0
        for batch in val_loader:
            batch = batch.to(device)
            outputs = model(batch)
            loss = criterion(outputs, batch)
                
            losses[k] = loss.item()
            k += 1

        avg_val_loss = losses.mean()
        # -----------------------------
        
        # evaluate training loss
        losses =  torch.zeros(len(train_loader), device=device)
        k = 0
        for batch in train_loader:
            batch = batch.to(device)
            outputs = model(batch)
            loss = criterion(outputs, batch)
                
            losses[k] = loss.item()
            k += 1
            
            if(k == len(train_loader)):
                break
        
        avg_train_loss = losses.mean()
        # ------------------------------
        print(f"learning rate: {optimizer.param_groups[0]['lr']}")
        print(f"val loss: {avg_val_loss}")
        print(f"train loss: {avg_train_loss}")
    

  from .autonotebook import tqdm as notebook_tqdm
  2%|▏         | 2240/112150 [00:45<33:55, 54.01it/s]  

finished epcoh: 0


  2%|▏         | 2249/112150 [01:04<23:54:00,  1.28it/s]

learning rate: 8.583261395319545e-05
val loss: 0.01817934773862362
train loss: 0.018181797116994858


  4%|▍         | 4483/112150 [01:47<33:53, 52.96it/s]   

finished epcoh: 1


  4%|▍         | 4495/112150 [02:04<17:38:51,  1.69it/s]

learning rate: 0.00013855103885100813
val loss: 0.01700441539287567
train loss: 0.01699102483689785


  6%|▌         | 6729/112150 [02:35<26:30, 66.29it/s]   

finished epcoh: 2


  6%|▌         | 6738/112150 [02:55<19:58:13,  1.47it/s]

learning rate: 0.00022385109008700877
val loss: 0.01514926552772522
train loss: 0.015156994573771954


  8%|▊         | 8969/112150 [03:24<23:15, 73.92it/s]   

finished epcoh: 3


  8%|▊         | 8985/112150 [03:37<9:43:04,  2.95it/s] 

learning rate: 0.00033800452529874725
val loss: 0.014940446242690086
train loss: 0.01494923047721386


 10%|▉         | 11207/112150 [04:03<19:53, 84.59it/s] 

finished epcoh: 4


 10%|█         | 11222/112150 [04:17<10:25:26,  2.69it/s]

learning rate: 0.0004760219961000428
val loss: 0.014342368580400944
train loss: 0.014351881109178066


 12%|█▏        | 13450/112150 [04:45<20:09, 81.60it/s]   

finished epcoh: 5


 12%|█▏        | 13467/112150 [05:02<11:58:05,  2.29it/s]

learning rate: 0.0006318711194624738
val loss: 0.013624461367726326
train loss: 0.013641721569001675


 14%|█▍        | 15694/112150 [05:29<19:22, 82.98it/s]   

finished epcoh: 6


 14%|█▍        | 15710/112150 [05:43<9:06:44,  2.94it/s] 

learning rate: 0.0007987401374157798
val loss: 0.01338213961571455
train loss: 0.013401484116911888


 16%|█▌        | 17940/112150 [06:10<18:55, 82.98it/s]  

finished epcoh: 7


 16%|█▌        | 17957/112150 [06:25<9:30:00,  2.75it/s] 

learning rate: 0.0009693356411905733
val loss: 0.012907050549983978
train loss: 0.012922242283821106


 18%|█▊        | 20183/112150 [06:52<18:53, 81.10it/s]  

finished epcoh: 8


 18%|█▊        | 20200/112150 [07:09<10:32:00,  2.42it/s]

learning rate: 0.0011362013470589943
val loss: 0.0123378224670887
train loss: 0.012342734262347221


 20%|█▉        | 22425/112150 [07:38<18:18, 81.68it/s]   

finished epcoh: 9


 20%|██        | 22439/112150 [08:07<18:59:03,  1.31it/s]

learning rate: 0.001292043991014232
val loss: 0.011012295261025429
train loss: 0.01100356224924326


 22%|██▏       | 24671/112150 [08:34<17:57, 81.19it/s]   

finished epcoh: 10


 22%|██▏       | 24682/112150 [09:01<18:16:04,  1.33it/s]

learning rate: 0.0014300520982839843
val loss: 0.010114243254065514
train loss: 0.010090316645801067


 24%|██▍       | 26911/112150 [09:28<17:26, 81.43it/s]   

finished epcoh: 11


 24%|██▍       | 26925/112150 [09:49<12:57:04,  1.83it/s]

learning rate: 0.001544193695095328
val loss: 0.009641195647418499
train loss: 0.009614516980946064


 26%|██▌       | 29159/112150 [10:17<17:09, 80.59it/s]   

finished epcoh: 12


 26%|██▌       | 29168/112150 [10:47<21:19:40,  1.08it/s]

learning rate: 0.001629479950487631
val loss: 0.008884647861123085
train loss: 0.00885975081473589


 28%|██▊       | 31400/112150 [11:15<16:32, 81.39it/s]   

finished epcoh: 13


 28%|██▊       | 31411/112150 [11:47<20:40:24,  1.08it/s]

learning rate: 0.0016821832250787253
val loss: 0.008300947956740856
train loss: 0.008260425180196762


 30%|██▉       | 33640/112150 [12:15<16:04, 81.43it/s]   

finished epcoh: 14


 30%|███       | 33654/112150 [12:37<12:49:58,  1.70it/s]

learning rate: 0.0016999999993193994
val loss: 0.008042393252253532
train loss: 0.008009368553757668


 32%|███▏      | 35880/112150 [13:04<15:30, 81.94it/s]   

finished epcoh: 15


 32%|███▏      | 35897/112150 [13:27<11:52:35,  1.78it/s]

learning rate: 0.0016965751138295405
val loss: 0.00765839172527194
train loss: 0.007615339942276478


 34%|███▍      | 38125/112150 [13:55<15:07, 81.54it/s]   

finished epcoh: 16


 34%|███▍      | 38140/112150 [14:16<11:12:03,  1.84it/s]

learning rate: 0.0016863341306760328
val loss: 0.0073494683019816875
train loss: 0.007315434515476227


 36%|███▌      | 40367/112150 [14:43<14:37, 81.82it/s]   

finished epcoh: 17


 36%|███▌      | 40383/112150 [14:55<5:57:51,  3.34it/s]

learning rate: 0.001669359504233628
val loss: 0.011752900667488575
train loss: 0.011746767908334732


 38%|███▊      | 42610/112150 [15:22<14:12, 81.59it/s]  

finished epcoh: 18


 38%|███▊      | 42626/112150 [15:31<4:12:03,  4.60it/s]

learning rate: 0.0016457879042135217
val loss: 0.010148672387003899
train loss: 0.010129549540579319


 40%|███▉      | 44854/112150 [15:58<13:51, 80.93it/s]  

finished epcoh: 19


 40%|████      | 44870/112150 [16:05<3:12:32,  5.82it/s]

learning rate: 0.0016158091152791987
val loss: 0.009312155656516552
train loss: 0.009291520342230797


 42%|████▏     | 47099/112150 [16:33<13:15, 81.75it/s]  

finished epcoh: 20


 42%|████▏     | 47114/112150 [16:39<3:15:47,  5.54it/s]

learning rate: 0.0015796645090119165
val loss: 0.00890581775456667
train loss: 0.008884284645318985


 44%|████▍     | 49339/112150 [17:07<13:09, 79.61it/s]  

finished epcoh: 21


 44%|████▍     | 49354/112150 [17:13<3:11:45,  5.46it/s]

learning rate: 0.0015376451005286547
val loss: 0.008553639985620975
train loss: 0.008535314351320267


 46%|████▌     | 51585/112150 [17:41<11:56, 84.48it/s]  

finished epcoh: 22


 46%|████▌     | 51602/112150 [17:47<2:49:40,  5.95it/s]

learning rate: 0.0014900892053995082
val loss: 0.008280747570097446
train loss: 0.008247309364378452


 48%|████▊     | 53830/112150 [18:14<11:43, 82.92it/s]  

finished epcoh: 23


 48%|████▊     | 53847/112150 [18:21<2:50:09,  5.71it/s]

learning rate: 0.0014373797157296434
val loss: 0.008054747246205807
train loss: 0.008031242527067661


 50%|████▉     | 56068/112150 [18:47<10:56, 85.49it/s]  

finished epcoh: 24


 50%|█████     | 56085/112150 [18:55<2:50:47,  5.47it/s]

learning rate: 0.0013799410173372048
val loss: 0.007971185259521008
train loss: 0.007947842590510845


 52%|█████▏    | 58316/112150 [19:21<10:37, 84.43it/s]  

finished epcoh: 25


 52%|█████▏    | 58333/112150 [19:28<2:35:49,  5.76it/s]

learning rate: 0.0013182355728482394
val loss: 0.007910186424851418
train loss: 0.007879339158535004


 54%|█████▍    | 60555/112150 [19:54<10:04, 85.35it/s]  

finished epcoh: 26


 54%|█████▍    | 60572/112150 [20:01<2:32:42,  5.63it/s]

learning rate: 0.0012527601982195596
val loss: 0.007755425292998552
train loss: 0.007727895397692919


 56%|█████▌    | 62803/112150 [20:28<09:46, 84.17it/s]  

finished epcoh: 27


 56%|█████▌    | 62812/112150 [20:37<4:10:20,  3.28it/s]

learning rate: 0.0011840420626687779
val loss: 0.007508154027163982
train loss: 0.007473407778888941


 58%|█████▊    | 65043/112150 [21:03<09:20, 84.05it/s]  

finished epcoh: 28


 58%|█████▊    | 65060/112150 [21:12<2:58:53,  4.39it/s]

learning rate: 0.0011126344442177402
val loss: 0.007435886655002832
train loss: 0.007396027911454439


 60%|█████▉    | 67288/112150 [21:39<08:51, 84.39it/s]  

finished epcoh: 29


 60%|██████    | 67305/112150 [21:47<2:31:08,  4.95it/s]

learning rate: 0.0010391122750232145
val loss: 0.00737511133775115
train loss: 0.007333476562052965


 62%|██████▏   | 69527/112150 [22:13<08:24, 84.43it/s]  

finished epcoh: 30


 62%|██████▏   | 69544/112150 [22:22<2:38:43,  4.47it/s]

learning rate: 0.00096406751236122
val loss: 0.007263731677085161
train loss: 0.007230900693684816


 64%|██████▍   | 71774/112150 [22:49<07:56, 84.81it/s]  

finished epcoh: 31


 64%|██████▍   | 71791/112150 [22:59<2:54:22,  3.86it/s]

learning rate: 0.0008881043725351046
val loss: 0.00715663330629468
train loss: 0.00712421303614974


 66%|██████▌   | 74012/112150 [23:27<07:51, 80.86it/s]  

finished epcoh: 32


 66%|██████▌   | 74028/112150 [23:37<2:35:43,  4.08it/s]

learning rate: 0.0008118344660811459
val loss: 0.007115764543414116
train loss: 0.007086172234266996


 68%|██████▊   | 76262/112150 [24:05<07:19, 81.70it/s]  

finished epcoh: 33


 68%|██████▊   | 76271/112150 [24:14<3:09:55,  3.15it/s]

learning rate: 0.0007358718734401317
val loss: 0.007037002593278885
train loss: 0.006998453754931688


 70%|██████▉   | 78499/112150 [24:42<06:47, 82.57it/s]  

finished epcoh: 34


 70%|███████   | 78516/112150 [24:52<2:08:16,  4.37it/s]

learning rate: 0.0006608282007427268
val loss: 0.006988768000155687
train loss: 0.006953968666493893


 72%|███████▏  | 80746/112150 [25:19<06:19, 82.75it/s]  

finished epcoh: 35


 72%|███████▏  | 80755/112150 [25:28<2:42:36,  3.22it/s]

learning rate: 0.0005873076555165257
val loss: 0.006895107217133045
train loss: 0.0068594892509281635


 74%|███████▍  | 82984/112150 [25:55<05:51, 83.08it/s]  

finished epcoh: 36


 74%|███████▍  | 83001/112150 [26:04<1:54:15,  4.25it/s]

learning rate: 0.0005159021819623092
val loss: 0.00682377303019166
train loss: 0.00679322425276041


 76%|███████▌  | 85230/112150 [26:31<05:28, 81.98it/s]  

finished epcoh: 37


 76%|███████▌  | 85247/112150 [26:40<1:42:06,  4.39it/s]

learning rate: 0.0004471866949673972
val loss: 0.006787765771150589
train loss: 0.006757634691894054


 78%|███████▊  | 87475/112150 [27:07<05:02, 81.62it/s]  

finished epcoh: 38


 78%|███████▊  | 87492/112150 [27:16<1:30:06,  4.56it/s]

learning rate: 0.00038171445122902074
val loss: 0.00675263861194253
train loss: 0.006719158962368965


 80%|████████  | 89720/112150 [27:43<04:29, 83.09it/s]  

finished epcoh: 39


 80%|████████  | 89729/112150 [27:52<1:54:11,  3.27it/s]

learning rate: 0.00032001259475670536
val loss: 0.0067175524309277534
train loss: 0.006680437829345465


 82%|████████▏ | 91959/112150 [28:19<04:04, 82.71it/s]  

finished epcoh: 40


 82%|████████▏ | 91976/112150 [28:27<1:07:55,  4.95it/s]

learning rate: 0.00026257791261866234
val loss: 0.006665370427072048
train loss: 0.0066290562972426414


 84%|████████▍ | 94203/112150 [28:54<03:36, 82.84it/s]  

finished epcoh: 41


 84%|████████▍ | 94220/112150 [29:03<1:03:36,  4.70it/s]

learning rate: 0.00020987283510442019
val loss: 0.0066396314650774
train loss: 0.006602284498512745


 86%|████████▌ | 96447/112150 [29:30<03:13, 81.29it/s]  

finished epcoh: 42


 86%|████████▌ | 96464/112150 [29:38<54:59,  4.75it/s]  

learning rate: 0.00016232171250804063
val loss: 0.006611838936805725
train loss: 0.006571582052856684


 88%|████████▊ | 98687/112150 [30:05<02:42, 83.02it/s]

finished epcoh: 43


 88%|████████▊ | 98704/112150 [30:13<42:44,  5.24it/s]

learning rate: 0.00012030739850906216
val loss: 0.006591229233890772
train loss: 0.006549554876983166


 90%|████████▉ | 100934/112150 [30:40<02:15, 82.76it/s]

finished epcoh: 44


 90%|█████████ | 100943/112150 [30:48<49:02,  3.81it/s]

learning rate: 8.416816765978795e-05
val loss: 0.006577224005013704
train loss: 0.006534148007631302


 92%|█████████▏| 103174/112150 [31:15<01:48, 83.02it/s]

finished epcoh: 45


 92%|█████████▏| 103191/112150 [31:22<27:54,  5.35it/s]

learning rate: 5.419499179749265e-05
val loss: 0.0065634530037641525
train loss: 0.006520238239318132


 94%|█████████▍| 105421/112150 [31:49<01:21, 82.92it/s]

finished epcoh: 46


 94%|█████████▍| 105430/112150 [31:57<29:15,  3.83it/s]

learning rate: 3.0629197310277676e-05
val loss: 0.006552454549819231
train loss: 0.006507308688014746


 96%|█████████▌| 107658/112150 [32:24<00:53, 83.40it/s]

finished epcoh: 47


 96%|█████████▌| 107675/112150 [32:31<13:56,  5.35it/s]

learning rate: 1.3660522118891148e-05
val loss: 0.006548347417265177
train loss: 0.006502307020127773


 98%|█████████▊| 109904/112150 [32:59<00:27, 82.79it/s]

finished epcoh: 48


 98%|█████████▊| 109921/112150 [33:06<06:44,  5.51it/s]

learning rate: 3.4255880185547446e-06
val loss: 0.0065455022267997265
train loss: 0.006499065551906824


100%|█████████▉| 112143/112150 [33:34<00:00, 78.29it/s]

finished epcoh: 49
learning rate: 6.800680600540224e-09
val loss: 0.006544893607497215
train loss: 0.006498592905700207


## Save the trained model weights and metadata

In [13]:
# !mkdir models  # uncomment to create models/ dir
checkpoint = {
    "model": model.state_dict(),
    "train_set": train_dataset,
    "val_set": val_dataset,
    "test_set": test_dataset,
    "hyperparameters": hyperparams
}
torch.save(checkpoint, f"models/v0.pt")

## Load a saved model

In [14]:
chkp = torch.load("models/v0.pt")
emb_model = PositionAutoEncoder(chkp["hyperparameters"]).to(device)
emb_model.eval()
emb_model.load_state_dict(chkp["model"])

train_dataset = chkp["train_set"]
val_dataset = chkp["val_set"]
test_dataset = chkp["test_set"]
embed_data = list(train_dataset + val_dataset + test_dataset)

## Embed collection of chess positions

In [15]:
batches = [embed_data[i:i + 256] for i in range(0, len(embed_data), 256)]
embeds = torch.cat([emb_model.embed(torch.stack(batch)) for batch in batches]).unsqueeze(1)
embeds.shape

torch.Size([89690, 1, 128])

## Seacrh similar positions

In [16]:
# embed a query position
query = emb_model.embed(test_dataset[0].unsqueeze(0)).unsqueeze(0)

# calculate similarities and find top matches
similarities = F.cosine_similarity(embeds, query, dim=2)
top_matches = torch.topk(similarities, 10, dim=0)

# print(top_matches.values)

# convert matches to FEN strings
top_tensors = torch.stack(list(embed_data))[top_matches.indices].squeeze(1)
top_tensors = list(torch.split(top_tensors, 1, dim=0))
positions = [tensor_to_pos(t.squeeze(0)) for t in top_tensors]
query_pos = tensor_to_pos(test_dataset[0])

print(f"query position: {query_pos}")
print(f"similar positions: {positions}")

query position: r2qkbnr/pbpppppp/1pn5/8/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 0 1
similar positions: ['r2qkbnr/pbpppppp/1pn5/8/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 0 1', 'rn1qkbnr/pb1ppppp/1p6/2p5/4P3/5N2/PPPPBPPP/RNBQK2R w KQkq - 0 1', 'rn1qkbnr/pbpppppp/1p6/8/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 0 1', 'rn1qkbnr/p1pppppp/1p6/1B6/4b3/5N2/PPPP1PPP/RNBQK2R w KQkq - 0 1', 'r2qkbnr/pb1ppppp/1pn5/1Bp5/4P3/2P2N2/PP1P1PPP/RNBQK2R w KQkq - 0 1', 'rn1qkbnr/pbpppppp/1p6/8/2B1P3/8/PPPP1PPP/RNBQK1NR w KQkq - 0 1', 'rn1qkbnr/pbpppppp/1p6/8/4P3/8/PPPPBPPP/RNBQK1NR w KQkq - 0 1', 'rn1qkbnr/p2ppppp/bpp5/3B4/4P3/8/PPPP1PPP/RNBQK1NR w KQkq - 0 1', 'rn1qkbnr/pp2pppp/8/2p5/3pP1b1/3B1N2/PPPP1PPP/RNBQK2R w KQkq - 0 1', 'rn1qkbnr/p1pppppp/1pb5/8/4P3/3P1N2/PPP2PPP/RNBQKB1R w KQkq - 0 1']


## Inspect similar positions
Pass a FEN string to chess.Board() to view it

In [None]:
chess.Board(query_pos)

In [None]:
chess.Board(positions[1])