In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [2]:
from ray.rllib.models.torch.misc import SlimFC, normc_initializer
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.env.env_context import EnvContext
from ray.tune.registry import register_env
from ray.rllib.algorithms.ppo import PPO
from datetime import datetime
import gymnasium as gym
from torch import nn
import numpy as np
import torch
import chess
import ray
import os

In [None]:
import torch.version


torch.cuda.is_available()

In [4]:
class RayChessEnvironment(gym.Env):
    def __init__(self, config: EnvContext):
        super().__init__()
        self.board = chess.Board()
        self.move_table = self.build_move_table()
        self.move_lookup = {move: idx for idx, move in enumerate(self.move_table)}
        self.action_space = gym.spaces.Discrete(len(self.move_table))
        self.observation_space = gym.spaces.Box(
            low=0.0, high=1.0, shape=(19, 8, 8), dtype=np.float32
        )

    def build_move_table(self):
        move_table = []
        for from_square in range(64):
            for to_square in range(64):
                if from_square != to_square:
                    move_table.append(chess.Move(from_square, to_square))
        
        promotion_ranks = ((1, 0), (6, 7))
        for from_rank, to_rank in promotion_ranks:
            for file in range(8):
                from_square = from_rank * 8 + file
                to_square = to_rank * 8 + file

                for promotion in (chess.QUEEN, chess.ROOK, chess.BISHOP, chess.KNIGHT):
                    move_table.append(chess.Move(from_square, to_square, promotion))

                if file > 0:
                    to_square = to_rank * 8 + (file - 1)
                    for promotion in (chess.QUEEN, chess.ROOK, chess.BISHOP, chess.KNIGHT):
                        move_table.append(chess.Move(from_square, to_square, promotion))

                if file < 7:
                    to_square = to_rank * 8 + (file + 1)
                    for promotion in (chess.QUEEN, chess.ROOK, chess.BISHOP, chess.KNIGHT):
                        move_table.append(chess.Move(from_square, to_square, promotion))
        
        return move_table

    def reset(self, *, seed=None, options=None):
        self.board.reset()
        if seed is not None:
            np.random.seed(seed)
        return self.get_observation(), {}

    def step(self, action):
        move = self.move_table[action]

        try:
            if move in self.board.legal_moves:
                self.board.push(move)
            else:
                legal_moves = list(self.board.legal_moves)
                if legal_moves:
                    self.board.push(np.random.choice(legal_moves))
            
            obs = self.get_observation()
            assert (obs >= 0).all() and (obs <= 1).all(), "Observation out of bounds"
            
            reward = self.get_reward()
            done = self.board.is_game_over()
            return obs, reward, done, False, {}
            
        except Exception as e:
            print(f"Error in step: {e}")
            raise

    def get_observation(self):
        board_tensor = np.zeros((19, 8, 8), dtype=np.float32)
        
        # Piece planes (0-11)
        for square, piece in self.board.piece_map().items():
            piece_plane = {
                'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
                'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
            }[piece.symbol()]
            rank, file = divmod(square, 8)
            board_tensor[piece_plane, rank, file] = 1.0

        # State planes (12-18)
        state_planes = np.array([
            float(self.board.turn),
            float(self.board.has_kingside_castling_rights(True)),
            float(self.board.has_queenside_castling_rights(True)),
            float(self.board.has_kingside_castling_rights(False)),
            float(self.board.has_queenside_castling_rights(False)),
            float(self.board.is_check()),
            min(1.0, self.board.halfmove_clock / 100.0)  # Clip to ensure <= 1.0
        ])
        
        # Broadcast state planes to 8x8
        for i, plane in enumerate(state_planes):
            board_tensor[12 + i] = np.full((8, 8), plane)
        
        return board_tensor

    def get_reward(self):
        if self.board.is_checkmate():
            return 1.0 if not self.board.turn else -1.0
        return 0.0

    @staticmethod
    def env_creator(config: EnvContext):
        return RayChessEnvironment(config)

class ChessModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        # CNN layers for processing the chess board
        self.conv_layers = nn.Sequential(
            nn.Conv2d(19, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
        )

        # Calculate the flattened size after convolutions
        conv_out_size = 256 * 8 * 8  # Since padding preserves spatial dimensions

        # Fully connected layers
        self.fc_layers = nn.Sequential(
            SlimFC(
                in_size=conv_out_size,
                out_size=1024,
                initializer=normc_initializer(1.0),
                activation_fn="relu"
            ),
            SlimFC(
                in_size=1024,
                out_size=512,
                initializer=normc_initializer(1.0),
                activation_fn="relu"
            ),
            SlimFC(
                in_size=512,
                out_size=num_outputs,
                initializer=normc_initializer(0.01),
                activation_fn=None
            )
        )

        # Value branch
        self.value_branch = nn.Sequential(
            SlimFC(
                in_size=conv_out_size,
                out_size=512,
                initializer=normc_initializer(1.0),
                activation_fn="relu"
            ),
            SlimFC(
                in_size=512,
                out_size=1,
                initializer=normc_initializer(0.01),
                activation_fn=None
            )
        )

        # Variable for storing the value output
        self._value_out = None

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        # Process observation through CNN
        x = input_dict["obs"].float()  # Shape: [B, 19, 8, 8]
        conv_out = self.conv_layers(x)
        
        # Flatten
        conv_flat = torch.flatten(conv_out, start_dim=1)
        
        # Compute action logits
        logits = self.fc_layers(conv_flat)
        
        # Compute value
        self._value_out = self.value_branch(conv_flat).squeeze(1)
        
        return logits, state

    @override(TorchModelV2)
    def value_function(self):
        return self._value_out
    
def setup_logging():
    # Create logs directory if it doesn't exist
    log_dir = "/kaggle/working/chess_training_logs"
    os.makedirs(log_dir, exist_ok=True)
    
    # Create a timestamped log file
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(log_dir, f"training_log_{timestamp}.txt")
    
    return log_file

def save_log(log_file, episode, result):
    with open(log_file, "a") as f:
        f.write(f"Episode {episode}:\n")
        metrics = {
            "episode_reward_mean": result.get("episode_reward_mean"),
            "episode_len_mean": result.get("episode_len_mean"),
            "training_iteration": result.get("training_iteration"),
            "timesteps_total": result.get("timesteps_total"),
            "time_total_s": result.get("time_total_s"),
        }
        # Write summarized metrics
        for key, value in metrics.items():
            f.write(f"  {key}: {value}\n")

def main():
    # Register the custom environment
    register_env("RayChessEnv", RayChessEnvironment.env_creator)

    # Define PPO training configuration
    config = {
        "env": "RayChessEnv",
        "framework": "torch",
        "num_workers": 4,
        "num_envs_per_worker": 16,
        "num_gpus": 1,
        "env_config": {},

        "_enable_new_api_stack": False,

        "model": {
            "custom_model": ChessModel,
            "custom_model_config": {},
        },

        "lr": 1e-4,
        "train_batch_size": 1024,
        "sgd_minibatch_size": 128,
        "num_sgd_iter": 10,
        "gamma": 0.99,
        "lambda": 0.95,
        "clip_param": 0.2,
        "vf_clip_param": 10.0,
        "entropy_coeff": 0.01,

        "num_gpus_per_worker": 0.05,
        "rollout_fragment_length": 'auto',
        "batch_mode": "truncate_episodes",
    }

    # Initialize Ray and set up logging
    ray.init()
    log_file = setup_logging()
    log_dir = os.path.dirname(log_file)  # Directory for logs and checkpoints

    # Create PPO trainer
    trainer = PPO(config=config)

    # Training loop with reduced logging
    for episode in range(1024):
        result = trainer.train()

        # Log every 10 episodes
        if episode % 10 == 0:
            save_log(log_file, episode, result)
            print(f"Episode {episode}: {result.get('episode_reward_mean')} mean reward")

        # Save checkpoints every 100 episodes
        if episode % 100 == 0:
            checkpoint = trainer.save(log_dir)
            print(f"Checkpoint saved at: {checkpoint}")
            with open(log_file, "a") as f:
                f.write(f"\nCheckpoint saved at: {checkpoint}\n")

    # Cleanup
    ray.shutdown()


In [None]:
if __name__ == "__main__":
    main()

[36m(RolloutWorker pid=67688)[0m Exception raised in creation task: The actor died because of an error raised in its creation task, [36mray::RolloutWorker.__init__()[39m (pid=67688, ip=127.0.0.1, actor_id=de374220ac91a9d41733e15401000000, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x000001C32837FF50>)
[36m(RolloutWorker pid=67688)[0m   File "python\ray\_raylet.pyx", line 1879, in ray._raylet.execute_task
[36m(RolloutWorker pid=67688)[0m   File "python\ray\_raylet.pyx", line 1820, in ray._raylet.execute_task.function_executor
[36m(RolloutWorker pid=67688)[0m   File "c:\Users\yasse\chess\.conda\Lib\site-packages\ray\_private\function_manager.py", line 696, in actor_method_executor
[36m(RolloutWorker pid=67688)[0m     return method(__ray_actor, *args, **kwargs)
[36m(RolloutWorker pid=67688)[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[36m(RolloutWorker pid=67688)[0m   File "c:\Users\yasse\chess\.conda\Lib\site-packages\ray\util\tracing\tracing