In [None]:
import chess
import numpy as np
import torch
import torch.nn as nn
import random
from collections import deque
import torch.optim as optim

class ChessEnvironment:
    def __init__(self):
        self.board = chess.Board()
        self.state_size = (12, 8, 8)
        self.action_size = 8 * 8 * 8 * 8
        self.current_player = chess.WHITE

    def reset(self, fen=None):
        if fen:
            self.board.set_fen(fen)
        else:
            self.board.reset()
        self.current_player = self.board.turn
        return self._get_state()

    def _get_state(self):
        state = np.zeros((13, 8, 8), dtype=np.float32)
        for square in chess.SQUARES:
            piece = self.board.piece_at(square)
            if piece is not None:
                channel = self._piece_to_channel(piece)
                row, col = self._square_to_coords(square)
                state[channel, row, col] = 1.0

        state[12, :, :] = 1.0 if self.board.turn == chess.WHITE else 0.0

        if self.current_player == chess.BLACK:
            state = np.flip(state, axis=(1, 2))  # Flip board for Black perspective

        return state

    def _piece_to_channel(self, piece):
        piece_type = piece.piece_type - 1
        color_offset = 0 if piece.color == chess.WHITE else 6
        return piece_type + color_offset

    def _square_to_coords(self, square):
        row = 7 - chess.square_rank(square)
        col = chess.square_file(square)
        return row, col

    def get_legal_moves(self):
        legal_moves = []
        for move in self.board.legal_moves:
            row1, col1 = self._square_to_coords(move.from_square)
            row2, col2 = self._square_to_coords(move.to_square)
            legal_moves.append((row1, col1, row2, col2))
        return legal_moves

    def step(self, action):
        row1, col1, row2, col2 = action
        from_square = self._coords_to_square(row1, col1)
        to_square = self._coords_to_square(row2, col2)

        piece = self.board.piece_at(from_square)
        if piece and piece.piece_type == chess.PAWN:
            if (piece.color == chess.WHITE and row2 == 0) or (piece.color == chess.BLACK and row2 == 7):
                move = chess.Move(from_square, to_square, promotion=chess.QUEEN)
            else:
                move = chess.Move(from_square, to_square)
        else:
            move = chess.Move(from_square, to_square)

        if move not in self.board.legal_moves:
            raise ValueError(f"Illegal move attempted: {move}")

        self.board.push(move)
        new_state = self._get_state()
        done = self.board.is_game_over()
        reward = self.calculate_reward(done)

        return new_state, reward, done

    def _coords_to_square(self, row, col):
        return chess.square(col, 7 - row)

    '''def calculate_reward(env, agent_color):
        """
        Calculate the reward based on the game outcome and the agent's color.
        :param env: The ChessEnvironment object.
        :param agent_color: The color the agent is playing as (chess.WHITE or chess.BLACK).
        :return: The reward for the agent.
        """
        if env.board.is_checkmate():
            # Agent wins if it checkmates the opponent
            if env.board.turn != agent_color:
                return 1.0  # Agent wins
            else:
                return -1.0  # Agent loses
        elif env.board.is_stalemate() or env.board.is_insufficient_material():
            return 0.0  # Draw
        else:
            return 0.0  # Game is still ongoing
    '''
    
    def calculate_reward(env, blah):
        """
        Calculate the reward based on the game outcome.
        :param env: The ChessEnvironment object.
        :return: A tuple (white_reward, black_reward).
        """
        if env.board.is_checkmate():
            # White wins if it's Black's turn (Black is checkmated)
            if env.board.turn == chess.BLACK:
                return [1.0, -1.0]  # White wins, Black loses
            else:
                return [-1.0, 1.0]  # Black wins, White loses
        elif env.board.is_stalemate() or env.board.is_insufficient_material():
            return [0.0, 0.0]  # Draw
        else:
            return [0.0, 0.0]  # Game is still ongoing
           

    def render(self):
        print(self.board)

# Keep everything else unchanged


