In [1]:
import chess
import numpy as np
from network import ChessNetV2 as ChessNet
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from concurrent.futures import ProcessPoolExecutor
from replay_buffer import ReplayBuffer
from state import move_to_index, move_mask
import mcts
import util
import gc
import time
from device import device
import os

In [2]:
def train_step(replay_buffer: ReplayBuffer, batch_size, net, optimizer):
    net.train()
    if replay_buffer.size() < batch_size:
        return 0.0, 0.0, 0.0

    boards, positions, target_policies, target_values = replay_buffer.sample(batch_size)
    optimizer.zero_grad()

    pred_policies, pred_values = net(positions)

    masks = torch.stack([torch.tensor(move_mask(board), dtype=torch.float32) for board in boards]).to(device)
    masked_logits = pred_policies + (masks - 1) * 1e9
    pred_probs = F.log_softmax(masked_logits, dim=1)

    policy_loss = -torch.sum(target_policies * pred_probs, dim=1).mean()

    value_loss = F.mse_loss(pred_values.squeeze(-1), target_values)
    loss = policy_loss + value_loss

    loss.backward()
    torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
    optimizer.step()

    return loss.item(), policy_loss.item(), value_loss.item()

In [None]:
model_checkpoint_path = 'model.pth'
optimizer_checkpoint_path = 'optimizer.pth'
replay_buffer_checkpoint_path = 'replay-buffer.pkl'
game_gif_path = 'game.gif'

net = ChessNet(n_moves=len(move_to_index))
try:
    net.load_state_dict(torch.load(model_checkpoint_path), device=device)
except: # checkpoint doesn't exist, continue with new model
    pass
net = net.to(device)

optimizer = optim.Adam(net.parameters(), lr=3e-4)
try:
    optimizer.load_state_dict(torch.load(optimizer_checkpoint_path))
except: # checkpoint doesn't exist, continue with new optimizer
    pass

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

n_epochs = 50
n_selfplay_games = 12

replay_buffer = ReplayBuffer(1000000)
try:
    replay_buffer.load(replay_buffer_checkpoint_path)
except: # checkpoint doesn't exist, continue with new replay buffer
    pass
batch_size = 64

n_sims = 400
c_puct = 1.5
temperature = 1.0
temperature_threshold = 15
alpha = 0.1
epsilon = 0.15

In [None]:
torch.save(net.state_dict(), model_checkpoint_path)
torch.save(optimizer.state_dict(), optimizer_checkpoint_path)
replay_buffer.save(replay_buffer_checkpoint_path)

for epoch in range(n_epochs):
    epoch_dir = os.path.join('games', f'epoch{epoch + 1}')
    os.makedirs(epoch_dir, exist_ok=True)

    start = time.time()

    with ProcessPoolExecutor(max_workers=12) as executor:
        futures = [
            executor.submit(
                mcts.selfplay_wrapper, 
                model_checkpoint_path,
                n_sims, 
                os.path.join(epoch_dir, f'game{i + 1}.txt'),
                c_puct, 
                temperature,
                temperature_threshold,
                alpha, 
                epsilon,
            ) for i in range(n_selfplay_games)
        ]

        for future in futures:
            boards, positions, policies, values = future.result()
            replay_buffer.add_game(boards, positions, policies, values)
    
    n_train_steps = 200
    total_loss = 0
    total_ploss = 0
    total_vloss = 0
    
    for step in range(n_train_steps):
        loss, ploss, vloss = train_step(replay_buffer, batch_size, net, optimizer)
        total_loss += loss
        total_ploss += ploss
        total_vloss += vloss
    
    avg_loss = total_loss / n_train_steps
    avg_ploss = total_ploss / n_train_steps
    avg_vloss = total_vloss / n_train_steps

    
    end = time.time()
    
    scheduler.step(avg_loss)
    
    print(f'Epoch {epoch + 1}: Loss: {avg_loss:.4f}, P_Loss: {avg_ploss:.4f}, V_loss: {avg_vloss:.4f}, Time: {(end - start):.2f}s')

    torch.save(net.state_dict(), model_checkpoint_path)
    torch.save(optimizer.state_dict(), optimizer_checkpoint_path)
    replay_buffer.save(replay_buffer_checkpoint_path)
    gc.collect()

Epoch 1: Loss: 2.2166, P_Loss: 2.1980, V_loss: 0.0186, Time: 1223.74s
Epoch 2: Loss: 2.5781, P_Loss: 2.3176, V_loss: 0.2604, Time: 1469.88s
Epoch 3: Loss: 2.2715, P_Loss: 2.2452, V_loss: 0.0263, Time: 2568.52s
Epoch 4: Loss: 2.2144, P_Loss: 2.1946, V_loss: 0.0198, Time: 1139.67s
Epoch 5: Loss: 2.2445, P_Loss: 2.2287, V_loss: 0.0159, Time: 1416.72s
Epoch 6: Loss: 2.2849, P_Loss: 2.2745, V_loss: 0.0105, Time: 863.93s
Epoch 7: Loss: 2.3242, P_Loss: 2.2403, V_loss: 0.0839, Time: 1241.08s
Epoch 8: Loss: 2.2202, P_Loss: 2.2089, V_loss: 0.0113, Time: 1383.95s
Epoch 9: Loss: 2.3462, P_Loss: 2.2463, V_loss: 0.0999, Time: 1488.04s
Epoch 10: Loss: 2.2841, P_Loss: 2.2443, V_loss: 0.0397, Time: 1947.19s
Epoch 11: Loss: 2.2534, P_Loss: 2.2426, V_loss: 0.0108, Time: 2963.55s
Epoch 12: Loss: 2.2593, P_Loss: 2.2404, V_loss: 0.0189, Time: 1335.29s
Epoch 13: Loss: 2.2448, P_Loss: 2.2355, V_loss: 0.0093, Time: 1327.39s
Epoch 14: Loss: 2.2481, P_Loss: 2.2413, V_loss: 0.0068, Time: 2076.30s
Epoch 15: Loss: 

Process SpawnProcess-297:
Process SpawnProcess-292:
Process SpawnProcess-290:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/abhinavsrivatsa/projects/chess-bot/env/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/Users/abhinavsrivatsa/projects/chess-bot/env/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/abhinavsrivatsa/projects/chess-bot/env/lib/python3.11/concurrent/futures/process.py", line 249, in _process_worker
    call_item = call_queue.get(block=True)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/abhinavsrivatsa/projects/chess-bot/env/lib/python3.11/multiprocessing/queues.py", line 102, in get
    with self._rlock:
  File "/Users/abhinavsrivatsa/projects/chess-bot/env/lib/python3.11/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
      

KeyboardInterrupt: 