In [12]:
pip install python-chess

Note: you may need to restart the kernel to use updated packages.


In [13]:
# Cell 1: Imports
import os
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset # Added TensorDataset
import pickle
import random
import math
import copy # For deep copying model
from collections import deque # Efficient buffer
from tqdm.notebook import tqdm
import gc
import matplotlib.pyplot as plt
from IPython.display import display, SVG

# Chess imports
import chess
from chess import Board, Move

print("Imports successful!")

Imports successful!


In [14]:
# Cell 2: Configuration (Updated documentation for visualize flag)

# --- Paths ---
PRETRAINED_MODEL_PATH = "/kaggle/input/chessv6/pytorch/default/1/chess_model_v5_resnet_dropout_best.pth" # Load your best SL model
MAPPINGS_PATH = "/kaggle/input/chessv6/pytorch/default/1/chess_mappings_v5.pkl"
RL_CHECKPOINT_DIR = "/kaggle/working/rl_checkpoints/" # Directory to save RL models
os.makedirs(RL_CHECKPOINT_DIR, exist_ok=True)

# --- RL Hyperparameters ---
# MCTS
MCTS_SIMULATIONS = 100     # Number of simulations per move (Start low: 50-100, increase to 800, 1600+)
C_PUCT = 1.5               # Exploration constant (adjust between 1.0 and 4.0)
DIRICHLET_ALPHA = 0.3      # Alpha for Dirichlet noise at the root node
DIRICHLET_EPSILON = 0.25   # Weight of Dirichlet noise

# Self-Play
TEMPERATURE_TAU_START = 1.0 # Exploration temperature for early moves
TEMPERATURE_TAU_END = 1e-3  # Near-greedy temperature for later moves
TAU_SWITCH_MOVE = 30      # Move number to switch temperature
SELF_PLAY_GAMES_PER_ITER = 20 # Games generated per RL iteration (Start low: ~10-50)
MAX_MOVES_PER_GAME = 300 # Safety break for excessively long games

# Training
DATA_BUFFER_MAX_GAMES = 500 # Store data from approx last N games (adjust based on memory)
TRAINING_EPOCHS_PER_ITER = 1 # Epochs over the buffer per RL iteration
RL_BATCH_SIZE = 256       # Batch size for RL training (can be smaller than SL)
RL_LEARNING_RATE = 0.0002 # Learning rate (often lower for RL fine-tuning)
WEIGHT_DECAY = 1e-4       # L2 Regularization
POLICY_LOSS_WEIGHT = 1.0
VALUE_LOSS_WEIGHT = 1.0   # Often equal weights in RL

# Evaluation
EVALUATION_GAMES = 20      # Games to play between new and best network (increase for reliability)
WIN_THRESHOLD = 0.55       # Win rate needed for new model to become best (e.g., > 55%)
EVAL_MCTS_SIMULATIONS = MCTS_SIMULATIONS # Can use same or different sims for eval

# Loop Control
MAX_RL_ITERATIONS = 1000   # Total number of RL iterations (generation + training)
CHECKPOINT_SAVE_FREQ = 5   # Save RL model every N iterations

# --- Debugging/Visualization ---
VISUALIZE_SELF_PLAY = False # <<< SET TO True TO SEE TEXT BOARD OUTPUT DURING SELF-PLAY (Still slows down generation!) <<<

# --- Model Architecture (MUST MATCH SAVED V5 MODEL) ---
NUM_RES_BLOCKS = 9
NUM_CHANNELS = 128
DROPOUT_RATE = 0.3 # Must match the dropout used when saving the V5 model

print("Configuration set (Visualization uses text output).")

Configuration set (Visualization uses text output).


In [15]:
# Cell 3: Helper Functions (Board Representation - Copied from V5)

# --- Board Representation (Ensure this matches EXACTLY your V5 notebook) ---
def board_to_matrix_v4(board: Board, flip: bool = False):
    """
    Converts a board state into a matrix representation (multi-channel).
    Version 4: Includes piece positions, turn, castling rights, and optional horizontal flip.
    Output shape: (18, 8, 8) - float32 numpy array
    Channels:
    - 0-5: White pieces (P, N, B, R, Q, K)
    - 6-11: Black pieces (P, N, B, R, Q, K)
    - 12: White King Castling Right (1 if True)
    - 13: White Queen Castling Right (1 if True)
    - 14: Black King Castling Right (1 if True)
    - 15: Black Queen Castling Right (1 if True)
    - 16: White's Turn (1 if White's turn)
    - 17: Constant plane of 1s
    """
    matrix = np.zeros((18, 8, 8), dtype=np.float32)
    current_board = board.copy() # Work on a copy

    if flip: # Flipping probably not needed/used in standard AlphaZero RL
        raise NotImplementedError("Flipping during RL self-play not standard/implemented here.")
        # current_board = current_board.transform(chess.flip_horizontal) # If needed later

    piece_map = current_board.piece_map()
    for square, piece in piece_map.items():
        row, col = divmod(square, 8)
        piece_idx = piece.piece_type - 1
        color_offset = 0 if piece.color else 6 # White=0, Black=6
        matrix[piece_idx + color_offset, row, col] = 1

    # Castling rights (relative to the original board perspective)
    if board.has_kingside_castling_rights(True): matrix[12, :, :] = 1
    if board.has_queenside_castling_rights(True): matrix[13, :, :] = 1
    if board.has_kingside_castling_rights(False): matrix[14, :, :] = 1
    if board.has_queenside_castling_rights(False): matrix[15, :, :] = 1

    # Turn (relative to the current board)
    if current_board.turn: # True if White's turn
        matrix[16, :, :] = 1

    # Constant plane
    matrix[17, :, :] = 1

    return matrix

print("Board representation function defined.")

Board representation function defined.


In [16]:
# Cell 4: Model Definition (Copied EXACTLY from V5)

