In [1]:
import chess
import numpy as np
from network import ChessNet
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from concurrent.futures import ProcessPoolExecutor
from replay_buffer import ReplayBuffer
from state import move_to_index
import mcts
import util
import os
import time

In [2]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print('Using MPS device')
else:
    device = torch.device('cpu')
    print('MPS not available, using CPU')

Using MPS device


In [3]:
def train_step(replay_buffer: ReplayBuffer, batch_size, net, optimizer):
    net.train()

    if replay_buffer.size() < batch_size:
        return 0.0, 0.0, 0.0

    positions, target_policies, target_values = replay_buffer.sample(batch_size)

    optimizer.zero_grad()

    pred_policies, pred_values = net(positions)

    log_probs = F.log_softmax(pred_policies, dim=1)
    policy_loss = -torch.sum(target_policies * log_probs) / positions.size(0)

    value_loss = F.mse_loss(pred_values, target_values.unsqueeze(1))

    loss = policy_loss + value_loss

    loss.backward()
    optimizer.step()

    return loss.item(), policy_loss.item(), value_loss.item()

In [None]:
model_checkpoint_path = 'model.pth'
optimizer_checkpoint_path = 'optimizer.pth'
replay_buffer_checkpoint_path = 'replay-buffer.pkl'
game_gif_path = 'game.gif'

net = ChessNet(n_moves=len(move_to_index))
try:
    net.load_state_dict(torch.load(model_checkpoint_path), device=device)
except: # checkpoint doesn't exist, continue with new model
    pass
net = net.to(device)

optimizer = optim.Adam(net.parameters(), lr=1e-3)
try:
    optimizer.load_state_dict(torch.load(optimizer_checkpoint_path))
except: # checkpoint doesn't exist, continue with new optimizer
    pass

n_epochs = 75
n_selfplay_games = 8

replay_buffer = ReplayBuffer(1000000)
try:
    replay_buffer.load(replay_buffer_checkpoint_path)
except:
    pass
batch_size = 64

n_sims = 400
c_puct = 1.5
temperature = 1.0
alpha = 0.1
epsilon = 0.15

In [5]:
torch.save(net.state_dict(), model_checkpoint_path)
torch.save(optimizer.state_dict(), optimizer_checkpoint_path)
replay_buffer.save(replay_buffer_checkpoint_path)

for epoch in range(n_epochs):
    start = time.time()

    with ProcessPoolExecutor(max_workers=8) as executor:
        futures = [
            executor.submit(
                mcts.selfplay_wrapper, model_checkpoint_path, n_sims, c_puct, temperature, alpha, epsilon
            ) for _ in range(n_selfplay_games)
        ]

        for future in futures:
            board, positions, policies, values = future.result()
            replay_buffer.add_game(positions, policies, values)
        
    end = time.time()

    # util.save_game_gif(board, game_gif_path)

    loss, ploss, vloss = train_step(replay_buffer, batch_size, net, optimizer)
    print(f'Epoch {epoch + 1}: Loss: {loss:.4f}, P_Loss: {ploss:.4f}, V_loss: {vloss:.4f}, Time: {(end - start):.2f}s')

    torch.save(net.state_dict(), model_checkpoint_path)
    torch.save(optimizer.state_dict(), optimizer_checkpoint_path)
    replay_buffer.save(replay_buffer_checkpoint_path)

Epoch 1: Loss: 9.4483, P_Loss: 8.4195, V_loss: 1.0287, Time: 269.62s


In [6]:
util.save_game_gif(board, game_gif_path)

Can't handle color: url(#check_gradient)
Can't handle color: url(#check_gradient)
Can't handle color: url(#check_gradient)
Can't handle color: url(#check_gradient)
Can't handle color: url(#check_gradient)
Can't handle color: url(#check_gradient)
Can't handle color: url(#check_gradient)
Can't handle color: url(#check_gradient)
Can't handle color: url(#check_gradient)
Can't handle color: url(#check_gradient)
Can't handle color: url(#check_gradient)


GIF successfully saved to game.gif


In [7]:
torch.save(net.state_dict(), 'model.pth')