In [11]:
# import torch
# tensor = torch.load('X_all.pt')
# print("Wczytany tensor:", tensor)
# print("Wymiary wczytanego tensora:", tensor.shape)

In [12]:
# print(tensor[0])

In [13]:
# flat = nn.Flatten()
# flatten = flat(tensor[:10])
# print(flatten.shape)
# out = rl(flatten)
# print(out.shape)
# print(out[0])
# print(max(out[0]))
# arg = torch.argmax(out[0])
# print(arg)
# print(out[0][arg])

In [14]:
# input_tensor = tensor[:1024]
# input_tensor.shape
# nonzero_mask = (input_tensor != 0).float()
# multiplied_indices = torch.arange(1, 7, device=input_tensor.device).unsqueeze(-1).unsqueeze(-1)
# result_tensor = input_tensor * multiplied_indices * nonzero_mask
# result_tensor_summed = torch.sum(result_tensor, dim=1)
# print(result_tensor_summed.shape) 
# result_tensor_summed[0]

In [15]:
# import chess

# boards = []

# layer_to_piece = {
#     1: chess.Piece(chess.PAWN, chess.WHITE),
#     2: chess.Piece(chess.KNIGHT, chess.WHITE),
#     3: chess.Piece(chess.BISHOP, chess.WHITE),
#     4: chess.Piece(chess.ROOK, chess.WHITE),
#     5: chess.Piece(chess.QUEEN, chess.WHITE),
#     6: chess.Piece(chess.KING, chess.WHITE),
#    -1: chess.Piece(chess.PAWN, chess.BLACK),
#    -2: chess.Piece(chess.KNIGHT, chess.BLACK),
#    -3: chess.Piece(chess.BISHOP, chess.BLACK),
#    -4: chess.Piece(chess.ROOK, chess.BLACK),
#    -5: chess.Piece(chess.QUEEN, chess.BLACK),
#    -6: chess.Piece(chess.KING, chess.BLACK),
#     0: None
# }

# for tensor in result_tensor_summed:
#     board = chess.Board.empty()
    
#     for row_idx, row in enumerate(tensor):
#         for col_idx, val in enumerate(row):
#             piece = layer_to_piece[int(val.item())]
            
#             if piece is not None:
#                 square = chess.square(col_idx, 7 - row_idx)
#                 board.set_piece_at(square, piece)
    
#     boards.append(board)


In [16]:
import torch.nn as nn
import torch

