In [1]:
import chess
import numpy as np
from network import ChessNetV2 as ChessNet
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from concurrent.futures import ProcessPoolExecutor
from replay_buffer import ReplayBuffer
from state import move_to_index, move_mask
import mcts
import util
import gc
import time
from device import device
import os

In [2]:
def train_step(replay_buffer: ReplayBuffer, batch_size, net, optimizer, scheduler):
    net.train()

    if replay_buffer.size() < batch_size:
        return 0.0, 0.0, 0.0

    boards, positions, target_policies, target_values = replay_buffer.sample(batch_size)

    optimizer.zero_grad()

    pred_policies, pred_values = net(positions)

    masks = torch.stack([torch.tensor(move_mask(board), dtype=torch.float32) for board in boards]).to(device)
    masked_logits = pred_policies + (masks - 1) * 10
    policy_loss = F.cross_entropy(masked_logits, target_policies)

    value_loss = F.mse_loss(pred_values.squeeze(-1), target_values)

    loss = policy_loss + 0.5 * value_loss

    loss.backward()
    optimizer.step()
    scheduler.step(loss)

    return loss.item(), policy_loss.item(), value_loss.item()

In [3]:
model_checkpoint_path = 'model.pth'
optimizer_checkpoint_path = 'optimizer.pth'
replay_buffer_checkpoint_path = 'replay-buffer.pkl'
game_gif_path = 'game.gif'

net = ChessNet(n_moves=len(move_to_index))
try:
    net.load_state_dict(torch.load(model_checkpoint_path), device=device)
except: # checkpoint doesn't exist, continue with new model
    pass
net = net.to(device)

optimizer = optim.Adam(net.parameters(), lr=1e-3)
try:
    optimizer.load_state_dict(torch.load(optimizer_checkpoint_path))
except: # checkpoint doesn't exist, continue with new optimizer
    pass

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

n_epochs = 50
n_selfplay_games = 12

replay_buffer = ReplayBuffer(1000000)
try:
    replay_buffer.load(replay_buffer_checkpoint_path)
except: # checkpoint doesn't exist, continue with new replay buffer
    pass
batch_size = 64

n_sims = 800
c_puct = 1.5
temperature = 1.5
alpha = 0.1
epsilon = 0.15

In [None]:
torch.save(net.state_dict(), model_checkpoint_path)
torch.save(optimizer.state_dict(), optimizer_checkpoint_path)
replay_buffer.save(replay_buffer_checkpoint_path)

for epoch in range(n_epochs):
    epoch_dir = os.path.join('games', f'epoch{epoch + 1}')
    try:
        os.mkdir(epoch_dir)
    except:
        pass

    start = time.time()

    with ProcessPoolExecutor(max_workers=12) as executor:
        futures = [
            executor.submit(
                mcts.selfplay_wrapper, 
                model_checkpoint_path,
                n_sims, 
                os.path.join(epoch_dir, f'game{i + 1}.txt'),
                c_puct, 
                temperature, 
                alpha, 
                epsilon,
            ) for i in range(n_selfplay_games)
        ]

        for future in futures:
            boards, positions, policies, values = future.result()
            replay_buffer.add_game(boards, positions, policies, values)
        
    end = time.time()

    # util.save_game_gif(board, game_gif_path)

    loss, ploss, vloss = train_step(replay_buffer, batch_size, net, optimizer, scheduler)
    print(f'Epoch {epoch + 1}: Loss: {loss:.4f}, P_Loss: {ploss:.4f}, V_loss: {vloss:.4f}, Time: {(end - start):.2f}s')

    torch.save(net.state_dict(), model_checkpoint_path)
    torch.save(optimizer.state_dict(), optimizer_checkpoint_path)
    replay_buffer.save(replay_buffer_checkpoint_path)
    gc.collect()