In [1]:
import chess
import numpy as np
from network import ChessNetV2 as ChessNet
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from concurrent.futures import ProcessPoolExecutor
from replay_buffer import ReplayBuffer
from state import move_to_index, move_mask
import mcts
import util
import gc
import time

In [2]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print('Using MPS device')
else:
    device = torch.device('cpu')
    print('MPS not available, using CPU')

Using MPS device


In [3]:
def train_step(replay_buffer: ReplayBuffer, batch_size, net, optimizer):
    net.train()

    if replay_buffer.size() < batch_size:
        return 0.0, 0.0, 0.0

    boards, positions, target_policies, target_values = replay_buffer.sample(batch_size)

    optimizer.zero_grad()

    pred_policies, pred_values = net(positions)

    masks = torch.stack([torch.tensor(move_mask(board), dtype=torch.float32) for board in boards]).to(device)
    masked_policies = pred_policies * masks
    target_sums = masked_policies.sum(dim=1, keepdim=True)
    masked_policies = masked_policies / (target_sums + 1e-10)
    
    probs = F.log_softmax(masked_policies, dim=1)
    policy_loss = -torch.sum(target_policies * probs) / batch_size

    value_loss = F.mse_loss(pred_values.squeeze(-1), target_values)

    loss = policy_loss + value_loss

    loss.backward()
    optimizer.step()

    return loss.item(), policy_loss.item(), value_loss.item()

In [4]:
model_checkpoint_path = 'model.pth'
optimizer_checkpoint_path = 'optimizer.pth'
replay_buffer_checkpoint_path = 'replay-buffer.pkl'
game_gif_path = 'game.gif'

net = ChessNet(n_moves=len(move_to_index))
try:
    net.load_state_dict(torch.load(model_checkpoint_path), device=device)
except: # checkpoint doesn't exist, continue with new model
    pass
net = net.to(device)

optimizer = optim.Adam(net.parameters(), lr=1e-3)
try:
    optimizer.load_state_dict(torch.load(optimizer_checkpoint_path))
except: # checkpoint doesn't exist, continue with new optimizer
    pass

n_epochs = 50
n_selfplay_games = 12

replay_buffer = ReplayBuffer(1000000)
try:
    replay_buffer.load(replay_buffer_checkpoint_path)
except: # checkpoint doesn't exist, continue with new replay buffer
    pass
batch_size = 64

n_sims = 400
c_puct = 1.5
temperature = 1.5
alpha = 0.1
epsilon = 0.15

In [5]:
torch.save(net.state_dict(), model_checkpoint_path)
torch.save(optimizer.state_dict(), optimizer_checkpoint_path)
replay_buffer.save(replay_buffer_checkpoint_path)

for epoch in range(n_epochs):
    start = time.time()

    with ProcessPoolExecutor(max_workers=12) as executor:
        futures = [
            executor.submit(
                mcts.selfplay_wrapper, model_checkpoint_path, n_sims, c_puct, temperature, alpha, epsilon
            ) for _ in range(n_selfplay_games)
        ]

        for future in futures:
            boards, positions, policies, values = future.result()
            replay_buffer.add_game(boards, positions, policies, values)
        
    end = time.time()

    # util.save_game_gif(board, game_gif_path)

    loss, ploss, vloss = train_step(replay_buffer, batch_size, net, optimizer)
    print(f'Epoch {epoch + 1}: Loss: {loss:.4f}, P_Loss: {ploss:.4f}, V_loss: {vloss:.4f}, Time: {(end - start):.2f}s')

    torch.save(net.state_dict(), model_checkpoint_path)
    torch.save(optimizer.state_dict(), optimizer_checkpoint_path)
    replay_buffer.save(replay_buffer_checkpoint_path)
    gc.collect()

Process SpawnProcess-11:
Process SpawnProcess-3:
Traceback (most recent call last):
  File "/Users/abhinavsrivatsa/projects/chess/env/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/Users/abhinavsrivatsa/projects/chess/env/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/abhinavsrivatsa/projects/chess/env/lib/python3.11/concurrent/futures/process.py", line 249, in _process_worker
    call_item = call_queue.get(block=True)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/abhinavsrivatsa/projects/chess/env/lib/python3.11/multiprocessing/queues.py", line 102, in get
    with self._rlock:
  File "/Users/abhinavsrivatsa/projects/chess/env/lib/python3.11/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt
Traceback (most recent call last):
  File "/Users/abhinavsrivatsa/proj

KeyboardInterrupt: 