In [None]:
import chess
import numpy as np
from network import ChessNet
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from concurrent.futures import ProcessPoolExecutor
from replay_buffer import ReplayBuffer
from state import move_to_index
import mcts
import util
import tempfile
import os

In [2]:
def train_step(replay_buffer, batch_size, net, optimizer):
    net.train()

    if replay_buffer.size() < batch_size:
        return 0.0, 0.0, 0.0

    positions, target_policies, target_values = replay_buffer.sample(batch_size)

    optimizer.zero_grad()

    pred_policies, pred_values = net(positions)

    log_probs = F.log_softmax(pred_policies, dim=1)
    policy_loss = -torch.sum(target_policies * log_probs) / positions.size(0)

    value_loss = F.mse_loss(pred_values, target_values)

    loss = policy_loss + value_loss

    loss.backward()
    optimizer.step()

    return loss.item(), policy_loss.item(), value_loss.item()

In [3]:
net = ChessNet(n_moves=len(move_to_index))
optimizer = optim.Adam(net.parameters(), lr=1e-3)

n_epochs = 100
n_selfplay_games = 5

replay_buffer = ReplayBuffer(1000000)
batch_size = 64

n_sims = 100
c_puct = 1.5
temperature = 1.5
alpha = 0.2
epsilon = 0.2

In [4]:
for epoch in range(n_epochs):
    with tempfile.NamedTemporaryFile(suffix='.pth', delete=False) as tmp: # save net for this epoch to parallelize MCTS
        net_path = tmp.name
        torch.save(net.state_dict(), net_path)

    with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
        futures = [
            executor.submit(
                mcts.selfplay_wrapper, net_path, replay_buffer, n_sims, c_puct, temperature, alpha, epsilon
            ) for _ in range(n_selfplay_games)
        ]

        for future in futures:
            board, _, _, _ = future.result()
        
    os.unlink(net_path)

    loss, ploss, vloss = train_step(replay_buffer, batch_size, net, optimizer)
    print(f"Epoch {epoch + 1}: Loss: {loss:.4f}, Policy: {ploss:.4f}, Value: {vloss:.4f}")

Process SpawnProcess-2:
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/Users/abhinavsrivatsa/projects/chess/env/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/Users/abhinavsrivatsa/projects/chess/env/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/abhinavsrivatsa/projects/chess/env/lib/python3.11/concurrent/futures/process.py", line 249, in _process_worker
    call_item = call_queue.get(block=True)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/abhinavsrivatsa/projects/chess/env/lib/python3.11/multiprocessing/queues.py", line 102, in get
    with self._rlock:
  File "/Users/abhinavsrivatsa/projects/chess/env/lib/python3.11/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt
Traceback (most recent call last):
  File "/Users/abhinavsrivatsa/proje

KeyboardInterrupt: 

In [None]:
util.save_game_gif(board, 'game.gif')

GIF successfully saved to game.gif