In [None]:
class WhiteChessCNN(nn.Module):
    def __init__(self, action_size):
        super(WhiteChessCNN, self).__init__()
        self.conv1 = nn.Conv2d(13, 32, kernel_size=3, stride=1, padding=1)  # Input: 13 channels
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, action_size)

    def forward(self, x):
        x = torch.relu(self.conv1(x))  # Shape: [batch_size, 32, 8, 8]
        x = torch.relu(self.conv2(x))  # Shape: [batch_size, 64, 8, 8]
        x = x.view(x.size(0), -1)  # Flatten: [batch_size, 64 * 8 * 8]
        x = torch.relu(self.fc1(x))  # Shape: [batch_size, 128]
        return self.fc2(x)  # Shape: [batch_size, action_size]

class BlackChessCNN(nn.Module):
    def __init__(self, action_size):
        super(BlackChessCNN, self).__init__()
        self.conv1 = nn.Conv2d(13, 32, kernel_size=3, stride=1, padding=1)  # Input: 13 channels
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, action_size)

    def forward(self, x):
        x = torch.relu(self.conv1(x))  # Shape: [batch_size, 32, 8, 8]
        x = torch.relu(self.conv2(x))  # Shape: [batch_size, 64, 8, 8]
        x = x.view(x.size(0), -1)  # Flatten: [batch_size, 64 * 8 * 8]
        x = torch.relu(self.fc1(x))  # Shape: [batch_size, 128]
        return self.fc2(x)  # Shape: [batch_size, action_size]



In [None]:
def select_action(state, q_network, env, epsilon):
    """Select an action using epsilon-greedy strategy, ensuring only legal moves are considered."""
    legal_moves = env.get_legal_moves()  # Get all legal moves as tuples (row1, col1, row2, col2)

    if random.random() < epsilon:
        # Exploration: choose a random legal move
        return random.choice(legal_moves)
    else:
        # Exploitation: choose the best legal move using the Q-network
        with torch.no_grad():
            q_values = q_network(state).view(8, 8, 8, 8)

        # Filter Q-values so only legal moves remain
        legal_q_values = []
        for move in legal_moves:
            row1, col1, row2, col2 = move
            legal_q_values.append((move, q_values[row1, col1, row2, col2].item()))

        if not legal_q_values:
            raise ValueError("No legal moves available — this shouldn't happen!")

        # Choose the legal move with the highest Q-value
        best_move = max(legal_q_values, key=lambda x: x[1])[0]
        return best_move

def compute_loss(batch, q_network, target_network, gamma):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    """Compute the loss for a mini-batch of experiences."""
    states, actions, rewards, next_states, dones = zip(*batch)
    
    # Reshape states and next_states to remove extra dimension
    states = torch.tensor(np.array(states), dtype=torch.float32).squeeze(1).to(device)  # Remove extra dimension
    next_states = torch.tensor(np.array(next_states), dtype=torch.float32).squeeze(1).to(device)  # Remove extra dimension
    rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
    dones = torch.tensor(dones, dtype=torch.float32).to(device)

    # Compute Q-values for current states
    q_values = q_network(states)

    # Convert actions to a single index for gathering
    action_indices = [np.ravel_multi_index(action, (8, 8, 8, 8)) for action in actions]  
    # Convert to PyTorch tensor
    action_indices = torch.tensor(action_indices, dtype=torch.long).unsqueeze(1).to(device)
    # Now gather the Q-values
    q_values = q_values.gather(1, action_indices)

    # Compute target Q-values
    with torch.no_grad():
        next_q_values = target_network(next_states)
        next_q_values = next_q_values.max(dim=1)[0]
        targets = rewards + gamma * next_q_values * (1 - dones)

    # Compute MSE loss
    loss = nn.functional.mse_loss(q_values.squeeze(1), targets)  # squeeze to match dimensions
    return loss

In [None]:
# Initialize environment and get a sample state
env = ChessEnvironment()

# Initialize the CNN
action_size = 8 * 8 * 8 * 8  # Maximum number of possible moves

white_q_network = WhiteChessCNN(action_size)
white_target_network = WhiteChessCNN(action_size)
white_target_network.load_state_dict(white_q_network.state_dict())
white_target_network.eval()
black_q_network = BlackChessCNN(action_size)
black_target_network = BlackChessCNN(action_size)
black_target_network.load_state_dict(black_q_network.state_dict())
black_target_network.eval()

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move models to GPU
white_q_network.to(device)
black_q_network.to(device)
white_target_network.to(device)
black_target_network.to(device)