class rlClassifier(nn.Module):
    def __init__(
            self,
            input_size: int,
            output_size: int,
            layer_sizes: list[int],
            dropout: float=0.1):
        super(rlClassifier, self).__init__()

        layers = []
        flat = nn.Flatten(start_dim=1)
        layers.append(flat)
        old_size = input_size
        for layer in layer_sizes:
            layers.append(nn.Linear(old_size, layer))
            layers.append(nn.Dropout(dropout))
            layers.append(nn.ReLU())
            old_size = layer

        layers.append(nn.Linear(old_size, output_size))
        layers.append(nn.Softmax(dim=-1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = torch.unsqueeze(x, 0)
        return self.model(x)

    def get_action(self, state):
        probs = self.forward(state)
        actions = torch.multinomial(probs, 1).squeeze(1)
        return actions


    def reinforce_update(self, log_probs, rewards, optimizer):
        batch_rewards = torch.tensor(rewards, dtype=torch.float32)
        batch_rewards = batch_rewards.to('cuda')
        log_probs = torch.stack(log_probs).to('cuda')

        # Calculate the loss
        action_loss = - (log_probs * batch_rewards).sum()

        # Perform backpropagation
        optimizer.zero_grad()
        action_loss.backward()
        optimizer.step()


In [17]:
from functools import lru_cache
import chess

@lru_cache(maxsize=None)
def number_to_move(number) -> chess.Move:
    if number < 0 or number > 4095:
        raise ValueError("Number not from range(4095).")
    start_index = number // 64
    end_index = number % 64
    start_row = 8 - (start_index // 8)
    start_column = start_index % 8
    end_row = 8 - (end_index // 8)
    end_column = end_index % 8
    from_square = chess.square(start_column, start_row - 1)
    to_square = chess.square(end_column, end_row - 1)
    return from_square, to_square


In [18]:
class ChessEnv:

    def __init__(self) -> None:
        self.layer_to_piece = {
            1: chess.Piece(chess.PAWN, chess.WHITE),
            2: chess.Piece(chess.KNIGHT, chess.WHITE),
            3: chess.Piece(chess.BISHOP, chess.WHITE),
            4: chess.Piece(chess.ROOK, chess.WHITE),
            5: chess.Piece(chess.QUEEN, chess.WHITE),
            6: chess.Piece(chess.KING, chess.WHITE),
            -1: chess.Piece(chess.PAWN, chess.BLACK),
            -2: chess.Piece(chess.KNIGHT, chess.BLACK),
            -3: chess.Piece(chess.BISHOP, chess.BLACK),
            -4: chess.Piece(chess.ROOK, chess.BLACK),
            -5: chess.Piece(chess.QUEEN, chess.BLACK),
            -6: chess.Piece(chess.KING, chess.BLACK),
            0: None
        }
    
    def correct_move(self, move: tuple):
        correct_move = None
        for legal_move in self.board.legal_moves:
            if move[0] == legal_move.from_square and move[1] == legal_move.to_square:
                correct_move = legal_move
                break
        return correct_move
    
    def push_move(self, move: chess.Move):
        self.board.push(move)
        # print("-----------")
        # print(self.board)

    def reset(self, tensor):
        self.board = self.generate_board(tensor)
        self.tensor = tensor

    def update_tensors(self, move):
        new_tensor = self.tensor.clone()
        idx_end = self.board.piece_at(move.to_square).piece_type - 1
        idx_beg = idx_end if move.promotion is None else 0
        rank_beg = 7 - chess.square_rank(move.from_square)
        file_beg = chess.square_file(move.from_square)
        rank_end = 7 - chess.square_rank(move.to_square)
        file_end = chess.square_file(move.to_square)
        new_tensor[idx_beg][rank_beg][file_beg] = torch.tensor(0)
        new_tensor[idx_end][rank_end][file_end] = torch.tensor(1)
        self.tensor = new_tensor
            


    def generate_board(self, in_tensor) -> chess.Board:
        nonzero_mask = (in_tensor != 0).float()
        multiplied_indices = torch.arange(1, 7, device=in_tensor.device).unsqueeze(-1).unsqueeze(-1)
        result_tensor = in_tensor * multiplied_indices * nonzero_mask
        tensor = torch.sum(result_tensor, dim=0)
        board = chess.Board.empty()
        for row_idx, row in enumerate(tensor):
            for col_idx, val in enumerate(row):
                piece = self.layer_to_piece[int(val.item())]
                if piece is not None:
                    square = chess.square(col_idx, 7 - row_idx)
                    board.set_piece_at(square, piece)
        # print(board)
        return board
    

    def step(self, action):
        illegal_rew = -20
        legal_rew = 5
        fields = number_to_move(action)
        correct_move: chess.Move = self.correct_move(fields)
        if correct_move is None:
            return self.tensor, illegal_rew
        self.push_move(correct_move)
        self.board.turn = True
        self.update_tensors(correct_move)
        return self.tensor, legal_rew
        



In [19]:
def train_model(env, model, optimizer, batch, num_episodes=200, max_steps_per_episode=100):
    for i in range(num_episodes + 1):
        state = batch.clone().to('cuda')

        env.reset(state)
        batch_log_probs = []
        batch_rewards = []

        for step in range(max_steps_per_episode):
            action = model.get_action(state)
            next_states, reward = env.step(action)

            logits = model.forward(state)
            log_probs_batch = torch.log_softmax(logits, dim=-1).gather(1, action.unsqueeze(1))
            batch_log_probs.append(log_probs_batch.squeeze())
            batch_rewards.append(reward)

            state = next_states
        print(f"move reward: {sum(batch_rewards)}")
        model.reinforce_update(batch_log_probs, batch_rewards, optimizer)

        if i % 10 == 0:
            print(f"Episode {i}/{num_episodes}")

    print("Training finished.")

In [20]:
from torch import optim
tensor = torch.load('X_all.pt')
rl = rlClassifier(6*8*8, 4096, [512, 1024, 2048]).cuda()
tensor = tensor[3443]
tensor = tensor.cuda()
torch.autograd.set_detect_anomaly(True)
optimizer = optim.Adam(rl.parameters(), lr=0.05)
env = ChessEnv()

train_model(env, rl, optimizer, tensor)

move reward: -2000
Episode 0/200
move reward: -2000
move reward: -2000
move reward: -2000
move reward: -2000
move reward: -2000


KeyboardInterrupt: 