In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import chess
import random


class ChessBoard:
    def __init__(self):
        self.board = chess.Board()

    def get_state(self):
        state = []
        for square in chess.SQUARES:
            piece = self.board.piece_at(square)
            if piece:
                value = piece.piece_type * (1 if piece.color else -1)
                state.append(value)
            else:
                state.append(0)
        return torch.tensor(state, dtype=torch.float32)

    def __str__(self):
        board_str = ""
        for row in range(8):
            board_str += str(8 - row) + " | "
            for col in range(8):
                square = chess.square(col, 7 - row)
                piece = self.board.piece_at(square)
                if piece:
                    if piece.color == chess.WHITE:
                        board_str += piece.symbol().upper() + "  "
                    else:
                        board_str += piece.symbol().lower() + "  "
                else:
                    board_str += ".  "
            board_str += "\n"
        board_str += "  | a  b  c  d  e  f  g  h"
        return board_str

    def print_board(self):
        print(self)

    def copy(self):
        new_board = ChessBoard()
        new_board.board = self.board.copy()
        return new_board

    def get_legal_moves(self):
        return list(self.board.legal_moves)

    def make_move(self, move):
        self.board.push(move)

    def is_game_over(self):
        return self.board.is_game_over()

    def get_result(self):
        if self.board.is_checkmate():
            return 1 if self.board.turn == chess.BLACK else -1
        elif (
            self.board.is_stalemate()
            or self.board.is_insufficient_material()
            or self.board.is_seventyfive_moves()
            or self.board.is_fivefold_repetition()
            or self.board.is_variant_draw()
        ):
            return 0
        else:
            return None


class ChessNet(nn.Module):
    def __init__(self):
        super(ChessNet, self).__init__()
        self.fc1 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)

    def forward(self, x):
        x = x.to(self.device)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        return x


def select_move(board, model):
    legal_moves = board.get_legal_moves()
    if not legal_moves:
        return None
    if random.random() < 0.1:
        return random.choice(legal_moves)
    else:
        best_move = None
        best_score = -float("inf") if board.board.turn == chess.WHITE else float("inf")
        for move in legal_moves:
            temp_board = board.copy()
            temp_board.make_move(move)
            state = temp_board.get_state().to(model.device)
            score = model(state).item()
            if move.promotion:
                move.promotion = chess.QUEEN
            if (board.board.turn == chess.WHITE and score > best_score) or (
                board.board.turn == chess.BLACK and score < best_score
            ):
                best_score = score
                best_move = move
        return best_move


def train_model(model, num_episodes=1000, learning_rate=0.001):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    for episode in range(num_episodes):
        board = ChessBoard()
        states, rewards = [], []
        while not board.is_game_over():
            state = board.get_state().to(model.device)
            move = select_move(board, model)
            if move is None:
                break
            board.make_move(move)
            states.append(state)

        result = board.get_result()
        if result is None:
            rewards = [0] * len(states)
        else:
            rewards = [result * (-1) ** i for i in range(len(states))]

        for i in range(len(states)):
            state = states[i]
            reward = torch.tensor(rewards[i], dtype=torch.float32).to(model.device)
            score = model(state)
            loss = (score - reward) ** 2
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(
            f"Episode {episode + 1}/{num_episodes}, Result: {result}, Device: {model.device}"
        )


def play_game(model1, model2):
    board = ChessBoard()
    while not board.is_game_over():
        if board.board.turn == chess.WHITE:
            move = select_move(board, model1)
        else:
            move = select_move(board, model2)
        if move is None:
            break
        board.make_move(move)
        board.print_board()
    result = board.get_result()
    if result == 1:
        print("White wins!")
    elif result == -1:
        print("Black wins!")
    else:
        print("Draw!")


def play_against_human(model):
    board = ChessBoard()
    while not board.is_game_over():
        board.print_board()
        if board.board.turn == chess.WHITE:
            while True:
                try:
                    move_str = input("Your move (white): ")
                    move = board.board.parse_uci(move_str)
                    if move in board.get_legal_moves():
                        if move.promotion is not None:
                            while True:
                                piece_str = input(
                                    "Choose promotion piece (Q, R, B, N):"
                                ).upper()
                                if piece_str == "Q":
                                    move.promotion = chess.QUEEN
                                    break
                                elif piece_str == "R":
                                    move.promotion = chess.ROOK
                                    break
                                elif piece_str == "B":
                                    move.promotion = chess.BISHOP
                                    break
                                elif piece_str == "N":
                                    move.promotion = chess.KNIGHT
                                    break
                                else:
                                    print(
                                        "Invalid promotion piece. Choose Q, R, B, or N."
                                    )
                        board.make_move(move)
                        break  # Valid move, exit the loop
                    else:  # Not a legal move
                        print("Illegal move. Please try again.")
                except ValueError:
                    print("Invalid move format. Use UCI notation (e.g., e2e4).")
        else:
            model_move = select_move(board, model)
            if model_move is None:
                break
            board.make_move(model_move)
            print(f"Model's move (black): {model_move.uci()}")

    board.print_board()
    result = board.get_result()
    if result == 1:
        print("White wins!")
    elif result == -1:
        print("Black wins!")
    else:
        print("Draw!")


def main():
    model = ChessNet()
    train_model(model, num_episodes=100)

    # play_game(model, ChessNet())

    # play_against_human(model)


if __name__ == "__main__":
    main()

Episode 1/100, Result: 0, Device: cuda
Episode 2/100, Result: 0, Device: cuda
Episode 3/100, Result: 0, Device: cuda
Episode 4/100, Result: 0, Device: cuda
Episode 5/100, Result: 0, Device: cuda
Episode 6/100, Result: 0, Device: cuda
Episode 7/100, Result: 0, Device: cuda
Episode 8/100, Result: 0, Device: cuda
Episode 9/100, Result: 0, Device: cuda
Episode 10/100, Result: 0, Device: cuda
Episode 11/100, Result: 0, Device: cuda
Episode 12/100, Result: 0, Device: cuda
Episode 13/100, Result: 0, Device: cuda
Episode 14/100, Result: 0, Device: cuda
Episode 15/100, Result: 0, Device: cuda
Episode 16/100, Result: 0, Device: cuda
Episode 17/100, Result: 0, Device: cuda
Episode 18/100, Result: -1, Device: cuda
Episode 19/100, Result: 0, Device: cuda
Episode 20/100, Result: 0, Device: cuda
Episode 21/100, Result: 0, Device: cuda
Episode 22/100, Result: 0, Device: cuda
Episode 23/100, Result: 0, Device: cuda
Episode 24/100, Result: 0, Device: cuda
Episode 25/100, Result: 0, Device: cuda
Episode 