In [1]:
import chess
import numpy as np
from network import ChessNetV2 as ChessNet
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from concurrent.futures import ProcessPoolExecutor
from replay_buffer import ReplayBuffer
from state import move_to_index, move_mask
import mcts
import util
import gc
import time
from device import device
import os

In [2]:
def train_step(replay_buffer: ReplayBuffer, batch_size, net, optimizer):
    net.train()
    if replay_buffer.size() < batch_size:
        return 0.0, 0.0, 0.0

    boards, positions, target_policies, target_values = replay_buffer.sample(batch_size)
    optimizer.zero_grad()

    pred_policies, pred_values = net(positions)

    masks = torch.stack([torch.tensor(move_mask(board), dtype=torch.float32) for board in boards]).to(device)
    masked_logits = pred_policies + (masks - 1) * 1e9
    pred_probs = F.log_softmax(masked_logits, dim=1)

    policy_loss = -torch.sum(target_policies * pred_probs, dim=1).mean()

    value_loss = F.mse_loss(pred_values.squeeze(-1), target_values)
    loss = policy_loss + value_loss

    loss.backward()
    
    grad_norm = torch.nn.utils.clip_grad_norm_(net.parameters(), 5.0)
    
    optimizer.step()

    return loss.item(), policy_loss.item(), value_loss.item(), grad_norm

In [3]:
model_checkpoint_path = 'model.pth'
optimizer_checkpoint_path = 'optimizer.pth'
replay_buffer_checkpoint_path = 'replay-buffer.pkl'
game_gif_path = 'game.gif'

net = ChessNet(n_moves=len(move_to_index))
try:
    net.load_state_dict(torch.load(model_checkpoint_path), device=device)
except: # checkpoint doesn't exist, continue with new model
    pass
net = net.to(device)

optimizer = optim.Adam(net.parameters(), lr=5e-5, weight_decay=1e-5)
try:
    optimizer.load_state_dict(torch.load(optimizer_checkpoint_path))
except: # checkpoint doesn't exist, continue with new optimizer
    pass

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.7,
    patience=3,
    min_lr=1e-6,
    verbose=True
)

n_epochs = 50
n_selfplay_games = 24
n_train_steps = 1000

replay_buffer = ReplayBuffer(500000, pct_recent=0.3, pct_recent_util=0.6)
try:
    replay_buffer.load(replay_buffer_checkpoint_path)
except: # checkpoint doesn't exist, continue with new replay buffer
    pass
batch_size = 64

n_sims = 800
c_puct = 1.5
temperature = 1.5
temperature_threshold = 15
alpha = 0.15
epsilon = 0.15



In [None]:
torch.save(net.state_dict(), model_checkpoint_path)
torch.save(optimizer.state_dict(), optimizer_checkpoint_path)
replay_buffer.save(replay_buffer_checkpoint_path)

best_loss = float('inf')
patience_counter = 0
max_patience = 15

for epoch in range(n_epochs):
    epoch_dir = os.path.join('games', f'epoch{epoch + 1}')
    os.makedirs(epoch_dir, exist_ok=True)

    start = time.time()

    print(f"\nEpoch {epoch + 1}: Starting selfplay...")

    with ProcessPoolExecutor(max_workers=12) as executor:
        futures = [
            executor.submit(
                mcts.selfplay_wrapper, 
                model_checkpoint_path,
                n_sims, 
                os.path.join(epoch_dir, f'game{i + 1}.txt'),
                c_puct, 
                temperature,
                temperature_threshold,
                alpha, 
                epsilon,
            ) for i in range(n_selfplay_games)
        ]

        positions_added = 0
        game_results = {'1-0': 0, '0-1': 0, '1/2-1/2': 0}

        for future in futures:
            boards, positions, policies, values = future.result()
            replay_buffer.add_game(boards, positions, policies, values)
            positions_added += len(positions)

            result = boards[-1].result()
            if result in game_results:
                game_results[result] += 1
            
    print(f"Added {positions_added} positions. Buffer: {replay_buffer.size()}")
    print(f"Win Results - White: {game_results.get('1-0', 0)}, "
          f"Black: {game_results.get('0-1', 0)}, "
          f"Draw: {game_results.get('1/2-1/2', 0)}")
    
    total_loss = 0
    total_ploss = 0
    total_vloss = 0
    total_grad_norm = 0
    
    for step in range(n_train_steps):
        loss, ploss, vloss, grad_norm = train_step(replay_buffer, batch_size, net, optimizer)
        total_loss += loss
        total_ploss += ploss
        total_vloss += vloss
        total_grad_norm += grad_norm

    
    avg_loss = total_loss / n_train_steps
    avg_ploss = total_ploss / n_train_steps
    avg_vloss = total_vloss / n_train_steps
    avg_grad_norm = total_grad_norm / n_train_steps
    
    end = time.time()
    
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(avg_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    print(f'\nEpoch {epoch + 1}: Loss: {avg_loss:.4f}, P_Loss: {avg_ploss:.4f}, '
          f'V_loss: {avg_vloss:.4f}, Avg_GradNorm: {avg_grad_norm:.2f}, '
          f'LR: {current_lr:.6f}, Time: {(end - start)/60:.2f}m')
    
    if avg_loss < best_loss - 0.01:
        best_loss = avg_loss
        patience_counter = 0
        print(f"New best loss! Saving checkpoint...")
        torch.save(net.state_dict(), model_checkpoint_path)
        torch.save(optimizer.state_dict(), optimizer_checkpoint_path)
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{max_patience}")
        if patience_counter >= max_patience:
            print("Early stopping triggered!")
            break
    
    replay_buffer.save(replay_buffer_checkpoint_path)

    if current_lr != old_lr:
        print(f"Learning rate reduced: {old_lr:.6f} -> {current_lr:.6f}")

    gc.collect()


Epoch 1: Starting selfplay...
Added 6266 positions. Buffer: 6266
Win Results - White: 0, Black: 0, Draw: 0


TypeError: ReduceLROnPlateau.step() missing 1 required positional argument: 'metrics'