white_optimizer = optim.Adam(white_q_network.parameters(), lr=0.001)
black_optimizer = optim.Adam(black_q_network.parameters(), lr=0.001)
white_replay_buffer = deque(maxlen=100000)
black_replay_buffer = deque(maxlen=100000)

In [None]:
epsilon = .9  # Initial exploration rate
epsilon_min = 0.01  # Minimum exploration rate
epsilon_decay = 0.995  # Decay rate for exploration
gamma = 0.99  # Discount factor
batch_size = 64  # Mini-batch size
target_update_freq = 100  # Frequency of updating the target network
num_episodes = 100  # Number of episodes to train
queens_gambit_fen = "rnbqkbnr/ppp1pppp/8/3p4/2PP4/8/PP2PPPP/RNBQKBNR w KQkq - 0 3"
ruy_lopez_fen = "r1bqkbnr/pppp1ppp/2n5/1B2p3/4P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 4"
italian_game_fen = "r1bqkbnr/pppp1ppp/2n5/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 4"



 
for episode in range(num_episodes):
    env = ChessEnvironment()
    state = env.reset(fen=italian_game_fen)
    state = np.expand_dims(state, axis=0).copy()  # Add batch dimension
    done = False
    white_total_reward = 0
    black_total_reward = 0

    while not done:
        # White's turn
        if env.board.turn == chess.WHITE:
            action = select_action(torch.tensor(state, dtype=torch.float32).to(device), white_q_network, env, epsilon)
            next_state, _, done = env.step(action)  # Immediate reward is 0 for non-terminal states
            next_state = np.expand_dims(next_state, axis=0).copy()  # Add batch dimension

            # Store experience in White's replay buffer
            white_replay_buffer.append((state, action, 0.0, next_state, done))  # Immediate reward is 0

            # Update state
            state = next_state

        # Black's turn
        else:
            action = select_action(torch.tensor(state, dtype=torch.float32).to(device), black_q_network, env, epsilon)
            next_state, _, done = env.step(action)  # Immediate reward is 0 for non-terminal states
            next_state = np.expand_dims(next_state, axis=0).copy()  # Add batch dimension

            # Store experience in Black's replay buffer
            black_replay_buffer.append((state, action, 0.0, next_state, done))  # Immediate reward is 0

            # Update state
            state = next_state

    # Calculate the reward at the end of the game
    if env.board.is_checkmate():
        if env.board.turn == chess.BLACK:
            white_reward, black_reward = 1.0, -1.0  # White wins, Black loses
        else:
            white_reward, black_reward = -1.0, 1.0  # Black wins, White loses
    elif env.board.is_stalemate() or env.board.is_insufficient_material():
        white_reward, black_reward = 0.0, 0.0  # Draw
    else:
        white_reward, black_reward = 0.0, 0.0  # Game is still ongoing (should not happen)

    # Update the rewards in the replay buffers for terminal states
    for i in range(len(white_replay_buffer)):
        state, action, _, next_state, done = white_replay_buffer[i]
        if done:
            white_replay_buffer[i] = (state, action, white_reward, next_state, done)

    for i in range(len(black_replay_buffer)):
        state, action, _, next_state, done = black_replay_buffer[i]
        if done:
            black_replay_buffer[i] = (state, action, black_reward, next_state, done)

    # Update total rewards
    white_total_reward += white_reward
    black_total_reward += black_reward

    # Train the White model
    if len(white_replay_buffer) > batch_size:
        batch = random.sample(white_replay_buffer, batch_size)
        loss = compute_loss(batch, white_q_network, white_target_network, gamma)
        white_optimizer.zero_grad()
        loss.backward()
        white_optimizer.step()

    # Train the Black model
    if len(black_replay_buffer) > batch_size:
        batch = random.sample(black_replay_buffer, batch_size)
        loss = compute_loss(batch, black_q_network, black_target_network, gamma)
        black_optimizer.zero_grad()
        loss.backward()
        black_optimizer.step()

    # Decay epsilon
    epsilon = max(epsilon_min, epsilon * epsilon_decay)

    # Update target networks
    if episode % target_update_freq == 0:
        white_target_network.load_state_dict(white_q_network.state_dict())
        black_target_network.load_state_dict(black_q_network.state_dict())

    # Log progress
    print(f"Episode: {episode + 1}, White Reward: {white_total_reward}, Black Reward: {black_total_reward}, Epsilon: {epsilon:.2f}")