class ResidualBlock(nn.Module):
    """ Standard Residual Block """
    def __init__(self, num_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        out = self.relu(out)
        return out

class ChessModelV5(nn.Module):
    """ V5: Dual-headed ResNet-style model with Dropout """
    def __init__(self, num_policy_classes, num_res_blocks=9, num_channels=128, dropout_rate=0.3):
        super(ChessModelV5, self).__init__()
        input_channels = 18 # From board_to_matrix_v4
        self.dropout_rate = dropout_rate

        self.conv_in = nn.Sequential(
            nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(num_channels),
            nn.ReLU(inplace=True)
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_channels) for _ in range(num_res_blocks)]
        )

        # Policy Head
        self.policy_conv = nn.Conv2d(num_channels, 32, kernel_size=1, bias=False)
        self.policy_bn = nn.BatchNorm2d(32)
        self.policy_relu = nn.ReLU(inplace=True)
        self.policy_flatten = nn.Flatten()
        # Calculate flat size dynamically (important!)
        dummy_input = torch.zeros(1, input_channels, 8, 8)
        policy_flat_size = self._get_flat_size(self.policy_relu(self.policy_bn(self.policy_conv(self.res_blocks(self.conv_in(dummy_input))))) )
        self.policy_dropout = nn.Dropout(p=self.dropout_rate)
        self.policy_fc = nn.Linear(policy_flat_size, num_policy_classes)

        # Value Head
        self.value_conv = nn.Conv2d(num_channels, 16, kernel_size=1, bias=False)
        self.value_bn = nn.BatchNorm2d(16)
        self.value_relu = nn.ReLU(inplace=True)
        self.value_flatten = nn.Flatten()
        value_flat_size = self._get_flat_size(self.value_relu(self.value_bn(self.value_conv(self.res_blocks(self.conv_in(dummy_input))))))
        self.value_fc1 = nn.Linear(value_flat_size, 64)
        self.value_relu2 = nn.ReLU(inplace=True)
        self.value_dropout1 = nn.Dropout(p=self.dropout_rate)
        self.value_fc2 = nn.Linear(64, 1)
        self.value_tanh = nn.Tanh() # Output between -1 and 1

        print(f"ChessModelV5 initialized for RL (Input: {input_channels}, ResBlocks: {num_res_blocks}, Channels: {num_channels}, Dropout: {dropout_rate}, Policy Classes: {num_policy_classes})")
        # Don't call _initialize_weights here, we'll load them

    def _get_flat_size(self, x):
        """ Helper to get flattened size after convolutions """
        return x.view(1, -1).size(1)

    def forward(self, x):
        features = self.conv_in(x)
        features = self.res_blocks(features)

        # Policy Head
        policy = self.policy_relu(self.policy_bn(self.policy_conv(features)))
        policy = self.policy_flatten(policy)
        policy = self.policy_dropout(policy)
        policy_logits = self.policy_fc(policy)

        # Value Head
        value = self.value_relu(self.value_bn(self.value_conv(features)))
        value = self.value_flatten(value)
        value = self.value_relu2(self.value_fc1(value))
        value = self.value_dropout1(value)
        value = self.value_fc2(value)
        value_output = self.value_tanh(value)

        return policy_logits, value_output

print("Model Definition loaded.")

Model Definition loaded.


In [17]:
# Cell 5: Load Mappings, Device Setup, Load Pre-trained Model

# --- Device Setup ---
if torch.cuda.is_available():
    device = torch.device("cuda")
    gpu_count = torch.cuda.device_count()
    print(f"Using {gpu_count} GPU(s):")
    for i in range(gpu_count):
         print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
    # Set default device (important for tensor creation within MCTS if not specified)
    torch.cuda.set_device(0)
else:
    device = torch.device("cpu")
    gpu_count = 0
    print("Using CPU.")

# --- Load Mappings ---
loaded_mappings = None
loaded_move_to_int = None
loaded_int_to_move = None
loaded_num_classes = None
mappings_loaded = False
try:
    print(f"\nLoading mappings from: {MAPPINGS_PATH}")
    with open(MAPPINGS_PATH, "rb") as f:
        loaded_mappings = pickle.load(f)
    loaded_move_to_int = loaded_mappings['move_to_int']
    loaded_int_to_move = loaded_mappings['int_to_move']
    loaded_num_classes = loaded_mappings['num_classes']
    mappings_loaded = True
    print(f"Mappings loaded successfully. Num policy classes: {loaded_num_classes}")
except Exception as e:
    print(f"Error loading mappings: {e}. Cannot proceed.")

# --- Initialize and Load Model ---
model = None
best_model = None # Model used for self-play generation
model_loaded = False

if mappings_loaded:
    try:
        print(f"\nInitializing model structure with {loaded_num_classes} classes...")
        # Ensure these params match the saved V5 model!
        model = ChessModelV5(
            num_policy_classes=loaded_num_classes,
            num_res_blocks=NUM_RES_BLOCKS,
            num_channels=NUM_CHANNELS,
            dropout_rate=DROPOUT_RATE
        )

        print(f"Loading pre-trained weights from: {PRETRAINED_MODEL_PATH}")
        # Load state dict onto CPU first to handle potential DataParallel prefix
        state_dict = torch.load(PRETRAINED_MODEL_PATH, map_location='cpu')

        # Remove 'module.' prefix if it exists (saved from DataParallel)
        if all(key.startswith('module.') for key in state_dict.keys()):
             print("  Removing 'module.' prefix from state dict keys.")
             from collections import OrderedDict
             new_state_dict = OrderedDict()
             for k, v in state_dict.items():
                 name = k[7:] # remove `module.`
                 new_state_dict[name] = v
             state_dict = new_state_dict

        model.load_state_dict(state_dict)
        model.to(device)
        model.eval() # Start in evaluation mode

        # Create the 'best_model' used for generating games, initially same as loaded model
        best_model = copy.deepcopy(model)
        best_model.to(device)
        best_model.eval()

        model_loaded = True
        print("Pre-trained model loaded successfully onto device.")
        print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    except FileNotFoundError:
        print(f"Error: Pre-trained model file not found at {PRETRAINED_MODEL_PATH}")
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Ensure model definition parameters in Cell 4 match the saved model.")
else:
    print("Skipping model loading due to mapping load failure.")

