In [2]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [3]:
import gym
import torch as T
import torch.nn as nn
import numpy as np
import pygame
import torch.optim as optim
import cv2
import torch.nn.functional as F
from abc import ABC, abstractmethod
from gym import spaces
from typing import Union
from collections import deque
from pygame.font import Font
from pygame.surface import Surface
from torch.distributions.categorical import Categorical
from tqdm import tqdm

In [4]:
# first let's make sure you have internet enabled
import requests
requests.get('http://www.google.com',timeout=10).ok

True

#### If you don't have internet access (it doesn't say "True" above)
1. make sure your account is Phone Verified in [account settings](https://www.kaggle.com/settings)
2. make sure internet is turned on in Settings -> Turn on internet

In [4]:
%%capture
# ensure we are on the latest version of kaggle-environments
!pip install --upgrade kaggle-environments

In [None]:
# Now let's set up the chess environment!
from kaggle_environments import make
env = make("chess", debug=True)

In [None]:
# this should run a game in the environment between two random bots
# NOTE: each game starts from a randomly selected opening
result = env.run(["random", "random"])
env.render(mode="ipython", width=1000, height=1000) 

### Creating your first agent
Now let's create your first agent! The environment has the [Chessnut](https://github.com/cgearhart/Chessnut) pip package installed and we'll use that to parse the board state and generate moves.

In [None]:
from Chessnut.moves import MOVES
print (MOVES['K']) 

In [8]:
class Learning(nn.Module, ABC):
    def __init__(
        self,
        environment: gym.Env,
        epochs: int,
        gamma: float,
        learning_rate: float
    ) -> None:
        super().__init__()
        self.state_dim = environment.observation_space.shape[0]
        self.action_dim = environment.action_space.n

        self.gamma = gamma
        self.epochs = epochs
        self.learning_rate = learning_rate

        self.device = T.device("cuda:0" if T.cuda.is_available() else "cpu")

    @abstractmethod
    def take_action(self, state: np.ndarray, *args):
        pass

    @abstractmethod
    def learn(self):
        pass

    @abstractmethod
    def remember(self, *args):
        pass
    
    @abstractmethod
    def save(self, folder: str):
        pass

##### types

In [7]:
import numpy as np

Cell = tuple[int]
Action = tuple[Cell, Cell]
Trajectory = tuple[np.ndarray, float, bool, dict]

##### pieces

In [10]:
EMPTY = 0
PAWN = 1
BISHOP = 2
KNIGHT = 3
ROOK = 4
QUEEN = 5
KING = 6

BLACK = 0
WHITE = 1

ASCIIS = (
    ("♙", "♗", "♘", "♖", "♕", "♔"),
    ("♟︎", "♝", "♞", "♜", "♛", "♚"),
)

def get_ascii(color: int, piece: int):
    return ASCIIS[color][piece - 1][0]

##### movements

In [11]:
MOVE = -1

CHECK_WIN = 10
CHECK_LOSE = -CHECK_WIN

CHECK_MATE_WIN = 100
CHECK_MATE_LOSE = -CHECK_MATE_WIN

##### Info keys

In [12]:
WRONG_MOVE = "wrong_move"
EMPTY_SELECT = "empty_select"
CHECK_WIN = "check_win"
CHECK_LOSE = "check_lose"
CHECK_MATE_WIN = "check_mate_win"
CHECK_MATE_LOSE = "check_mate_lose"

##### Colors

In [13]:
BLACK_color = (000,) * 3
GRAY_color  = (128,) * 3
WHITE_color = (255,) * 3

##### moves

In [14]:
POSSIBLE_MOVES = {
    "king": 8,
    "knight": 8,
    "rook": 7 * 4,
    "bishop": 7 * 4,
    "queen": 7 * 4 * 2,
    "pawn": 7 * 4 * 2,
}

KING_moves = ((-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1))

ROOK_moves = (
    (1, 0),
    (2, 0),
    (3, 0),
    (4, 0),
    (5, 0),
    (6, 0),
    (7, 0),
    (-1, 0),
    (-2, 0),
    (-3, 0),
    (-4, 0),
    (-5, 0),
    (-6, 0),
    (-7, 0),
    (0, 1),
    (0, 2),
    (0, 3),
    (0, 4),
    (0, 5),
    (0, 6),
    (0, 7),
    (0, -1),
    (0, -2),
    (0, -3),
    (0, -4),
    (0, -5),
    (0, -6),
    (0, -7),
)
BISHOP_moves = (
    (1, 1),
    (1, -1),
    (2, 2),
    (3, 3),
    (4, 4),
    (5, 5),
    (6, 6),
    (7, 7),
    (2, -2),
    (3, -3),
    (4, -4),
    (5, -5),
    (6, -6),
    (7, -7),
    (-1, 1),
    (-2, 2),
    (-3, 3),
    (-4, 4),
    (-5, 5),
    (-6, 6),
    (-7, 7),
    (-1, -1),
    (-2, -2),
    (-3, -3),
    (-4, -4),
    (-5, -5),
    (-6, -6),
    (-7, -7),
)

KNIGHT_moves = (
    (2, 1),
    (2, -1),
    (-2, 1),
    (-2, -1),
    (1, 2),
    (1, -2),
    (-1, 2),
    (-1, -2),
)

QUEEN_moves = BISHOP_moves + ROOK_moves

PAWN_moves = ((1, 0), (2, 0), (1, 1), (1, -1)) + ROOK_moves[2:] + BISHOP_moves[1:]

PIECE_MOVE = [
    None,
    PAWN_moves,
    BISHOP_moves,
    KNIGHT_moves,
    ROOK_moves,
    QUEEN_moves,
    KING_moves
]

In [None]:
print(QUEEN)
board = np.zeros((2, 8, 8), dtype=np.uint8)
board[:, 0, 3] = QUEEN

##### utils

In [16]:
def build_base_model(
    input_size: int,
    hidden_layers: tuple[int],
    output_size: int,
    last_activation: nn.Module = nn.Identity(),
) -> nn.Module:
    layers = [
        nn.Linear(input_size, hidden_layers[0]),
        nn.ReLU(),
    ]

    for i in range(len(hidden_layers) - 1):
        in_features = hidden_layers[i]
        out_features = hidden_layers[i + 1]
        layers += [
            nn.Linear(in_features, out_features),
            nn.ReLU(),
        ]

    layers += [nn.Linear(hidden_layers[-1], output_size), last_activation]

    return nn.Sequential(*layers)


def make_batch_ids(n: int, batch_size: int, shuffle: bool = True) -> np.ndarray:
    starts = np.arange(0, n, batch_size)
    indices = np.arange(n, dtype=np.int64)
    if shuffle:
        np.random.shuffle(indices)
    return [indices[i : i + batch_size] for i in starts]


def tensor_to_numpy(x: T.Tensor) -> np.ndarray:
    return x.detach().cpu().numpy()


def save_to_video(path: str, frames: np.ndarray, fps: int = 2):
    size = frames.shape[1:3]
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(
        path,
        fourcc,
        fps,
        size,
    )
    for f in frames:
        out.write(f)
    out.release()
    

##### Rewards

In [17]:
MOVE_reward = -1

CHECK_WIN_reward = 10
CHECK_LOSE_reward = -CHECK_WIN_reward

CHECK_MATE_WIN_reward = 100
CHECK_MATE_LOSE_reward = -CHECK_MATE_WIN_reward

##### Fen to board

In [10]:
piece_dic = {
    'p' : 1,
    "b" : 2,
    "n" : 3,
    "r" : 4,
    "q" : 5,
    "k" : 6
    }
def fen_board(fen:str):
    fen = fen.split(' ')
    fen_board = fen[0]
    turn = fen[1] 
    fen_lines = fen_board.split("/")
    # Handle White 
    white_pos = T.zeros(size=(8,8))
    i=0
    
    while i<8:
        j=0
        empties=0
        while j<len(fen_lines[i]):
            print(i,j)
            if fen_lines[i][j].isupper():
                print(piece_dic.get(fen_lines[i][j].lower()))
                white_pos[i][j+empties]=piece_dic.get(fen_lines[i][j].lower())
            elif fen_lines[i][j].isdigit():
                empties+=int(fen_lines[i][j])-1
            j+=1
        i+=1
    # Handle Black
    black_pos = T.zeros(size=(8,8))
    i=0
    while i<8:
        j = 0
        empties = 0 
        while j<len(fen_lines[i]):
            if fen_lines[i][j].islower():
                black_pos[i][empties+j]=piece_dic.get(fen_lines[i][j])
            elif fen_lines[i][j].isdigit():
                empties+=int(fen_lines[i][j])-1
            j+=1
        i+=1
    if turn=='w':
        board = T.concat(tensors=(white_pos.flip(dims=[0]).unsqueeze(dim=0),black_pos.unsqueeze(dim=0)),dim=0)
    else : 
        board = T.concat(tensors=(black_pos.unsqueeze(dim=0),white_pos.unsqueeze(dim=0)),dim=0)
    return board

##### Environment

In [11]:
class Chess(gym.Env):
    metadata: dict = {
        "render_mode": ("human", "rgb_array"),
    }

    def __init__(
        self,
        max_steps: int = 6000,
        render_mode: str = "human",
        window_size: int = 800,
    ) -> None:
        self.action_space = spaces.Discrete(640)
        self.observation_space = spaces.Box(0, 7, (128,), dtype=np.int32)

        self.board: np.ndarray = self.init_board()
        self.pieces: list[dict] = self.init_pieces()
        self.pieces_names: list[str] = self.get_pieces_names()

        self.turn: int = WHITE
        self.done: bool = False
        self.steps: int = 0
        self.checked: bool = [False, False]
        self.max_steps: int = max_steps

        self.font: Font = None
        self.cell_size: int = window_size // 8
        self.screen: Surface = None
        self.window_size: int = window_size
        self.render_mode: str = render_mode

    def init_board(self) -> np.ndarray:
        board = np.zeros((2, 8, 8), dtype=np.uint8)
        board[:, 0, 3] = QUEEN
        board[:, 0, 4] = KING
        board[:, 1, :] = PAWN
        board[:, 0, (0, 7)] = ROOK
        board[:, 0, (1, 6)] = KNIGHT
        board[:, 0, (2, 5)] = BISHOP
        return board

    def init_pieces(self):
        pieces = {
            "pawn_1": (1, 0),
            "pawn_2": (1, 1),
            "pawn_3": (1, 2),
            "pawn_4": (1, 3),
            "pawn_5": (1, 4),
            "pawn_6": (1, 5),
            "pawn_7": (1, 6),
            "pawn_8": (1, 7),
            "rook_1": (0, 0),
            "rook_2": (0, 7),
            "knight_1": (0, 1),
            "knight_2": (0, 6),
            "bishop_1": (0, 2),
            "bishop_2": (0, 5),
            "queen": (0, 3),
            "king": (0, 4),
        }

        return [pieces.copy(), pieces.copy()]

    def get_state(self, turn: int) -> np.ndarray:
        arr = self.board.copy()
        print(arr[0],arr[1])
        if turn == WHITE:
            arr[[0, 1]] = arr[[1, 0]]
        return arr.flatten()

    def draw_cells(self):
        for y in range(8):
            for x in range(8):
                self.draw_cell(x, y)

    def draw_pieces(self):
        for y in range(8):
            for x in range(8):
                self.draw_piece(x, y)

    def render(self) -> Union[None, np.ndarray]:
        self.init_pygame()
        self.screen.fill(BLACK_color)
        self.draw_cells()
        self.draw_pieces()

        if self.render_mode == "human":
            pygame.display.flip()
        else:
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
            )

    def init_pygame(self) -> None:
        if self.screen is not None:
            return
        pygame.init()
        pygame.font.init()
        self.font = pygame.font.Font("./seguisym.ttf", self.cell_size // 2)
        if self.render_mode == "human":
            pygame.display.init()
            self.screen = pygame.display.set_mode((self.window_size,) * 2)
            pygame.display.set_caption("Chess RL Environment")
        else:
            self.screen = pygame.Surface((self.window_size,) * 2)

    def get_cell_color(self, x: int, y: int) -> tuple[int]:
        if (x + y) % 2 == 0:
            return GRAY_color
        return BLACK_color

    def get_left_top(self, x: int, y: int, offset: float = 0) -> tuple[int]:
        return self.cell_size * x + offset, self.cell_size * y + offset

    def draw_cell(self, x: int, y: int) -> None:
        pygame.draw.rect(
            self.screen,
            self.get_cell_color(x, y),
            pygame.Rect((*self.get_left_top(x, y), self.cell_size, self.cell_size)),
        )

    def draw_piece(self, x: int, y: int) -> None:
        row, col = y, x
        for color in [BLACK, WHITE]:

            if self.is_empty((row, col), color):
                continue

            yy = abs((color * 7) - y)
            text = self.font.render(
                get_ascii(color, self.board[color, row, col]),
                True,
                WHITE_color,
                self.get_cell_color(x, yy),
            )
            rect = text.get_rect()
            rect.center = self.get_left_top(x, yy, offset=self.cell_size // 2)
            self.screen.blit(text, rect)

    def close(self) -> None:
        if self.screen is None:
            return
        pygame.display.quit()
        pygame.quit()

    def reset(self) -> np.ndarray:
        self.done = False
        self.turn = WHITE
        self.steps = 0
        self.board = self.init_board()
        self.pieces = self.init_pieces()
        self.checked = [False, False]

    def get_pieces_names(self) -> set:
        return list(self.pieces[0].keys())

    def is_in_range(self, pos: Cell) -> bool:
        row, col = pos
        return row >= 0 and row <= 7 and col >= 0 and col <= 7

    def get_size(self, name: str):
        return POSSIBLE_MOVES[name]

    def get_empty_actions(self, name: str):
        size = self.get_size(name)
        possibles = np.zeros((size, 2), dtype=np.int32)
        actions_mask = np.zeros((size), dtype=np.int32)
        return possibles, actions_mask

    def is_path_empty(self, current_pos: Cell, next_pos: Cell, turn: int) -> bool:
        next_row, next_col = next_pos
        current_row, current_col = current_pos

        diff_row = next_row - current_row
        diff_col = next_col - current_col
        sign_row = np.sign(next_row - current_row)
        sign_col = np.sign(next_col - current_col)

        size = max(abs(diff_row), abs(diff_col)) - 1
        rows = np.zeros(size, dtype=np.int32) + next_row
        cols = np.zeros(size, dtype=np.int32) + next_col

        if diff_row:
            rows = np.arange(current_row + sign_row, next_row, sign_row, dtype=np.int32)

        if diff_col:
            cols = np.arange(current_col + sign_col, next_col, sign_col, dtype=np.int32)

        for pos in zip(rows, cols):
            if not self.both_side_empty(tuple(pos), turn):
                return False

        return True

    def piece_can_jump(self, pos: Cell, turn: int) -> bool:
        jumps = {KNIGHT, KING}
        piece = self.board[turn, pos[0], pos[1]]
        return piece in jumps

    def general_validation(
        self,
        current_pos: Cell,
        next_pos: Cell,
        turn: int,
        deny_enemy_king: bool,
    ) -> bool:
        if not self.is_in_range(next_pos):
            return False

        if not self.is_empty(next_pos, turn):
            return False

        if self.is_enemy_king(next_pos, turn) and (not deny_enemy_king):
            return False

        if (not self.piece_can_jump(current_pos, turn)) and (
            not self.is_path_empty(current_pos, next_pos, turn)
        ):
            return False

        return True

    def is_valid_move(
        self,
        current_pos: Cell,
        next_pos: Cell,
        turn: int,
        deny_enemy_king: bool,
    ) -> bool:
        if not self.general_validation(current_pos, next_pos, turn, deny_enemy_king):
            return False
        if self.is_lead_to_check(current_pos, next_pos, turn):
            return False
        return True

    def is_lead_to_check(self, current_pos: int, next_pos: int, turn: int) -> bool:
        temp = Chess(render_mode="rgb_array")
        temp.board = np.copy(self.board)
        temp.move_piece(current_pos, next_pos, turn)
        return temp.is_check(temp.get_pos_king(turn), turn)

    def get_actions_for_bishop(
        self, pos: Cell, turn: int, deny_enemy_king: bool = False
    ):
        possibles, actions_mask = self.get_empty_actions("bishop")
        if pos is None:
            return possibles, actions_mask

        row, col = pos
        for i, (r, c) in enumerate(BISHOP_moves):
            next_pos = (row + r, col + c)

            if not self.is_valid_move(pos, next_pos, turn, deny_enemy_king):
                continue

            possibles[i] = next_pos
            actions_mask[i] = 1

        return possibles, actions_mask

    def get_actions_for_rook(self, pos: Cell, turn: int, deny_enemy_king: bool = False):
        possibles, actions_mask = self.get_empty_actions("rook")
        if pos is None:
            return possibles, actions_mask

        row, col = pos
        for i, (r, c) in enumerate(ROOK_moves):
            next_pos = (row + r, col + c)

            if not self.is_valid_move(pos, next_pos, turn, deny_enemy_king):
                continue

            possibles[i] = next_pos
            actions_mask[i] = 1

        return possibles, actions_mask

    def get_action_for_queen(self, pos: Cell, turn: int, deny_enemy_king: bool = False):
        possibles_rook, actions_mask_rook = self.get_actions_for_rook(
            pos, turn, deny_enemy_king
        )
        possibles_bishop, actions_mask_bishop = self.get_actions_for_bishop(
            pos, turn, deny_enemy_king
        )
        possibles = np.concatenate([possibles_bishop, possibles_rook])
        actions_mask = np.concatenate([actions_mask_bishop, actions_mask_rook])

        return possibles, actions_mask

    def get_actions_for_pawn(self, pos: Cell, turn: int, deny_enemy_king: bool = False):
        possibles, actions_mask = self.get_empty_actions("pawn")
        if pos is None:
            return possibles, actions_mask

        row, col = pos
        if self.board[turn, row, col] == QUEEN:
            return self.get_action_for_queen(pos, turn)

        for i, (r, c) in enumerate(PAWN_moves[:4]):
            next_pos = (row + r, col + c)

            if not self.is_valid_move(pos, next_pos, turn, deny_enemy_king):
                continue

            can_moves = (
                (r == 1 and c == 0 and self.both_side_empty(next_pos, turn)),
                (r == 2 and row == 1 and self.both_side_empty(next_pos, turn)),
                (r == 1 and abs(c) == 1 and self.check_for_enemy(next_pos, turn)),
                # TODO: EN PASSANT
            )

            if True in can_moves:
                possibles[i] = next_pos
                actions_mask[i] = 1

        return possibles, actions_mask

    def get_actions_for_knight(
        self, pos: Cell, turn: int, deny_enemy_king: bool = False
    ):
        possibles, actions_mask = self.get_empty_actions("knight")

        if pos is None:
            return possibles, actions_mask

        row, col = pos
        for i, (r, c) in enumerate(KNIGHT_moves):
            next_pos = (row + r, col + c)
            if not self.is_valid_move(pos, next_pos, turn, deny_enemy_king):
                continue

            possibles[i] = next_pos
            actions_mask[i] = 1

        return possibles, actions_mask

    def get_actions_for_king(self, pos: Cell, turn: int):
        pos
        row, col = pos
        possibles, actions_mask = self.get_empty_actions("king")

        for i, (r, c) in enumerate(KING_moves):
            next_pos = (row + r, col + c)

            if not self.is_valid_move(pos, next_pos, turn, False):
                continue

            if self.is_neighbor_enemy_king(next_pos, turn):
                continue

            possibles[i] = next_pos
            actions_mask[i] = 1

        return possibles, actions_mask

    def get_source_pos(self, name: str, turn: int):
        cat = name.split("_")[0]
        pos = self.pieces[turn][name]
        if pos is None:
            pos = (0, 0)
        size = self.get_size(cat)
        return np.array([pos] * size)

    def get_actions_for(self, name: str, turn: int, deny_enemy_king: bool = False):
        assert name in self.pieces_names, f"name not in {self.pieces_names}"
        piece_cat = name.split("_")[0]
        piece_pos = self.pieces[turn][name]
        src_poses = self.get_source_pos(name, turn)

        if piece_cat == "pawn":
            return (
                src_poses,
                *self.get_actions_for_pawn(piece_pos, turn, deny_enemy_king),
            )

        if piece_cat == "knight":
            return (
                src_poses,
                *self.get_actions_for_knight(piece_pos, turn, deny_enemy_king),
            )

        if piece_cat == "rook":
            return (
                src_poses,
                *self.get_actions_for_rook(piece_pos, turn, deny_enemy_king),
            )

        if piece_cat == "bishop":
            return (
                src_poses,
                *self.get_actions_for_bishop(piece_pos, turn, deny_enemy_king),
            )

        if piece_cat == "queen":
            return (
                src_poses,
                *self.get_action_for_queen(piece_pos, turn, deny_enemy_king),
            )

        if piece_cat == "king":
            return (
                src_poses,
                *self.get_actions_for_king(piece_pos, turn),
            )

    def get_all_actions(self, turn: int, deny_enemy_king: bool = False):
        all_possibles = []
        all_source_pos = []
        all_actions_mask = []
        for name in self.pieces[turn].keys():
            # DENY ENEMY KING == FOR CHECKMATE VALIDATION ONLY SO ....
            if name == "king" and deny_enemy_king:
                continue

            source_pos, possibles, actions_mask = self.get_actions_for(
                name, turn, deny_enemy_king
            )
            all_source_pos.append(source_pos)
            all_possibles.append(possibles)
            all_actions_mask.append(actions_mask)

        return (
            np.concatenate(all_source_pos),
            np.concatenate(all_possibles),
            np.concatenate(all_actions_mask),
        )

    def check_for_enemy(self, pos: Cell, turn: int) -> bool:
        r, c = pos
        return not self.is_empty((7 - r, c), 1 - turn)

    def is_empty(self, pos: Cell, turn: int) -> bool:
        return self.board[turn, pos[0], pos[1]] == EMPTY

    def is_enemy_king(self, pos: Cell, turn: int) -> bool:
        r, c = pos
        return self.board[1 - turn, 7 - r, c] == KING

    def both_side_empty(self, pos: Cell, turn: int) -> bool:
        r, c = pos
        return self.is_empty(pos, turn) and self.is_empty((7 - r, c), 1 - turn)

    def get_pos_king(self, turn: int) -> Cell:
        row, col = np.where(self.board[turn] == KING)
        return row[0], col[0]

    def is_neighbor_enemy_king(self, pos: Cell, turn: int) -> bool:
        row, col = pos
        row_enemy_king, col_enemy_king = self.get_pos_king(1 - turn)
        row_enemy_king = 7 - row_enemy_king
        diff_row = abs(row - row_enemy_king)
        diff_col = abs(col - col_enemy_king)
        return diff_row <= 1 and diff_col <= 1

    def is_check(self, king_pos: Cell, turn: int) -> bool:
        rk, ck = king_pos
        
        # GO TO UP ROW
        for r in range(rk + 1, 8):
            if not self.is_empty((r, ck), turn):
                break
            p = self.board[1 - turn, 7 - r, ck]
            if p == ROOK or p == QUEEN:
                return True
        
        # GO TO DOWN ROW
        for r in range(rk - 1, -1, -1):
            if not self.is_empty((r, ck), turn):
                break
            p = self.board[1 - turn, 7 - r, ck]
            if p == ROOK or p == QUEEN:
                return True
        
        # GO TO RIGHT COL
        for c in range(ck + 1, 8):
            if not self.is_empty((rk, c), turn):
                break
            p = self.board[1 - turn, 7 - rk, c]
            if p == ROOK or p == QUEEN:
                return True

        # GOT TO LEFT COL
        for c in range(ck - 1, -1, -1):
            if not self.is_empty((rk, c), turn):
                break
            p = self.board[1 - turn, 7 - rk, c]
            if p == ROOK or p == QUEEN:
                return True

        # CROSS DOWN
        for r in range(rk + 1, 8):
            # RIGHT
            d = r - rk
            for c in [ck + d, ck - d]:
                if not self.is_in_range((r, c)):
                    continue

                if not self.is_empty((r, c), turn):
                    break

                p = self.board[1 - turn, 7 - r, c]
                if p == BISHOP or p == QUEEN:
                    return True
                
                if d == 1 and p == PAWN:
                    return True

        # CROSS UP
        for r in range(rk - 1, -1, -1):
            d = r - rk
            for c in [ck + d, ck - d]:
                if not self.is_in_range((r, c)):
                    continue

                if not self.is_empty((r, c), turn):
                    break

                p = self.board[1 - turn, 7 - r, c]
                if p == BISHOP or p == QUEEN:
                    return True


        # KNIGHTS
        for r, c in KNIGHT_moves:
            nr, nc = rk + r, ck + c
            if not self.is_in_range((nr, nc)):
                continue
            if self.board[1 - turn, 7 - nr, nc] == KNIGHT:
                return True

        return False

    def update_checks(self, rewards: list[int] = None, infos: list[set] = None):
        rewards = [0, 0] if rewards is None else rewards
        infos = [set(), set()] if infos is None else infos

        for turn in range(2):
            king_pos = self.get_pos_king(turn)
            is_check = self.is_check(king_pos, turn)
            self.checked[turn] = is_check
            if is_check:
                rewards[turn] += CHECK_LOSE_reward
                rewards[1 - turn] += CHECK_WIN_reward

                infos[turn].add(CHECK_LOSE)
                infos[1 - turn].add(CHECK_WIN)
                break
        return rewards, infos

    def update_check_mates(self, rewards: list[int] = None, infos: list[set] = None):
        rewards = [0, 0] if rewards is None else rewards
        infos = [set(), set()] if infos is None else infos

        for turn in range(2):
            _, _, actions = self.get_all_actions(turn)
            if np.sum(actions) == 0:
                self.done = True
                rewards[turn] += CHECK_MATE_LOSE_reward
                rewards[1 - turn] +=CHECK_MATE_WIN_reward

                infos[turn].add(CHECK_MATE_LOSE)
                infos[1 - turn].add(CHECK_MATE_WIN)
                break

        return rewards, infos

    def move_piece(self, current_pos: Cell, next_pos: Cell, turn: int):
        next_row, next_col = next_pos
        current_row, current_col = current_pos
        self.board[turn, next_row, next_col] = self.board[
            turn, current_row, current_col
        ]

        self.promote_pawn(next_pos, turn)
        self.board[turn, current_row, current_col] = EMPTY
        self.board[1 - turn, 7 - next_row, next_col] = EMPTY

        for (key, value) in self.pieces[turn].items():
            if value == tuple(current_pos):
                self.pieces[turn][key] = tuple(next_pos)

        for (key, value) in self.pieces[1 - turn].items():
            if value == (7 - next_pos[0], next_pos[1]):
                self.pieces[1 - turn][key] = None

        rewards = [MOVE_reward, MOVE_reward]
        rewards[1 - turn] *= 2

        return rewards, [set(), set()]

    def is_game_done(self):
        return self.done or (self.steps >= self.max_steps)

    def promote_pawn(self, pos: Cell, turn: int):
        row, col = pos
        if self.board[turn, row, col] == PAWN and row == 7:
            self.board[turn, row, col] = QUEEN

    def step(self, action: int):
        assert not self.is_game_done(), "the game is finished reset"
        assert action < 640, "action number must be less than 640"

        source_pos, possibles, actions_mask = self.get_all_actions(self.turn)
        assert actions_mask[action], f"Cannot Take This Action = {action}"
        rewards, infos = self.move_piece(
            source_pos[action], possibles[action], self.turn
        )
        rewards, infos = self.update_checks(rewards, infos)
        rewards, infos = self.update_check_mates(rewards, infos)

        self.turn = 1 - self.turn
        self.steps += 1
        return rewards, self.is_game_done(), infos

##### Buffer

In [19]:
class Buffer(ABC):
    def __init__(self, max_size: int, batch_size: int, shuffle: bool = True) -> None:
        super().__init__()
        self.shuffle = shuffle
        self.max_size = max_size
        self.batch_size = batch_size

    @abstractmethod
    def add(self, *args) -> None:
        pass

    @abstractmethod
    def clear(self) -> None:
        pass

    @abstractmethod
    def get_len(self) -> int:
        pass

    @abstractmethod
    def sample(self):
        pass

    def __len__(self):
        return self.get_len()

In [20]:
class Episode:
    def __init__(self) -> None:
        self.goals = []
        self.probs = []
        self.masks = []
        self.values = []
        self.states = []
        self.rewards = []
        self.actions = []

    def add(
        self,
        state: np.ndarray,
        reward: float,
        action,
        goal: bool,
        prob: float = None,
        value: float = None,
        masks: np.ndarray = None,
    ):
        self.goals.append(goal)
        self.states.append(state)
        self.rewards.append(reward)
        self.actions.append(action)

        if prob is not None:
            self.probs.append(prob)
        if value is not None:
            self.values.append(value)
        if masks is not None:
            self.masks.append(masks)

    def calc_advantage(self, gamma: float, gae_lambda: float) -> np.ndarray:
        n = len(self.rewards)
        advantages = np.zeros(n)
        for t in range(n - 1):
            discount = 1
            for k in range(t, n - 1):
                advantages[t] += (
                    discount
                    * (
                        self.rewards[k]
                        + gamma * self.values[k + 1] * (1 - int(self.goals[k]))
                    )
                    - self.values[k]
                )
                discount *= gamma * gae_lambda
        return list(advantages)

    def __len__(self):
        return len(self.goals)

    def total_reward(self) -> float:
        return sum(self.rewards)

In [21]:
class BufferPPO(Buffer):
    def __init__(
        self,
        max_size: int,
        batch_size: int,
        gamma: float,
        gae_lambda: float,
        shuffle: bool = True,
    ) -> None:
        super().__init__(max_size, batch_size, shuffle)
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.episodes = deque(maxlen=max_size)
        self.advantages = deque(maxlen=max_size)

    def add(self, episode: Episode):
        self.episodes.append(episode)
        self.advantages.append(episode.calc_advantage(self.gamma, self.gae_lambda))

    def clear(self) -> None:
        self.episodes.clear()
        self.advantages.clear()

    def get_len(self) -> int:
        return len(self.episodes)

    def sample(self):
        probs = sum(map(lambda x: x.probs, self.episodes), [])
        goals = sum(map(lambda x: x.goals, self.episodes), [])
        masks = sum(map(lambda x: x.masks, self.episodes), [])
        values = sum(map(lambda x: x.values, self.episodes), [])
        states = sum(map(lambda x: x.states, self.episodes), [])
        actions = sum(map(lambda x: x.actions, self.episodes), [])
        rewards = sum(map(lambda x: x.rewards, self.episodes), [])
        advantages = sum(self.advantages, [])

        batches = make_batch_ids(
            n=len(states), batch_size=self.batch_size, shuffle=self.shuffle
        )

        return (
            np.array(states),
            np.array(actions),
            np.array(rewards),
            np.array(goals),
            np.array(probs),
            np.array(values),
            np.array(masks),
            np.array(advantages),
            batches,
        )

##### Agent

###### Actor

In [22]:
class Actor(nn.Module):
    def __init__(
        self, state_dim: int, action_dim: int, hidden_layers: tuple[int]
    ) -> None:
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_layers = hidden_layers
        self.base_model = build_base_model(
            state_dim, hidden_layers, action_dim, nn.Softmax(dim=1)
        )

    def forward(self, states: T.Tensor, action_mask: T.Tensor):
        x = self.base_model(states)
        s = action_mask.sum(dim=1)
        l = ((x * (1 - action_mask)).sum(dim=1) / s).unsqueeze(1)
        x = (x + l) * action_mask
        return Categorical(x)

###### Critic

In [23]:
class Critic(nn.Module):
    def __init__(self, state_dim: int, hidden_layers: tuple[int]) -> None:
        super().__init__()
        self.state_dim = state_dim
        self.hidden_layers = hidden_layers
        self.model = build_base_model(state_dim, hidden_layers, 1)

    def forward(self, state: T.Tensor):
        return self.model(state)

###### PPO Agent

In [24]:
class PPO(Learning):
    def __init__(
        self,
        environment: gym.Env,
        hidden_layers: tuple[int],
        epochs: int,
        buffer_size: int,
        batch_size: int,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        policy_clip: float = 0.2,
        learning_rate: float = 0.003,
    ) -> None:
        super().__init__(environment, epochs, gamma, learning_rate)

        self.gae_lambda = gae_lambda
        self.policy_clip = policy_clip
        self.buffer = BufferPPO(
            gamma=gamma,
            max_size=buffer_size,
            batch_size=batch_size,
            gae_lambda=gae_lambda,
        )

        self.hidden_layers = hidden_layers
        self.actor = Actor(self.state_dim, self.action_dim, hidden_layers)
        self.critic = Critic(self.state_dim, hidden_layers)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=learning_rate)

        self.to(self.device)

    def take_action(self, state: np.ndarray, action_mask: np.ndarray):
        state = T.Tensor(state).unsqueeze(0).to(self.device)
        action_mask = T.Tensor(action_mask).unsqueeze(0).to(self.device)
        dist = self.actor(state, action_mask)
        action = dist.sample()
        probs = T.squeeze(dist.log_prob(action)).item()
        value = T.squeeze(self.critic(state)).item()
        action = T.squeeze(action).item()
        return action, probs, value

    def epoch(self):
        (
            states_arr,
            actions_arr,
            rewards_arr,
            goals_arr,
            old_probs_arr,
            values_arr,
            masks_arr,
            advantages_arr,
            batches,
        ) = self.buffer.sample()

        for batch in batches:
            masks = T.Tensor(masks_arr[batch]).to(self.device)
            values = T.Tensor(values_arr[batch]).to(self.device)
            states = T.Tensor(states_arr[batch]).to(self.device)
            actions = T.Tensor(actions_arr[batch]).to(self.device)
            old_probs = T.Tensor(old_probs_arr[batch]).to(self.device)
            advantages = T.Tensor(advantages_arr[batch]).to(self.device)

            dist = self.actor(states, masks)
            critic_value = T.squeeze(self.critic(states))

            new_probs = dist.log_prob(actions)
            prob_ratio = (new_probs - old_probs).exp()

            weighted_probs = advantages * prob_ratio
            weighted_clipped_probs = (
                T.clamp(prob_ratio, 1 - self.policy_clip, 1 + self.policy_clip)
                * advantages
            )

            actor_loss = -T.min(weighted_probs, weighted_clipped_probs).mean()
            critic_loss = ((advantages + values - critic_value) ** 2).mean()
            total_loss = actor_loss + 0.5 * critic_loss

            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            total_loss.backward()
            self.actor_optimizer.step()
            self.critic_optimizer.step()

    def learn(self):
        for epoch in tqdm(range(self.epochs), desc="PPO Learning...", ncols=64, leave=False):
            self.epoch()
        self.buffer.clear()

    def remember(self, episode: Episode):
        self.buffer.add(episode)

    def save(self, folder: str, name: str):
        T.save(self, os.path.join(folder, f"{name}.pt"))
        

##### Base Agent

In [25]:
class BaseAgent(ABC):
    def __init__(
        self,
        env: Chess,
        learner: Learning,
        episodes: int,
        train_on: int,
        result_folder: str,
    ) -> None:
        super().__init__()
        self.env = env
        self.learner = learner
        self.episodes = episodes
        self.train_on = train_on
        self.current_ep = 0
        self.result_folder = result_folder

        self.moves = np.zeros((2, episodes), dtype=np.uint32)
        self.rewards = np.zeros((2, episodes))
        self.mates_win = np.zeros((2, episodes), dtype=np.uint32)
        self.checks_win = np.zeros((2, episodes), dtype=np.uint32)
        self.mates_lose = np.zeros((2, episodes), dtype=np.uint32)
        self.checks_lose = np.zeros((2, episodes), dtype=np.uint32)

    def update_stats(self, infos: list[dict]):
        for turn, info in enumerate(infos):
            if CHECK_MATE_WIN in info:
                self.mates_win[turn, self.current_ep] += 1

            if CHECK_MATE_LOSE in info:
                self.mates_lose[turn, self.current_ep] += 1

            if CHECK_WIN in info:
                self.checks_win[turn, self.current_ep] += 1

            if CHECK_LOSE in info:
                self.checks_lose[turn, self.current_ep] += 1

    def take_action(self, turn: int, episode: Episode):
        mask = self.env.get_all_actions(turn)[-1]
        state = self.env.get_state(turn)

        action, prob, value = self.learner.take_action(state, mask)
        rewards, done, infos = self.env.step(action)
        self.moves[turn, self.current_ep] += 1

        self.update_stats(infos)
        goal = CHECK_MATE_WIN in infos[turn]
        episode.add(state, rewards[turn], action, goal, prob, value, mask)

        return done, [state, rewards, action, goal, prob, value, mask]

    def update_enemy(self, prev: list, episode: Episode, reward: int):
        if prev is None:
            return
        prev[1] = reward
        episode.add(*prev)

    def train_episode(self, render: bool):
        renders = []

        def render_fn():
            if self.env.render_mode != "human":
                renders.append(self.env.render())

        self.env.reset()
        episode_white = Episode()
        episode_black = Episode()
        white_data: list = None
        black_data: list = None
        render_fn()

        while True:
            done, white_data = self.take_action(WHITE, episode_white)
            self.update_enemy(black_data, episode_black, white_data[1][BLACK])
            render_fn()
            if done:
                break

            done, black_data = self.take_action(BLACK, episode_black)
            self.update_enemy(white_data, episode_white, black_data[1][WHITE])
            render_fn()
            if done:
                break

        self.add_episodes(episode_white, episode_black)
        self.rewards[BLACK, self.current_ep] = episode_black.total_reward()
        self.rewards[WHITE, self.current_ep] = episode_white.total_reward()

        if (render or self.env.done) and self.env.render_mode != "human":
            path = os.path.join(self.result_folder, "renders", f"episode_{self.current_ep}.mp4")
            save_to_video(path, np.array(renders))

    def log(self, episode: int):
        print(
            f"+ Episode {episode} Results [B | w]:",
            f"\t- Moves  = {self.moves[:, episode]}",
            f"\t- Reward = {self.rewards[:, episode]}",
            f"\t- Checks = {self.checks_win[:, episode]}",
            f"\t- Mates  = {self.mates_win[:, episode]}",
            "-" * 64,
            sep="\n",
        )

    def tqdm_postfix(self, episode: int):
        return {
            "episode": episode,
            "moves": self.moves[:, episode],
            "rewards": self.rewards[:, episode],
            "checks": self.checks_win[:, episode],
            "mates": self.mates_win[:, episode]
        }

    def train(self, render_each: int, save_on_learn: bool = True):
        for ep in (pbar := tqdm(range(self.episodes))):
            self.train_episode(ep % render_each == 0 or ep == self.episodes - 1)
            self.current_ep += 1
            pbar.set_postfix(self.tqdm_postfix(ep))
            if (ep + 1) % self.train_on == 0:
                self.learn()
                if save_on_learn:
                    self.save()

    def save(self):
        folder = self.result_folder
        np.save(os.path.join(folder, "moves.npy"), self.moves)
        np.save(os.path.join(folder, "rewards.npy"), self.rewards)
        np.save(os.path.join(folder, "mates_win.npy"), self.mates_win)
        np.save(os.path.join(folder, "mates_lose.npy"), self.mates_lose)
        np.save(os.path.join(folder, "checks_win.npy"), self.checks_win)
        np.save(os.path.join(folder, "checks_lose.npy"), self.checks_lose)
        self.save_learners()

    @abstractmethod
    def save_learners(self):
        pass

    @abstractmethod
    def learn(self):
        pass

    @abstractmethod
    def add_episodes(self, white: Episode, black: Episode) -> None:
        pass

##### Single Agent 

In [26]:

class SingleAgentChess(BaseAgent):
    def __init__(
        self,
        env: Chess,
        learner: Learning,
        episodes: int,
        train_on: int,
        result_folder: str,
    ) -> None:
        super().__init__(env, learner, episodes, train_on, result_folder)

    def add_episodes(self, white: Episode, black: Episode) -> None:
        self.learner.remember(white)
        self.learner.remember(black)

    def learn(self):
        self.learner.learn()

    def save_learners(self):
        self.learner.save(self.result_folder, "single_agent_ppo.pt")

##### Train

In [27]:
buffer_size = 32

In [None]:
chess = Chess(window_size=512, max_steps=128, render_mode="rgb_array")
chess.reset()
ppo = PPO(
    chess,
    hidden_layers=(2048,) * 4,
    epochs=100,
    buffer_size=buffer_size * 2,
    batch_size=128,
)

print(ppo.device)
print(ppo)
print("-" * 64)

agent = SingleAgentChess(
    env=chess,
    learner=ppo,
    episodes=2000,
    train_on=buffer_size,
    result_folder="./results/SingleAgent",
)
agent.train(render_each=20, save_on_learn=True)
agent.save()
chess.close()

In [12]:
def init_pieces():
        pieces = {
            "pawn_1": (1, 0),
            "pawn_2": (1, 1),
            "pawn_3": (1, 2),
            "pawn_4": (1, 3),
            "pawn_5": (1, 4),
            "pawn_6": (1, 5),
            "pawn_7": (1, 6),
            "pawn_8": (1, 7),
            "rook_1": (0, 0),
            "rook_2": (0, 7),
            "knight_1": (0, 1),
            "knight_2": (0, 6),
            "bishop_1": (0, 2),
            "bishop_2": (0, 5),
            "queen": (0, 3),
            "king": (0, 4),
        }

        return [pieces.copy(), pieces.copy()]
pieces =init_pieces()

def get_pieces_names() -> set:
        return list(pieces[0].keys())
pieces_names = get_pieces_names()

def init_board() -> np.ndarray:
    board = np.zeros((2, 8, 8), dtype=np.uint8)
    board[:, 0, 3] = QUEEN
    board[:, 0, 4] = KING
    board[:, 1, :] = PAWN
    board[:, 0, (0, 7)] = ROOK
    board[:, 0, (1, 6)] = KNIGHT
    board[:, 0, (2, 5)] = BISHOP
    return board
board = init_board()

def get_size(name: str):
        return POSSIBLE_MOVES[name]

def get_source_pos(name: str, turn: int):
        cat = name.split("_")[0]
        pos = pieces[turn][name]
        if pos is None:
            pos = (0, 0)
        size = get_size(cat)
        return np.array([pos] * size)


def get_empty_actions(name: str):
    size = get_size(name)
    possibles = np.zeros((size, 2), dtype=np.int32)
    actions_mask = np.zeros((size), dtype=np.int32)
    return possibles, actions_mask

def is_in_range( pos: Cell) -> bool:
    row, col = pos
    return row >= 0 and row <= 7 and col >= 0 and col <= 7

def is_empty( pos: Cell, turn: int) -> bool:
        return board[turn, pos[0], pos[1]] == EMPTY

def is_enemy_king(pos: Cell, turn: int) -> bool:
        r, c = pos
        return board[1 - turn, 7 - r, c] == KING

def piece_can_jump( pos: Cell, turn: int) -> bool:
        jumps = {KNIGHT, KING}
        piece = board[turn, pos[0], pos[1]]
        return piece in jumps
def both_side_empty(pos: Cell, turn: int) -> bool:
    r, c = pos
    return is_empty(pos, turn) and is_empty((7 - r, c), 1 - turn)

def is_path_empty(current_pos: Cell, next_pos: Cell, turn: int) -> bool:
        next_row, next_col = next_pos
        current_row, current_col = current_pos

        diff_row = next_row - current_row
        diff_col = next_col - current_col
        sign_row = np.sign(next_row - current_row)
        sign_col = np.sign(next_col - current_col)

        size = max(abs(diff_row), abs(diff_col)) - 1
        rows = np.zeros(size, dtype=np.int32) + next_row
        cols = np.zeros(size, dtype=np.int32) + next_col

        if diff_row:
            rows = np.arange(current_row + sign_row, next_row, sign_row, dtype=np.int32)

        if diff_col:
            cols = np.arange(current_col + sign_col, next_col, sign_col, dtype=np.int32)

        for pos in zip(rows, cols):
            if not both_side_empty(tuple(pos), turn):
                return False

        return True

def general_validation(
    current_pos: Cell,
    next_pos: Cell,
    turn: int,
    deny_enemy_king: bool,
) -> bool:
    if not is_in_range(next_pos):
        return False

    if not is_empty(next_pos, turn):
        return False

    if is_enemy_king(next_pos, turn) and (not deny_enemy_king):
        return False

    if (not piece_can_jump(current_pos, turn)) and (
        not is_path_empty(current_pos, next_pos, turn)
    ):
        return False

    return True

def is_lead_to_check( current_pos: int, next_pos: int, turn: int) -> bool:
    temp = Chess(render_mode="rgb_array")
    temp.board = np.copy(board)
    temp.move_piece(current_pos, next_pos, turn)
    return temp.is_check(temp.get_pos_king(turn), turn)

def is_valid_move(
    current_pos: Cell,
    next_pos: Cell,
    turn: int,
    deny_enemy_king: bool,
) -> bool:
    if not general_validation(current_pos, next_pos, turn, deny_enemy_king):
        return False
    if is_lead_to_check(current_pos, next_pos, turn):
        return False
    return True

def check_for_enemy(pos: Cell, turn: int) -> bool:
    r, c = pos
    return not is_empty((7 - r, c), 1 - turn)

def get_pos_king(turn: int) -> Cell:
        row, col = np.where(board[turn] == KING)
        return row[0], col[0]

def is_neighbor_enemy_king(pos: Cell, turn: int) -> bool:
    row, col = pos
    row_enemy_king, col_enemy_king = get_pos_king(1 - turn)
    row_enemy_king = 7 - row_enemy_king
    diff_row = abs(row - row_enemy_king)
    diff_col = abs(col - col_enemy_king)
    return diff_row <= 1 and diff_col <= 1


def get_actions_for_bishop(
    pos: Cell, turn: int, deny_enemy_king: bool = False
):
    possibles, actions_mask = get_empty_actions("bishop")
    if pos is None:
        return possibles, actions_mask

    row, col = pos
    for i, (r, c) in enumerate(BISHOP_moves):
        next_pos = (row + r, col + c)

        if not is_valid_move(pos, next_pos, turn, deny_enemy_king):
            continue

        possibles[i] = next_pos
        actions_mask[i] = 1

    return possibles, actions_mask

def get_actions_for_rook(pos: Cell, turn: int, deny_enemy_king: bool = False):
    possibles, actions_mask = get_empty_actions("rook")
    if pos is None:
        return possibles, actions_mask

    row, col = pos
    for i, (r, c) in enumerate(ROOK_moves):
        next_pos = (row + r, col + c)

        if not is_valid_move(pos, next_pos, turn, deny_enemy_king):
            continue

        possibles[i] = next_pos
        actions_mask[i] = 1

    return possibles, actions_mask

def get_action_for_queen(pos: Cell, turn: int, deny_enemy_king: bool = False):
    possibles_rook, actions_mask_rook = get_actions_for_rook(
        pos, turn, deny_enemy_king
    )
    possibles_bishop, actions_mask_bishop = get_actions_for_bishop(
        pos, turn, deny_enemy_king
    )
    possibles = np.concatenate([possibles_bishop, possibles_rook])
    actions_mask = np.concatenate([actions_mask_bishop, actions_mask_rook])

    return possibles, actions_mask

def get_actions_for_pawn(pos: Cell, turn: int, deny_enemy_king: bool = False):
    possibles, actions_mask = get_empty_actions("pawn")
    if pos is None:
        return possibles, actions_mask

    row, col = pos
    if board[turn, row, col] == QUEEN:
        return get_action_for_queen(pos, turn)

    for i, (r, c) in enumerate(PAWN_moves[:4]):
        next_pos = (row + r, col + c)

        if not is_valid_move(pos, next_pos, turn, deny_enemy_king):
            continue

        can_moves = (
            (r == 1 and c == 0 and both_side_empty(next_pos, turn)),
            (r == 2 and row == 1 and both_side_empty(next_pos, turn)),
            (r == 1 and abs(c) == 1 and check_for_enemy(next_pos, turn)),
            # TODO: EN PASSANT
        )

        if True in can_moves:
            possibles[i] = next_pos
            actions_mask[i] = 1

    return possibles, actions_mask

def get_actions_for_knight(
    pos: Cell, turn: int, deny_enemy_king: bool = False
):
    possibles, actions_mask = get_empty_actions("knight")

    if pos is None:
        return possibles, actions_mask

    row, col = pos
    for i, (r, c) in enumerate(KNIGHT_moves):
        next_pos = (row + r, col + c)
        if not is_valid_move(pos, next_pos, turn, deny_enemy_king):
            continue

        possibles[i] = next_pos
        actions_mask[i] = 1

    return possibles, actions_mask

def get_actions_for_king(pos: Cell, turn: int):
    pos
    row, col = pos
    possibles, actions_mask = get_empty_actions("king")

    for i, (r, c) in enumerate(KING_moves):
        next_pos = (row + r, col + c)

        if not is_valid_move(pos, next_pos, turn, False):
            continue

        if is_neighbor_enemy_king(next_pos, turn):
            continue

        possibles[i] = next_pos
        actions_mask[i] = 1

    return possibles, actions_mask
def get_actions_for(name: str, turn: int, deny_enemy_king: bool = False):
        assert name in pieces_names, f"name not in {pieces_names}"
        piece_cat = name.split("_")[0]
        piece_pos = pieces[turn][name]
        src_poses = get_source_pos(name, turn)

        if piece_cat == "pawn":
            return (
                src_poses,
                *get_actions_for_pawn(piece_pos, turn, deny_enemy_king),
            )

        if piece_cat == "knight":
            return (
                src_poses,
                *get_actions_for_knight(piece_pos, turn, deny_enemy_king),
            )

        if piece_cat == "rook":
            return (
                src_poses,
                *get_actions_for_rook(piece_pos, turn, deny_enemy_king),
            )

        if piece_cat == "bishop":
            return (
                src_poses,
                *get_actions_for_bishop(piece_pos, turn, deny_enemy_king),
            )

        if piece_cat == "queen":
            return (
                src_poses,
                *get_action_for_queen(piece_pos, turn, deny_enemy_king),
            )

        if piece_cat == "king":
            return (
                src_poses,
                *get_actions_for_king(piece_pos, turn),
            )
def get_all_actions(turn: int, deny_enemy_king: bool = False):
        all_possibles = []
        all_source_pos = []
        all_actions_mask = []
        for name in pieces[turn].keys():
            # DENY ENEMY KING == FOR CHECKMATE VALIDATION ONLY SO ....
            if name == "king" and deny_enemy_king:
                continue

            source_pos, possibles, actions_mask = get_actions_for(
                name, turn, deny_enemy_king
            )
            all_source_pos.append(source_pos)
            all_possibles.append(possibles)
            all_actions_mask.append(actions_mask)

        return (
            np.concatenate(all_source_pos),
            np.concatenate(all_possibles),
            np.concatenate(all_actions_mask),
        )

NameError: name 'QUEEN' is not defined

In [None]:
%%writefile main.py
from Chessnut import Game
import random

def chess_bot(obs):
    """
    Simple chess bot that prioritizes checkmates, then captures, queen promotions, then randomly moves.

    Args:
        obs: An object with a 'board' attribute representing the current board state as a FEN string.

    Returns:
        A string representing the chosen move in UCI notation (e.g., "e2e4")
    """
    # 0. Parse the current board state and generate legal moves using Chessnut library
    game = Game(obs.board)
    moves = list(game.get_moves())
    print(game.get_fen().split(' ')[1])
    game.get_fen()
    
    # 1. convert chessnut observation to local 

    for move in moves[:10]:
        g = Game(obs.board)
        g.apply_move(move)
        if g.status == Game.CHECKMATE:
            return move

    # 2. Check for captures
    for move in moves:
        if game.board.get_piece(Game.xy2i(move[2:4])) != ' ':
            return move

    # 3. Check for queen promotions
    for move in moves:
        if "q" in move.lower():
            return move

    # 4. Random move if no checkmates or captures
    return random.choice(moves)
    

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

A = 0.25
B = "tab:blue"
W = "tab:orange"


def ma(arr, count):
    l = len(arr)
    m = []
    for i in range(count, l):
        j = i - count
        m.append(np.mean(arr[j:i]))
    return np.array(m)


def plot(ax, arr, title, episodes=-1, alpha=A, legend=True):
    ax.set_title(title)
    ax.set_xlim([0, episodes])
    ax.set_xlabel("Episode")
    ax.set_ylabel("Value")
    for i in range(2):
        l = "White" if i else "Black"
        c = W if i else B
        ax.plot(arr[i, :episodes], label=l, alpha=alpha, c=c)
    if legend:
        ax.legend()
        ax.grid()
    return ax


def plot_ma(ax, arr, episodes=-1, count: int = 50):
    for i in range(2):
        c = W if i else B
        ax.plot(range(count, episodes), ma(arr[i, :episodes], count), c=c, alpha=1)
    return ax


def bar(ax, arr, title, episodes, alpha=A):
    ax.set_title(title)
    ax.set_xlabel("Episode")
    ax.set_xlim([0, lst])
    ax.set_ylabel("Value")

    for i in range(2):
        l = "White" if i else "Black"
        h = arr[i, :episodes]
        x = range(lst)
        ax.bar(x, h, label=l, alpha=alpha)
    ax.legend()
    ax.grid()
    return ax


def plot_moves(ax, moves, episodes, count: int = 50):
    arr = moves.sum(axis=0)[:episodes]
    ax.plot(arr, alpha=A, c=B)
    ax.plot(range(count, episodes), ma(arr, count), alpha=1, c=B)
    ax.set_title("Total Moves")
    ax.set_xlim([0, episodes])
    ax.set_xlabel("Episode")
    ax.grid()


def density(arr, count, episode):
    a = arr.max(axis=0)
    return [np.sum(a[max(0, i - count) : i]) / count for i in range(episode)]


def plot_check_mates(
    ax, check_mates_arr: np.ndarray, episodes: int, count_density: int
):
    #     ax.plot(check_mates_arr.max(axis=0)[:episodes], alpha=0.25)
    density_ax = ax.twinx()
    density_arr = density(check_mates_arr, count_density, episodes)
    density_ax.plot(
        range(episodes),
        density_arr,
        color="tab:green",
        alpha=1,
        label=f"total check mates rate for {count_density} episodes",
        linewidth=2,
    )
    density_ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    density_ax.legend()
    density_ax.grid()
    plot(ax, check_mates_arr, "Check Mates", episodes, alpha=0.25, legend=False)


ALPHA = 0.25
COUNT = 512  # 512
for name in ["Double Agents", "Single Agent"]:
    print(name, "...")
    folder = "".join(name.split(" "))
    folder = f"results/{folder}"
    moves = np.load(f"{folder}/moves.npy")
    mates = np.load(f"{folder}/mates_win.npy")
    checks = np.load(f"{folder}/checks_win.npy")
    rewards = np.load(f"{folder}/rewards.npy")
    episodes = np.max(np.where(moves[0] != 0)) + 1

    fig, axs = plt.subplots(2, 2, figsize=(20, 12), dpi=200)
    fig.suptitle(f"{name} | {episodes} Episodes")

    plot(axs[0, 0], rewards, "Rewards", episodes, alpha=ALPHA)
    plot_ma(axs[0, 0], rewards, episodes, count=32)

    plot_moves(axs[0, 1], moves, episodes, count=32)

    plot(axs[1, 0], checks, "Checks", episodes, alpha=ALPHA)
    plot_ma(axs[1, 0], checks, episodes, count=32)

    plot_check_mates(axs[1, 1], mates, episodes, COUNT)

    fig.tight_layout()
    fig.savefig(f"{folder}/plots.jpeg")
    

##### Testing

In [None]:
import pygame
from time import sleep
import numpy as np
import random
import sys
device = "cuda"
#sys.setrecursionlimit(100)
env = Chess(window_size=600)
env.render()
ppo = PPO(
    env,
    hidden_layers=(2048,) * 4,
    epochs=100,
    buffer_size=32 * 2,
    batch_size=128,
)
ppo=T.load("results/SingleAgent/single_agent_ppo.pt.pt")
actor = Actor(env.observation_space.shape[0],env.action_space.n,hidden_layers=(2048,) * 4)
#actor.load_state_dict(T.load("results/SingleAgent/single_agent_ppo.pt.pt"))
actor = ppo.actor

running = True
while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
        # if event.type == pygame.KEYDOWN:
        #     if event.key != pygame.K_SPACE:
        #         continue
        #     else:
    turn = env.turn
    print("White" if turn else "Black")
    src, dst, mask = env.get_all_actions(turn)

    #action = random.sample(list(np.where(mask == 1)[0]), 1)[0]
    state = T.Tensor(env.get_state(turn)).unsqueeze(0).to(device)
    action_mask = T.Tensor(env.get_all_actions(turn)[-1]).unsqueeze(0).to(device)
    print(action_mask.shape)
    dist = actor(state,action_mask)
    action = dist.sample()
    action = T.squeeze(action).item()
    #print(action)
    #print(f"Action = {action}", src[action], dst[action])
    rewards, done, infos = env.step(action)
    #print(f"Rewards = {rewards}")
    #print(f"Infos = {infos}")
    #print("-" * 64)
    env.render()
    if done:
        break
        env.reset()

        print("RESET")


#env.close()

### Testing your agent

Now let's see how your agent does againt the random agent!

In [None]:
result = env.run(["main.py", "random"])
print("Agent exit status/reward/time left: ")
# look at the generated replay.json and print out the agent info
for agent in result[-1]:
    print("\t", agent.status, "/", agent.reward, "/", agent.observation.remainingOverageTime)
print("\n")
# render the game
env.render(mode="ipython", width=1000, height=1000) 

# To Submit:
1. Download (or save) main.py
2. Go to the [submissions page](https://www.kaggle.com/competitions/fide-google-efficiency-chess-ai-challenge/submissions) and click "Submit Agent"
3. Upload main.py
4. Press Submit!

Now doubt you are already thinking of ways this bot could be improved! Go ahead and fork this notebook and get started! ♟️

# Submitting Multiple files 
### (or compressing your main.py)

Set up your directory structure like this:
```
kaggle_submissions/
  main.py
  <other files as desired>
```

You can run `tar -czf submission.tar.gz -C kaggle_submissions .` and upload `submission.tar.gz`