# --- Optimizer ---
optimizer = None
if model_loaded:
    optimizer = optim.Adam(model.parameters(), lr=RL_LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    print(f"Optimizer initialized: Adam (LR={RL_LEARNING_RATE}, WD={WEIGHT_DECAY})")

# --- Loss Functions ---
policy_criterion = nn.CrossEntropyLoss()
value_criterion = nn.MSELoss()
print("Loss functions defined.")

Using 2 GPU(s):
  GPU 0: Tesla T4
  GPU 1: Tesla T4

Loading mappings from: /kaggle/input/chessv6/pytorch/default/1/chess_mappings_v5.pkl
Mappings loaded successfully. Num policy classes: 1888

Initializing model structure with 1888 classes...
ChessModelV5 initialized for RL (Input: 18, ResBlocks: 9, Channels: 128, Dropout: 0.3, Policy Classes: 1888)
Loading pre-trained weights from: /kaggle/input/chessv6/pytorch/default/1/chess_model_v5_resnet_dropout_best.pth
Pre-trained model loaded successfully onto device.
Model parameters: 6,620,225
Optimizer initialized: Adam (LR=0.0002, WD=0.0001)
Loss functions defined.


  state_dict = torch.load(PRETRAINED_MODEL_PATH, map_location='cpu')


In [18]:
# Cell 6: MCTS Implementation (Cleaned - Overflow fix retained)

import math
import numpy as np
import torch
import torch.nn.functional as F
from chess import Board, Move
import time
import random # Ensure random is imported

class MCTSNode:
    def __init__(self, state: Board, parent=None, move=None, prior_p=0.0):
        self.state = state
        self.parent = parent
        self.move = move
        self.children = {}
        self.is_expanded = False
        self.is_terminal = state.is_game_over(claim_draw=True)
        self.terminal_value = None
        if self.is_terminal:
            result = state.result(claim_draw=True)
            if result == '1-0': self.terminal_value = 1.0
            elif result == '0-1': self.terminal_value = -1.0
            else: self.terminal_value = 0.0

        self.N = 0
        self.W = 0.0
        self.Q = 0.0
        self.P = prior_p

    def select_child(self, c_puct):
        best_score = -float('inf')
        best_move_uci = None
        best_child = None
        sqrt_total_N = math.sqrt(self.N) if self.N > 0 else 1.0

        for move_uci, child in self.children.items():
            q_value = -child.Q
            u_value = c_puct * child.P * sqrt_total_N / (1 + child.N)
            score = q_value + u_value
            if score > best_score:
                best_score = score
                best_move_uci = move_uci
                best_child = child

        if best_child is None and self.children:
            best_move_uci = random.choice(list(self.children.keys()))
            best_child = self.children[best_move_uci]

        return best_move_uci, best_child

    def expand(self, policy_probs, int_to_move_map):
        if self.is_expanded or self.is_terminal:
            return
        self.is_expanded = True
        legal_moves = list(self.state.legal_moves)
        if not legal_moves:
            self.is_terminal = True
            result = self.state.result(claim_draw=True)
            if result == '1-0': self.terminal_value = 1.0
            elif result == '0-1': self.terminal_value = -1.0
            else: self.terminal_value = 0.0
            return

        temp_move_to_int = {uci: idx for idx, uci in int_to_move_map.items()}
        for move in legal_moves:
            move_uci = move.uci()
            move_idx = temp_move_to_int.get(move_uci, -1)
            if move_idx != -1 and move_idx < len(policy_probs):
                prior_p = policy_probs[move_idx]
                next_state_board = self.state.copy()
                next_state_board.push(move)
                self.children[move_uci] = MCTSNode(next_state_board, parent=self, move=move, prior_p=prior_p)

    def backpropagate_refined(self, value_from_leaf):
        current_node = self
        current_value = -value_from_leaf
        while current_node is not None:
             current_node.N += 1
             current_node.W += current_value
             current_node.Q = current_node.W / current_node.N if current_node.N > 0 else 0.0
             current_value *= -1
             current_node = current_node.parent

# --- MCTS Main Function ---
def run_mcts(root_board: Board, nn: ChessModelV5, int_to_move_map, simulations: int, c_puct: float, device, add_noise=False, dirichlet_alpha=0.3, dirichlet_epsilon=0.25):
    root_node = MCTSNode(state=root_board.copy())
    policy_probs_np = None

    if not root_node.is_terminal:
        try:
            board_matrix = board_to_matrix_v4(root_node.state, flip=False)
            X_tensor = torch.tensor(board_matrix, dtype=torch.float32).unsqueeze(0).to(device)
            nn.eval()
            with torch.no_grad():
                policy_logits, value_estimate = nn(X_tensor)
                policy_probs_torch = torch.softmax(policy_logits, dim=1).squeeze(0)
            policy_probs_np = policy_probs_torch.cpu().numpy()

            if add_noise:
                 temp_move_to_int = {uci: idx for idx, uci in int_to_move_map.items()}
                 legal_moves_indices = [temp_move_to_int.get(m.uci(), -1) for m in root_node.state.legal_moves]
                 legal_moves_indices = [idx for idx in legal_moves_indices if idx != -1 and idx < len(policy_probs_np)]
                 if legal_moves_indices:
                     noise = np.random.dirichlet([dirichlet_alpha] * len(legal_moves_indices))
                     noisy_policy = policy_probs_np.copy()
                     legal_probs = noisy_policy[legal_moves_indices]
                     legal_probs = (1 - dirichlet_epsilon) * legal_probs + dirichlet_epsilon * noise
                     noisy_policy[legal_moves_indices] = legal_probs
                     policy_probs_np = noisy_policy / np.sum(noisy_policy)

            root_node.expand(policy_probs_np, int_to_move_map)
        except Exception as e:
            print(f"ERROR during root evaluation/expansion: {e}")
            return root_node, np.zeros(loaded_num_classes, dtype=np.float32)
    else:
         return root_node, np.zeros(loaded_num_classes, dtype=np.float32)

    # --- Run Simulations ---
    for _ in range(simulations):
        node = root_node
        search_path = [node]
        while node.is_expanded and not node.is_terminal:
            move_uci, next_node = node.select_child(c_puct)
            if next_node is None:
                 node = None
                 break
            node = next_node
            search_path.append(node)
        if node is None: continue

        leaf_value = 0.0
        if node.is_terminal:
            leaf_value = node.terminal_value
        elif not node.is_expanded:
            try:
                board_matrix = board_to_matrix_v4(node.state, flip=False)
                X_tensor = torch.tensor(board_matrix, dtype=torch.float32).unsqueeze(0).to(device)
                nn.eval()
                with torch.no_grad():
                    policy_logits, value_estimate = nn(X_tensor)
                    policy_probs = torch.softmax(policy_logits, dim=1).squeeze(0).cpu().numpy()
                    leaf_value = value_estimate.item()
                node.expand(policy_probs, int_to_move_map)
            except Exception as e:
                print(f"ERROR during leaf evaluation/expansion: {e}")
                continue
        try:
             node.backpropagate_refined(leaf_value)
        except Exception as e:
             print(f"ERROR during backpropagation: {e}")

    # --- Calculate Policy Target ---
    policy_target = np.zeros(loaded_num_classes, dtype=np.float32)
    if root_node.N > 0:
        temp_move_to_int = {uci: idx for idx, uci in int_to_move_map.items()}
        total_child_visits = sum(child.N for child in root_node.children.values())
        if total_child_visits > 0:
            for move_uci, child in root_node.children.items():
                move_idx = temp_move_to_int.get(move_uci, -1)
                if move_idx != -1:
                    policy_target[move_idx] = child.N / total_child_visits
            target_sum = np.sum(policy_target)
            if target_sum > 1e-6: policy_target /= target_sum

    return root_node, policy_target

# --- Helper to Select Move (Keep corrected version) ---
def get_mcts_move(root_node: MCTSNode, temperature: float):
    if not root_node.children: return None
    if temperature < 1e-2:
        try:
             best_move_uci = max(root_node.children, key=lambda m_uci: root_node.children[m_uci].N)
             return best_move_uci if best_move_uci in root_node.children else random.choice(list(root_node.children.keys()))
        except ValueError: return None
    else:
        visit_counts = np.array([child.N for child in root_node.children.values()], dtype=np.float64)
        moves_uci = list(root_node.children.keys())
        total_visits = np.sum(visit_counts)
        if total_visits <= 0: return random.choice(moves_uci)
        exponent = 1.0 / temperature
        if exponent > 70:
            best_move_idx = np.argmax(visit_counts)
            return moves_uci[best_move_idx]
        try:
            visit_powers = visit_counts ** exponent
            total_power = np.sum(visit_powers)
            if total_power <= 1e-9 or not np.isfinite(total_power):
                 best_move_idx = np.argmax(visit_counts)
                 return moves_uci[best_move_idx]
            probabilities = visit_powers / total_power
            probabilities /= np.sum(probabilities)
            return np.random.choice(moves_uci, p=probabilities)
        except (OverflowError, Exception) as e:
            # print(f"Warning during move sampling (temp={temperature}): {e}. Selecting greedily.") # Debug
            best_move_idx = np.argmax(visit_counts)
            return moves_uci[best_move_idx]

print("MCTS Node and run_mcts function defined (Cleaned).")

MCTS Node and run_mcts function defined (Cleaned).


In [19]:
# Cell 7: Self-Play Function (Visualization Changed to Text)

from IPython.display import display, SVG, clear_output # Make sure clear_output is imported
import chess

def run_self_play_game(best_nn: ChessModelV5, mcts_simulations: int, c_puct: float,
                       tau_switch_move: int, device, int_to_move_map,
                       add_noise=True, dirichlet_alpha=0.3, dirichlet_epsilon=0.25,
                       visualize=False, max_moves=300): # Added visualize and max_moves args
    """ Plays one game using MCTS and returns training data """
    game_data = []
    board = Board()
    move_count = 0

    while not board.is_game_over(claim_draw=True):
        move_count += 1
        current_player_is_white = board.turn

        temperature = TEMPERATURE_TAU_START if move_count < tau_switch_move else TEMPERATURE_TAU_END

        # Run MCTS
        root_node, policy_target = run_mcts(
            board, best_nn, int_to_move_map, mcts_simulations, c_puct, device,
            add_noise=add_noise, dirichlet_alpha=dirichlet_alpha, dirichlet_epsilon=dirichlet_epsilon
        )

        # Store data point for training
        state_matrix = board_to_matrix_v4(board, flip=False)
        game_data.append((state_matrix, policy_target, 1 if current_player_is_white else -1))

        # Select move to play
        move_uci = get_mcts_move(root_node, temperature)
        if move_uci is None:
            print(f"    Warning: get_mcts_move returned None in self-play! Board state:\n{board}")
            break
        move = Move.from_uci(move_uci)

        # Play move
        try:
            board.push(move)

            # ***** VISUALIZATION PART (Changed to Text) *****
            if visualize:
                 clear_output(wait=True) # Clear previous board/text
                 print(f"Self-Play Game - Move {board.fullmove_number}{'.' if board.turn else '...'} ({'White' if not board.turn else 'Black'}) played {move_uci}")
                 print(board) # Print the text representation of the board
                 print("-" * 20)
                 time.sleep(0.1) # Shorter pause for text
            # ************************************************

        except Exception as e:
            print(f"    ERROR pushing move {move_uci}: {e}")
            break

        del root_node
        gc.collect()

        if move_count > max_moves:
             # Removed print warning to reduce output, just break
             break

    result = board.result(claim_draw=True)
    if result == '1-0': game_outcome_z = 1.0
    elif result == '0-1': game_outcome_z = -1.0
    else: game_outcome_z = 0.0

    training_samples = []
    for state_matrix, pi_target, player_marker in game_data:
        value_target = game_outcome_z * player_marker
        training_samples.append((state_matrix, pi_target, np.float32(value_target)))

    return training_samples, move_count

print("Self-play function defined (Visualization uses text output).")

Self-play function defined (Visualization uses text output).


In [20]:
# Cell 8: Data Buffer & Training Step

# Use deque for efficient adding/removing from both ends
data_buffer = deque(maxlen=DATA_BUFFER_MAX_GAMES * 100) # Store *positions*, adjust multiplier based on avg game length

def add_game_to_buffer(game_samples, buffer):
    """Adds samples from a completed game to the buffer."""
    for sample in game_samples:
        buffer.append(sample)

def sample_batch(buffer, batch_size):
    """Samples a batch of data from the buffer."""
    if len(buffer) < batch_size:
        return None # Not enough data yet

    batch_indices = np.random.choice(len(buffer), batch_size, replace=False)
    batch = [buffer[i] for i in batch_indices]

    states = np.array([s[0] for s in batch], dtype=np.float32)
    policy_targets = np.array([s[1] for s in batch], dtype=np.float32)
    value_targets = np.array([s[2] for s in batch], dtype=np.float32)

    # Convert to tensors
    states_tensor = torch.tensor(states)
    policy_targets_tensor = torch.tensor(policy_targets)
    value_targets_tensor = torch.tensor(value_targets).unsqueeze(1) # Add channel dim for MSELoss

    return states_tensor, policy_targets_tensor, value_targets_tensor

def train_step(model_to_train, optimizer, buffer, batch_size, device, gpu_count):
    """Performs one training step."""
    batch = sample_batch(buffer, batch_size)
    if batch is None:
        return None, None, None # Not enough data

    states, pi_targets, z_targets = batch
    states = states.to(device)
    pi_targets = pi_targets.to(device)
    z_targets = z_targets.to(device)

    model_to_train.train() # Set to training mode (enables dropout)

    # Use DataParallel for training step if multiple GPUs
    train_model = model_to_train
    if gpu_count > 1 and not isinstance(model_to_train, nn.DataParallel):
        print("Wrapping model with DataParallel for training step.")
        train_model = nn.DataParallel(model_to_train)
    elif gpu_count <= 1 and isinstance(model_to_train, nn.DataParallel):
         print("Unwrapping model from DataParallel for training step.")
         train_model = model_to_train.module # Access the underlying model

    pi_logits, v_preds = train_model(states)

    # Calculate loss
    value_loss = VALUE_LOSS_WEIGHT * value_criterion(v_preds, z_targets)
    # Policy loss needs logits vs target distribution
    # Use log_softmax on logits and NLLLoss, or CrossEntropy with raw logits
    policy_loss = POLICY_LOSS_WEIGHT * policy_criterion(pi_logits, pi_targets) # CrossEntropy expects logits

    total_loss = value_loss + policy_loss

    # Backpropagation
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    # Unwrap model if necessary before returning
    if gpu_count > 1 and isinstance(train_model, nn.DataParallel):
        model_to_train = train_model.module

    model_to_train.eval() # Set back to eval mode after training step

    return total_loss.item(), policy_loss.item(), value_loss.item()


print("Data buffer and training step functions defined.")

Data buffer and training step functions defined.


In [21]:
# Cell 9: Evaluation Function

def evaluate_networks(nn_new: ChessModelV5, nn_old: ChessModelV5, num_games: int, mcts_sims: int, c_puct: float, device, int_to_move_map):
    """ Plays games between two networks and returns the win rate of nn_new """
    print(f"\n--- Evaluating Networks ({num_games} games, {mcts_sims} sims/move) ---")
    nn_new_wins = 0
    nn_old_wins = 0
    draws = 0

    nn_new.eval()
    nn_old.eval()

    for i in tqdm(range(num_games), desc="Evaluation Games"):
        board = Board()
        game_over = False
        players = {1: nn_new, -1: nn_old} if i % 2 == 0 else {-1: nn_new, 1: nn_old} # Alternate starting player

        while not game_over:
            player_to_move = 1 if board.turn else -1 # 1 for White, -1 for Black
            current_nn = players[player_to_move]

            # Run MCTS for the current player, NO noise, GREEDY move selection (temp=0)
            root_node, _ = run_mcts(
                board, current_nn, int_to_move_map, mcts_sims, c_puct, device,
                add_noise=False # No noise during evaluation
            )
            move_uci = get_mcts_move(root_node, temperature=0) # Greedy move selection

            if move_uci is None:
                print("Warning: Eval game ended due to no move returned by MCTS.")
                game_over = True
                break # Should indicate a draw or loss for the player unable to move?

            move = Move.from_uci(move_uci)
            board.push(move)

            if board.is_game_over(claim_draw=True):
                game_over = True
                result = board.result(claim_draw=True)
                if result == '1-0': # White wins
                    if players[1] == nn_new: nn_new_wins += 1
                    else: nn_old_wins += 1
                elif result == '0-1': # Black wins
                    if players[-1] == nn_new: nn_new_wins += 1
                    else: nn_old_wins += 1
                else: # Draw
                    draws += 1

            del root_node # Cleanup

    win_rate_new = nn_new_wins / num_games if num_games > 0 else 0
    print(f"Evaluation Result: New Wins: {nn_new_wins}, Old Wins: {nn_old_wins}, Draws: {draws}")
    print(f"Win Rate for New Network: {win_rate_new:.2f}")
    return win_rate_new

print("Evaluation function defined.")

Evaluation function defined.


In [None]:
# Cell 10: Main Reinforcement Learning Loop (Ensure calls updated function)

if not model_loaded or not optimizer:
    print("Model not loaded or optimizer not initialized. Cannot start RL loop.")
else:
    print("\n--- Starting Reinforcement Learning Loop ---")
    start_time = time.time()
    training_history = {'iteration': [], 'total_loss': [], 'policy_loss': [], 'value_loss': [], 'win_rate': [], 'buffer_size': []}

    # Use deque for efficient adding/removing from both ends
    avg_moves_per_game_estimate = 80 # Estimate, adjust as needed
    buffer_max_positions = DATA_BUFFER_MAX_GAMES * avg_moves_per_game_estimate
    data_buffer = deque(maxlen=buffer_max_positions)
    print(f"Initialized data buffer with max capacity for ~{DATA_BUFFER_MAX_GAMES} games (~{buffer_max_positions:,} positions).")


    for iteration in range(1, MAX_RL_ITERATIONS + 1):
        iter_start_time = time.time()
        print(f"\n===== Iteration {iteration}/{MAX_RL_ITERATIONS} =====")

        # --- 1. Self-Play ---
        print("Generating self-play games...")
        new_samples_count = 0
        games_generated = 0
        best_model.eval() # Ensure best model is in eval mode for generation
        self_play_bar = tqdm(range(SELF_PLAY_GAMES_PER_ITER), desc="Self-Play", leave=False) # Use leave=False for cleaner output
        for game_idx in self_play_bar:
            visualize_this_game = VISUALIZE_SELF_PLAY and (game_idx == 0) # Visualize only the first game per iteration

            game_samples, game_len = run_self_play_game(
                best_model, MCTS_SIMULATIONS, C_PUCT, TAU_SWITCH_MOVE, device, loaded_int_to_move,
                add_noise=True, dirichlet_alpha=DIRICHLET_ALPHA, dirichlet_epsilon=DIRICHLET_EPSILON,
                visualize=visualize_this_game, # Pass the flag
                max_moves=MAX_MOVES_PER_GAME # Pass the safety limit
            )

            if visualize_this_game:
                 clear_output(wait=True) # Clear the last board state after the visualized game finishes
                 print(f"--- Visualization for game {game_idx+1} complete ---")
                 # Turn off for subsequent games in this iteration if needed
                 # VISUALIZE_SELF_PLAY = False # Example: Turn off after first

            add_game_to_buffer(game_samples, data_buffer)
            new_samples_count += len(game_samples)
            games_generated += 1
            self_play_bar.set_postfix({"BufferPos": len(data_buffer), "LastGameLen": game_len})

        buffer_current_positions = len(data_buffer)
        print(f"Self-play complete. Generated {games_generated} games ({new_samples_count} positions). Buffer size: {buffer_current_positions} positions.")

        # --- 2. Training ---
        if buffer_current_positions < RL_BATCH_SIZE:
            print(f"Buffer size ({buffer_current_positions}) < Batch Size ({RL_BATCH_SIZE}). Skipping training.")
            training_history['iteration'].append(iteration)
            training_history['total_loss'].append(None)
            training_history['policy_loss'].append(None)
            training_history['value_loss'].append(None)
            training_history['buffer_size'].append(buffer_current_positions)
            training_history['win_rate'].append(None)
            continue

        print("Training network...")
        model.train()
        total_iter_loss = 0
        policy_iter_loss = 0
        value_iter_loss = 0
        batches_trained = 0
        steps_per_epoch = max(1, buffer_current_positions // RL_BATCH_SIZE)
        total_training_steps = steps_per_epoch * TRAINING_EPOCHS_PER_ITER

        train_bar = tqdm(range(total_training_steps), desc="Training", leave=False) # Use leave=False
        for _ in train_bar:
             # NOTE: train_step internally handles DataParallel wrapping/unwrapping
             loss_vals = train_step(model, optimizer, data_buffer, RL_BATCH_SIZE, device, gpu_count)
             if loss_vals[0] is not None:
                 t_loss, p_loss, v_loss = loss_vals
                 total_iter_loss += t_loss
                 policy_iter_loss += p_loss
                 value_iter_loss += v_loss
                 batches_trained += 1
                 train_bar.set_postfix({"Loss": f"{total_iter_loss/batches_trained:.4f}"})

        model.eval()

        avg_total_loss = total_iter_loss / batches_trained if batches_trained > 0 else 0
        avg_policy_loss = policy_iter_loss / batches_trained if batches_trained > 0 else 0
        avg_value_loss = value_iter_loss / batches_trained if batches_trained > 0 else 0
        print(f"Training complete. Avg Loss: {avg_total_loss:.4f} (P: {avg_policy_loss:.4f}, V: {avg_value_loss:.4f})")

        training_history['iteration'].append(iteration)
        training_history['total_loss'].append(avg_total_loss)
        training_history['policy_loss'].append(avg_policy_loss)
        training_history['value_loss'].append(avg_value_loss)
        training_history['buffer_size'].append(buffer_current_positions)

        # --- 3. Evaluation ---
        win_rate = None
        if iteration % CHECKPOINT_SAVE_FREQ == 0 or iteration == MAX_RL_ITERATIONS:
            win_rate = evaluate_networks(model, best_model, EVALUATION_GAMES, EVAL_MCTS_SIMULATIONS, C_PUCT, device, loaded_int_to_move)

            if win_rate > WIN_THRESHOLD:
                print(f"New network IS BETTER ({win_rate:.2f} > {WIN_THRESHOLD}). Updating best model.")
                best_model.load_state_dict(model.state_dict()) # Update best model weights
                best_model_path = os.path.join(RL_CHECKPOINT_DIR, f"best_model_iter_{iteration}.pth")
                # Save the state_dict directly from best_model
                torch.save(best_model.state_dict(), best_model_path)
                print(f"Saved new best model to {best_model_path}")
            else:
                print(f"New network DID NOT improve significantly ({win_rate:.2f} <= {WIN_THRESHOLD}). Keeping previous best model.")
        training_history['win_rate'].append(win_rate)

        # --- 4. Save Checkpoint ---
        if iteration % CHECKPOINT_SAVE_FREQ == 0 or iteration == MAX_RL_ITERATIONS:
            checkpoint_path = os.path.join(RL_CHECKPOINT_DIR, f"checkpoint_iter_{iteration}.pth")
            # Save the currently training model state_dict
            # Correctly handle DataParallel: save the module's state_dict if wrapped
            state_to_save = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
            torch.save(state_to_save, checkpoint_path)
            print(f"Saved training checkpoint to {checkpoint_path}")

        iter_end_time = time.time()
        print(f"Iteration {iteration} finished in {iter_end_time - iter_start_time:.2f} seconds.")
        print("-" * 20)
        gc.collect()
        if device.type == 'cuda': torch.cuda.empty_cache()


    total_run_time = time.time() - start_time
    print(f"\n--- RL Training Loop Finished in {total_run_time / 3600:.2f} hours ---")


--- Starting Reinforcement Learning Loop ---
Initialized data buffer with max capacity for ~500 games (~40,000 positions).

===== Iteration 1/1000 =====
Generating self-play games...


Self-Play:   0%|          | 0/20 [00:00<?, ?it/s]

In [19]:
# Cell 11: Plot RL Training Progress (Optional)

if training_history['iteration']:
    print("\n--- Plotting RL Training Progress ---")
    iters = training_history['iteration']

    plt.style.use('seaborn-v0_8-darkgrid')
    fig, ax1 = plt.subplots(figsize=(12, 6))

    color = 'tab:red'
    ax1.set_xlabel('RL Iteration')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(iters, training_history['total_loss'], label='Total Loss', color='red', marker='.', linestyle='-')
    ax1.plot(iters, training_history['policy_loss'], label='Policy Loss', color='lightcoral', marker='x', linestyle=':')
    ax1.plot(iters, training_history['value_loss'], label='Value Loss', color='salmon', marker='s', linestyle='--')
    ax1.tick_params(axis='y', labelcolor=color)
    ax1.legend(loc='upper left')
    ax1.set_title('RL Training Loss and Evaluation Win Rate')

    # Instantiate a second axes that shares the same x-axis for Win Rate
    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('Evaluation Win Rate vs Best', color=color)
    # Plot win rate only where it was calculated
    eval_iters = [i for i, wr in zip(iters, training_history['win_rate']) if wr is not None]
    eval_wrs = [wr for wr in training_history['win_rate'] if wr is not None]
    ax2.plot(eval_iters, eval_wrs, label='New Model Win Rate', color=color, marker='*', linestyle='-.')
    ax2.axhline(y=WIN_THRESHOLD, color='grey', linestyle='--', label=f'Win Threshold ({WIN_THRESHOLD})')
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.set_ylim(0, 1.05) # Win rate is between 0 and 1
    ax2.legend(loc='upper right')

    fig.tight_layout()
    plt.show()
else:
    print("No RL training history to plot.")

No RL training history to plot.


In [None]:
# Cell 12: Prediction Setup (Load Trained RL Model)

# --- Load Mappings (should already be loaded, but check) ---
if not mappings_loaded:
     try:
         print(f"\nRe-Loading mappings from: {MAPPINGS_PATH}")
         with open(MAPPING_SAVE_PATH, "rb") as f: # Use save path if modified
             loaded_mappings = pickle.load(f)
         loaded_move_to_int = loaded_mappings['move_to_int']
         loaded_int_to_move = loaded_mappings['int_to_move']
         loaded_num_classes = loaded_mappings['num_classes']
         mappings_loaded = True
         print(f"Mappings loaded. Num classes: {loaded_num_classes}")
     except Exception as e:
         print(f"Error: Could not load mapping file from {MAPPING_SAVE_PATH}: {e}")
         loaded_mappings = None

# --- Device Setup ---
if torch.cuda.is_available():
    prediction_device = torch.device("cuda")
    print(f"\nUsing GPU for prediction.")
else:
    prediction_device = torch.device("cpu")
    print("\nUsing CPU for prediction.")

# --- Load the BEST RL Model for Prediction ---
rl_prediction_model = None
rl_model_loaded = False

if mappings_loaded:
    try:
        # Find the latest "best_model" checkpoint if available, otherwise latest checkpoint
        checkpoint_files = [f for f in os.listdir(RL_CHECKPOINT_DIR) if f.startswith("best_model_iter_") and f.endswith(".pth")]
        if checkpoint_files:
            checkpoint_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
            model_path_to_load = os.path.join(RL_CHECKPOINT_DIR, checkpoint_files[-1])
            print(f"\nLoading BEST RL model weights from: {model_path_to_load}")
        else:
            # Fallback to latest training checkpoint if no 'best' model saved
            checkpoint_files = [f for f in os.listdir(RL_CHECKPOINT_DIR) if f.startswith("checkpoint_iter_") and f.endswith(".pth")]
            if checkpoint_files:
                 checkpoint_files.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
                 model_path_to_load = os.path.join(RL_CHECKPOINT_DIR, checkpoint_files[-1])
                 print(f"\nLoading LATEST RL checkpoint weights from: {model_path_to_load}")
            else:
                 print("\nNo RL checkpoints found. Attempting to load original SL model for prediction.")
                 model_path_to_load = PRETRAINED_MODEL_PATH # Fallback to original SL

        print(f"Initializing model structure...")
        rl_prediction_model = ChessModelV5(
            num_policy_classes=loaded_num_classes,
            num_res_blocks=NUM_RES_BLOCKS,
            num_channels=NUM_CHANNELS,
            dropout_rate=DROPOUT_RATE # Use same dropout rate (eval mode will disable it anyway)
        )

        state_dict = torch.load(model_path_to_load, map_location='cpu')
        if all(key.startswith('module.') for key in state_dict.keys()):
             print("  Removing 'module.' prefix.")
             from collections import OrderedDict
             new_state_dict = OrderedDict()
             for k, v in state_dict.items():
                 name = k[7:]
                 new_state_dict[name] = v
             state_dict = new_state_dict

        rl_prediction_model.load_state_dict(state_dict)
        rl_prediction_model.to(prediction_device)
        rl_prediction_model.eval() # IMPORTANT: Set to eval mode for prediction
        rl_model_loaded = True
        print("RL Prediction model loaded successfully.")

    except FileNotFoundError:
        print(f"Error: Model file not found at {model_path_to_load}.")
    except Exception as e:
        print(f"Error loading RL model state dict: {e}")
        rl_prediction_model = None
else:
    print("Skipping RL model loading.")

In [None]:
# Cell 13: Prediction Function (Using RL Model)

def predict_move_rl(board: Board, model_to_use, device_to_use, int_to_move_map, use_mcts=True, mcts_sims=50, c_puct=1.0):
    """ Predicts move using the RL model, optionally with MCTS for stronger play """

    if not rl_model_loaded or not model_to_use or not int_to_move_map:
        print("Error: RL Model or mappings not ready in predict_move_rl.")
        return None, 0.0

    model_to_use.eval() # Ensure model is in eval mode

    if use_mcts:
        # Use MCTS for prediction (stronger, slower) - NO noise, GREEDY selection
        root_node, _ = run_mcts(
            board, model_to_use, int_to_move_map, mcts_sims, c_puct, device_to_use, add_noise=False
        )
        best_move_uci = get_mcts_move(root_node, temperature=0) # Greedy
        predicted_value = -root_node.Q # Root Q is from opponent's perspective after root move
        if best_move_uci:
            best_move_obj = Move.from_uci(best_move_uci)
            return best_move_obj, predicted_value
        else: # No legal moves
            return None, predicted_value # Or terminal value if root is terminal
    else:
        # Use raw network output for prediction (faster, weaker)
        board_matrix = board_to_matrix_v4(board, flip=False)
        X_tensor = torch.tensor(board_matrix, dtype=torch.float32).unsqueeze(0).to(device_to_use)
        best_legal_move_obj = None
        predicted_value = 0.0

        with torch.no_grad():
            try:
                policy_logits, value_output = model_to_use(X_tensor)
                predicted_value = value_output.item()
                policy_probs = torch.softmax(policy_logits.squeeze(0), dim=0).cpu().numpy()

                legal_moves = list(board.legal_moves)
                if not legal_moves: return None, predicted_value

                best_prob = -1.0
                for move in legal_moves:
                    move_uci = move.uci()
                    move_idx = -1
                    # Find index (can be slow, better to use move_to_int if available & consistent)
                    for idx, uci in int_to_move_map.items():
                        if uci == move_uci:
                            move_idx = idx
                            break
                    if move_idx != -1:
                        prob = policy_probs[move_idx]
                        if prob > best_prob:
                            best_prob = prob
                            best_legal_move_obj = move

                if not best_legal_move_obj: # If all legal moves had 0 prob (unlikely)
                    best_legal_move_obj = random.choice(legal_moves)

            except Exception as e:
                print(f"Error during raw NN prediction: {e}")
                legal_moves = list(board.legal_moves)
                if legal_moves: best_legal_move_obj = random.choice(legal_moves)

        return best_legal_move_obj, predicted_value


print("RL Prediction function defined.")

In [None]:
# Cell 14: Prediction Example (with SVG)

if rl_model_loaded:
    board = Board()
    print("\n--- Starting RL Prediction Example with SVG Output ---")
    MAX_PREDICTION_MOVES = 50 # Play N moves
    USE_MCTS_IN_EXAMPLE = True # Use MCTS for stronger moves in the example?
    EXAMPLE_MCTS_SIMS = MCTS_SIMULATIONS # Use same sims as self-play, or adjust

    move_counter = 0

    try:
        print("Initial Board:")
        display(SVG(board._repr_svg_()))

        while move_counter < MAX_PREDICTION_MOVES:
            move_counter += 1
            current_player = "White" if board.turn else "Black"
            print(f"\n--- Move {board.fullmove_number}. {'...' if not board.turn else ''}{current_player} to Play ---")

            if board.is_game_over(claim_draw=True):
                print(f"Game Over! Result: {board.result(claim_draw=True)}")
                break

            # Use the RL prediction function
            ai_move, predicted_val = predict_move_rl(
                board, rl_prediction_model, prediction_device, loaded_int_to_move,
                use_mcts=USE_MCTS_IN_EXAMPLE, mcts_sims=EXAMPLE_MCTS_SIMS, c_puct=C_PUCT
            )

            if ai_move:
                move_uci = ai_move.uci()
                print(f"RL AI {'(MCTS)' if USE_MCTS_IN_EXAMPLE else '(Raw NN)'} suggests: {move_uci} (Predicted Value: {predicted_val:.3f})")
                board.push(ai_move)
                print(f"\nBoard after {board.fullmove_number-1 if not board.turn else board.fullmove_number}.{'..' if board.turn else ''}{move_uci}:")
                display(SVG(board._repr_svg_()))
            else:
                print(f"RL AI could not suggest a valid move for {current_player}.")
                if board.is_game_over(claim_draw=True):
                     print(f"Game Over! Result: {board.result(claim_draw=True)}")
                break

    except Exception as e:
         print(f"\nAn unexpected error occurred during the prediction simulation: {e}")
         import traceback
         traceback.print_exc()

    print("\n--- Prediction Example Finished ---")
    print("\nFinal Board State:")
    display(SVG(board._repr_svg_()))
    if board.is_game_over(claim_draw=True):
         print(f"Final Result: {board.result(claim_draw=True)}")

else:
    print("\nRL prediction model or mappings not loaded. Skipping prediction example.")