# Imports

In [1]:
!pip install wandb -qU
%pip install chess

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m35.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m301.8/301.8 kB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting chess
  Downloading chess-1.10.0-py3-none-any.whl.metadata (19 kB)
Downloading chess-1.10.0-py3-none-any.whl (154 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.4/154.4 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: chess
Successfully installed chess-1.10.0


In [2]:
import wandb
wandb.require("core")

wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
import chess
from chess import pgn

In [5]:
import os # to delete old files

# Variables

In [6]:
DIR = "drive/MyDrive/Colab Notebooks"
DATA_DIR = f"{DIR}/tensors"
SAVE_DIR = f"{DIR}/chess_ai"

# Matrices to games

In [7]:
def tensor_to_board(tensor):
    board = chess.Board()
    board.clear()
    idx_to_piece = [
        chess.PAWN,
        chess.KNIGHT,
        chess.BISHOP,
        chess.ROOK,
        chess.QUEEN,
        chess.KING,
    ]
    for row in range(8):
        for col in range(8):
            for piece_idx in range(6):
                white_offset = 0
                piece_is_here = tensor[row][col][piece_idx + white_offset]
                if piece_is_here:
                    piece = idx_to_piece[piece_idx]
                    board.set_piece_at(
                        chess.square(col, row), chess.Piece(piece, chess.WHITE)
                    )

                black_offset = 6
                piece_is_here = tensor[row][col][piece_idx + black_offset]
                if piece_is_here:
                    piece = idx_to_piece[piece_idx]
                    board.set_piece_at(
                        chess.square(col, row), chess.Piece(piece, chess.BLACK)
                    )
    return board


def info_to_move(from_square, to_square, promotion=None):
    move = chess.Move(from_square, to_square)
    if promotion:
        if promotion == chess.QUEEN:
            move.promotion = chess.QUEEN
        elif promotion == chess.BISHOP:
            move.promotion = chess.BISHOP
        elif promotion == chess.ROOK:
            move.promotion = chess.ROOK
        elif promotion == chess.KNIGHT:
            move.promotion = chess.KNIGHT
    return move


def tensor_to_move(tensor, color):
    tensor_8_8_76_form = tensor.view(8, 8, 76)
    for row in range(8):
        for col in range(8):
            promotion = None
            from_square = chess.square(col, row)
            for i in range(64, 67):
                if tensor_8_8_76_form[row][col][i]:
                    promotion = chess.KNIGHT
                    to_square = chess.square(
                        col - 65 + i, row + 1 if color == chess.WHITE else row - 1
                    )
                    return info_to_move(from_square, to_square, promotion)
            for i in range(67, 70):
                if tensor_8_8_76_form[row][col][i]:
                    promotion = chess.ROOK
                    to_square = chess.square(
                        col - 68 + i, row + 1 if color == chess.WHITE else row - 1
                    )
                    return info_to_move(from_square, to_square, promotion)
            for i in range(70, 73):
                if tensor_8_8_76_form[row][col][i]:
                    promotion = chess.BISHOP
                    to_square = chess.square(
                        col - 71 + i, row + 1 if color == chess.WHITE else row - 1
                    )
                    return info_to_move(from_square, to_square, promotion)
            for i in range(73, 76):
                if tensor_8_8_76_form[row][col][i]:
                    promotion = chess.QUEEN
                    to_square = chess.square(
                        col - 74 + i, row + 1 if color == chess.WHITE else row - 1
                    )
                    return info_to_move(from_square, to_square, promotion)
            for i in range(7, 56, 8):
                if tensor_8_8_76_form[row][col][i]:
                    to_square = chess.square(col + i // 8 + 1, row + i // 8 + 1)
                    return info_to_move(from_square, to_square)
            for i in range(3, 52, 8):
                if tensor_8_8_76_form[row][col][i]:
                    to_square = chess.square(col - i // 8 - 1, row - i // 8 - 1)
                    return info_to_move(from_square, to_square)
            for i in range(1, 50, 8):
                if tensor_8_8_76_form[row][col][i]:
                    to_square = chess.square(col - i // 8 - 1, row + i // 8 + 1)
                    return info_to_move(from_square, to_square)
            for i in range(5, 54, 8):
                if tensor_8_8_76_form[row][col][i]:
                    to_square = chess.square(col + i // 8 + 1, row - i // 8 - 1)
                    return info_to_move(from_square, to_square)
            for i in range(6, 55, 8):
                if tensor_8_8_76_form[row][col][i]:
                    to_square = chess.square(col + i // 8 + 1, row)
                    return info_to_move(from_square, to_square)
            for i in range(2, 51, 8):
                if tensor_8_8_76_form[row][col][i]:
                    to_square = chess.square(col - i // 8 - 1, row)
                    return info_to_move(from_square, to_square)
            for i in range(0, 49, 8):
                if tensor_8_8_76_form[row][col][i]:
                    to_square = chess.square(col, row + i // 8 + 1)
                    return info_to_move(from_square, to_square)
            for i in range(4, 53, 8):
                if tensor_8_8_76_form[row][col][i]:
                    to_square = chess.square(col, row - i // 8 - 1)
                    return info_to_move(from_square, to_square)
            if tensor_8_8_76_form[row][col][56]:
                to_square = chess.square(col - 1, row + 2)
                return info_to_move(from_square, to_square)
            if tensor_8_8_76_form[row][col][57]:
                to_square = chess.square(col - 2, row + 1)
                return info_to_move(from_square, to_square)
            if tensor_8_8_76_form[row][col][58]:
                to_square = chess.square(col - 2, row - 1)
                return info_to_move(from_square, to_square)
            if tensor_8_8_76_form[row][col][59]:
                to_square = chess.square(col - 1, row - 2)
                return info_to_move(from_square, to_square)
            if tensor_8_8_76_form[row][col][60]:
                to_square = chess.square(col + 1, row - 2)
                return info_to_move(from_square, to_square)
            if tensor_8_8_76_form[row][col][61]:
                to_square = chess.square(col + 2, row - 1)
                return info_to_move(from_square, to_square)
            if tensor_8_8_76_form[row][col][62]:
                to_square = chess.square(col + 2, row + 1)
                return info_to_move(from_square, to_square)
            if tensor_8_8_76_form[row][col][63]:
                to_square = chess.square(col + 1, row + 2)
                return info_to_move(from_square, to_square)
    print("Can't convert a tensor to a move")

# Games to matrices

In [8]:
ELO_RANGES = [
    800,
    # 1200,
    # 1600,
    # 2000,
    # 2400,
]
NUM_FOR_SINGLE_OF_ELO_RANGE = 4000
ELO_RANGE_MUL = [
    1,
    # 2,
    # 4,
    # 4,
    # 3,
]


def game_to_tensor(game: pgn.GameNode) -> tuple[torch.Tensor, torch.Tensor]:
    # Initialize an 8x8x12 tensor with zeros
    tensor = torch.zeros(8, 8, 12)

    # Mapping of pieces to tensor indices
    piece_to_idx = {
        chess.PAWN: 0,
        chess.KNIGHT: 1,
        chess.BISHOP: 2,
        chess.ROOK: 3,
        chess.QUEEN: 4,
        chess.KING: 5,
    }

    # Current player color
    color = game.turn()

    # Iterate over the board and set the tensor values
    for i in range(64):
        piece = game.board().piece_at(i)
        if piece:
            color_offset = 0 if piece.color == color else 6
            piece_idx = piece_to_idx[piece.piece_type]
            row, col = divmod(i, 8)
            tensor[row, col, color_offset + piece_idx] = 1

    # Add the additional binary features
    queen_castling_right = bool(
        game.board().castling_rights & (chess.BB_A1 if color else chess.BB_A8)
    )
    king_castling_right = bool(
        game.board().castling_rights & (chess.BB_H1 if color else chess.BB_H8)
    )
    consts = torch.tensor(
        [
            queen_castling_right,
            king_castling_right,
            1 if game.turn() == chess.WHITE else 0,
        ]
    )

    return tensor, consts


def move_to_tensor(move: chess.Move) -> torch.Tensor:
    from_sq, to_sq = move.from_square, move.to_square
    from_row, from_col = divmod(from_sq, 8)
    to_row, to_col = divmod(to_sq, 8)

    promotion_tensor = torch.zeros(8, 8, 12)
    queen_moves_tensor = torch.zeros(8, 8, 56)
    knight_moves_tensor = torch.zeros(8, 8, 8)
    if move.promotion:
        if move.promotion == chess.QUEEN:
            promotion_tensor[from_row][from_col][10 + to_col - from_col] = 1
        elif move.promotion == chess.BISHOP:
            promotion_tensor[from_row][from_col][7 + to_col - from_col] = 1
        elif move.promotion == chess.ROOK:
            promotion_tensor[from_row][from_col][4 + to_col - from_col] = 1
        elif move.promotion == chess.KNIGHT:
            promotion_tensor[from_row][from_col][1 + to_col - from_col] = 1
    elif from_row - to_row == from_col - to_col:
        if from_row < to_row:
            queen_moves_tensor[from_row][from_col][
                7 + (abs(from_col - to_col) - 1) * 8
            ] = 1
        else:
            queen_moves_tensor[from_row][from_col][
                3 + (abs(from_col - to_col) - 1) * 8
            ] = 1
    elif from_row - to_row == to_col - from_col:
        if from_row < to_row:
            queen_moves_tensor[from_row][from_col][
                1 + (abs(from_col - to_col) - 1) * 8
            ] = 1
        else:
            queen_moves_tensor[from_row][from_col][
                5 + (abs(from_col - to_col) - 1) * 8
            ] = 1
    elif from_row == to_row:
        if from_col < to_col:
            queen_moves_tensor[from_row][from_col][
                6 + (abs(from_col - to_col) - 1) * 8
            ] = 1
        else:
            queen_moves_tensor[from_row][from_col][
                2 + (abs(from_col - to_col) - 1) * 8
            ] = 1
    elif from_col == to_col:
        if from_row < to_row:
            queen_moves_tensor[from_row][from_col][
                0 + (abs(from_row - to_row) - 1) * 8
            ] = 1
        else:
            queen_moves_tensor[from_row][from_col][
                4 + (abs(from_row - to_row) - 1) * 8
            ] = 1
    elif from_row - to_row == -2:
        if from_col - to_col == -1:
            knight_moves_tensor[from_row][from_col][7] = 1
        else:
            knight_moves_tensor[from_row][from_col][0] = 1
    elif from_row - to_row == 2:
        if from_col - to_col == -1:
            knight_moves_tensor[from_row][from_col][4] = 1
        else:
            knight_moves_tensor[from_row][from_col][3] = 1
    elif from_row - to_row == -1:
        if from_col - to_col == -2:
            knight_moves_tensor[from_row][from_col][6] = 1
        else:
            knight_moves_tensor[from_row][from_col][1] = 1
    else:
        if from_col - to_col == -2:
            knight_moves_tensor[from_row][from_col][5] = 1
        else:
            knight_moves_tensor[from_row][from_col][2] = 1

    one_hot = torch.cat(
        (queen_moves_tensor, knight_moves_tensor, promotion_tensor),
        dim=2,
    )

    return torch.argmax(one_hot)


# Model creation

In [9]:
class ChessAI(nn.Module):
    def __init__(self):
        super(ChessAI, self).__init__()
        self.conv1 = nn.Conv2d(12, 128, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(128, 512, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(512)
        self.conv3 = nn.Conv2d(512, 1024, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(1024)
        self.fc1 = nn.Linear(1024 * 2 * 2, 2048)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(2048 + 3, 1024)  # Adding 3 for additional parameters
        self.dropout2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(1024, 4864)

    def forward(self, x, params):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, (2, 2))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, (2, 2))
        x = F.relu(self.bn3(self.conv3(x)))
        x = x.view(-1, 1024 * 2 * 2)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = torch.cat((x, params), 1)  # Combine board representation with additional parameters
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x


class ChessAISmaller(nn.Module):
    def __init__(self):
        super(ChessAISmaller, self).__init__()
        self.conv1 = nn.Conv2d(12, 64, kernel_size=3, padding=1)
        self.ln1 = nn.LayerNorm([64, 8, 8])
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.ln2 = nn.LayerNorm([128, 4, 4])
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.ln3 = nn.LayerNorm([256, 2, 2])
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(256 + 3, 512)  # Adding 3 for additional parameters
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 4864)

    def forward(self, x, params):
        x = F.relu(self.ln1(self.conv1(x)))
        x = F.max_pool2d(x, (2, 2))

        x = F.relu(self.ln2(self.conv2(x)))
        x = F.max_pool2d(x, (2, 2))

        x = F.relu(self.ln3(self.conv3(x)))
        x = self.global_avg_pool(x).view(-1, 256)
        x = torch.cat((x, params), 1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        torch.nn.init.kaiming_uniform_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0.01)

model = ChessAISmaller()
model.apply(init_weights)
model.to(device)

ChessAISmaller(
  (conv1): Conv2d(12, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (ln1): LayerNorm((64, 8, 8), eps=1e-05, elementwise_affine=True)
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (ln2): LayerNorm((128, 4, 4), eps=1e-05, elementwise_affine=True)
  (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (ln3): LayerNorm((256, 2, 2), eps=1e-05, elementwise_affine=True)
  (global_avg_pool): AdaptiveAvgPool2d(output_size=1)
  (fc1): Linear(in_features=259, out_features=512, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=512, out_features=4864, bias=True)
)

In [10]:
# print the total number of parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')

3,018,688 total parameters.


In [11]:
#optimizer = optim.Adam(model.parameters(), lr=0.0001)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
loss_function = nn.CrossEntropyLoss()

# Data preparation

In [46]:
RATING = 1600
PART_OF_THIS_RATING = 1

RATING_RANGE = f"{RATING}-{RATING+400}{'' if not PART_OF_THIS_RATING else f'-{PART_OF_THIS_RATING}'}"
print(RATING_RANGE)

1600-2000-1


In [47]:
states = torch.load(f"{DATA_DIR}/states_tensors_{RATING_RANGE}.pt", map_location=device)
print(f"{states.size() = }")
states_consts = torch.load(f"{DATA_DIR}/states_consts_tensors_{RATING_RANGE}.pt", map_location=device)
print(f"{states_consts.size() = }")
moves = torch.load(f"{DATA_DIR}/moves_tensors_{RATING_RANGE}.pt", map_location=device)
print(f"{moves.size() = }")

states.size() = torch.Size([1132411, 8, 8, 12])
states_consts.size() = torch.Size([1132411, 3])
moves.size() = torch.Size([1132411])


In [48]:
class ChessDataset(Dataset):
    def __init__(self, states, states_consts, moves):
        self.states = states
        self.states_consts = states_consts
        self.moves = moves

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        # one_hot_move = torch.zeros(4864)
        # one_hot_move[self.moves[idx]] = 1
        state = self.states[idx].permute(2, 0, 1)  # Change to [12, 8, 8]
        state_const = self.states_consts[idx]
        move = self.moves[idx]
        return state, state_const, move

# Create dataset
chess_dataset = ChessDataset(states, states_consts, moves)

# DataLoader
dataloader = DataLoader(chess_dataset, batch_size=32, shuffle=True, drop_last=True)

# Train

In [49]:
EPOCHS = 4
STEP_TO_SAVE = 5000
STEP_TO_PRINT_LOSS = 2500

INITIAL_EPOCH = 4

In [15]:
wandb.init(
    # set the wandb project where this run will be logged
    project="chess_ai_2",

    # track hyperparameters and run metadata
    config={
        "learning_rate": optimizer.param_groups[0]["lr"],
        "architecture": "CNN",
        "dataset": "lichess 2020.11",
        "model": "ChessAISmaller"
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mrad1an[0m ([33mrad1an-personal[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [44]:
def save_model(model, optimizer, epoch="latest", batch="latest"):
    print("Saving model...", end=" ")
    torch.save(model.state_dict(), f"{SAVE_DIR}/model_parameters_{RATING_RANGE}_epoch{epoch}_batch{batch}.pth")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'epoch': epoch,
        'batch': batch,
    }, f'{SAVE_DIR}/model_and_optimizer_{RATING_RANGE}_epoch{epoch}.pth')
    print("Done!")
    if batch % STEP_TO_SAVE != 0:
        batch = batch - batch%STEP_TO_SAVE + STEP_TO_SAVE
    if batch != "latest" and batch - STEP_TO_SAVE > 0:
        print(f"Trying to remove {SAVE_DIR}/model_parameters_{RATING_RANGE}_epoch{epoch}_batch{batch-STEP_TO_SAVE}.pth")
        try:
            os.remove(f"{SAVE_DIR}/model_parameters_{RATING_RANGE}_epoch{epoch}_batch{batch-STEP_TO_SAVE}.pth")
            print("Removed")
        except FileNotFoundError:
            print("File not found")
    if batch != "latest" and epoch > 1:
        print(f"Trying to remove {SAVE_DIR}/model_and_optimizer_{RATING_RANGE}_epoch{int(epoch)-1}.pth")
        try:
            os.remove(f'{SAVE_DIR}/model_and_optimizer_{RATING_RANGE}_epoch{int(epoch)-1}.pth')
            print("Removed")
        except FileNotFoundError:
            print("File not found")

def top_k_accuracy(output, target, k=10):
    _, pred = output.topk(k, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return correct[:k].float().sum(0, keepdim=True).mean().item()

In [None]:
wandb.log({"rating": RATING})
wandb.log({"learning_rate": optimizer.param_groups[0]['lr']})
for epoch in range(INITIAL_EPOCH, EPOCHS+INITIAL_EPOCH):
    wandb.log({"epoch": epoch})
    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):
        inputs, consts, labels = data[0].to(device), data[1].to(device), data[2].to(device)

        optimizer.zero_grad()

        outputs = model(inputs, consts)
        loss = loss_function(outputs, labels.long())
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        top_1_acc = top_k_accuracy(outputs, labels, k=1)
        top_5_acc = top_k_accuracy(outputs, labels, k=5)

        wandb.log({"loss": loss.item(), "top_1_accuracy": top_1_acc, "top_5_accuracy": top_5_acc})

        running_loss += loss.item()

        if i % STEP_TO_PRINT_LOSS == 0 and i != 0:
            print(f"Epoch: {epoch: >2} | Batch: {i: >5} | loss: {running_loss / STEP_TO_PRINT_LOSS:.3f}")
            running_loss = 0.0

        if i % STEP_TO_SAVE == 0 and i != 0:
            save_model(model, optimizer, epoch, i)
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    wandb.log({"learning_rate": current_lr})
    save_model(model, optimizer, epoch=epoch, batch=i)
print("Finished Training")

Epoch:  4 | Batch:  2500 | loss: 3.506
Epoch:  4 | Batch:  5000 | loss: 3.505
Saving model... Done!
Trying to remove drive/MyDrive/Colab Notebooks/chess_ai/model_and_optimizer_1600-2000-1_epoch3.pth
Removed
Epoch:  4 | Batch:  7500 | loss: 3.502
Epoch:  4 | Batch: 10000 | loss: 3.492
Saving model... Done!
Trying to remove drive/MyDrive/Colab Notebooks/chess_ai/model_parameters_1600-2000-1_epoch4_batch5000.pth
Removed
Trying to remove drive/MyDrive/Colab Notebooks/chess_ai/model_and_optimizer_1600-2000-1_epoch3.pth
File not found
Epoch:  4 | Batch: 12500 | loss: 3.520
Epoch:  4 | Batch: 15000 | loss: 3.492
Saving model... Done!
Trying to remove drive/MyDrive/Colab Notebooks/chess_ai/model_parameters_1600-2000-1_epoch4_batch10000.pth
Removed
Trying to remove drive/MyDrive/Colab Notebooks/chess_ai/model_and_optimizer_1600-2000-1_epoch3.pth
File not found
Epoch:  4 | Batch: 17500 | loss: 3.483
Epoch:  4 | Batch: 20000 | loss: 3.490
Saving model... Done!
Trying to remove drive/MyDrive/Colab

In [None]:
save_model(model, optimizer)

Saving model... Done!


# To load saved model

In [22]:
# Load saved model
model = ChessAISmaller()
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
loss_function = nn.CrossEntropyLoss()

In [23]:
RATING_TO_LOAD = 800
PART_OF_RATING_TO_LOAD = 0
RATING_TO_LOAD_RANGE = f"{RATING_TO_LOAD}-{RATING_TO_LOAD+400}{'' if not PART_OF_RATING_TO_LOAD else f'-{PART_OF_RATING_TO_LOAD}'}"

EPOCH_TO_LOAD = 3

In [None]:
# Load saved state dict
BATCH_TO_LOAD = 15000
model.load_state_dict(torch.load(f"{SAVE_DIR}/model_parameters_{RATING_TO_LOAD_RANGE}_epoch{EPOCH_TO_LOAD}_batch{BATCH_TO_LOAD}.pth", map_location=device))

<All keys matched successfully>

In [24]:
# Load a model and the optimizer+scheduler state together
checkpoint = torch.load(f'{SAVE_DIR}/model_and_optimizer_{RATING_TO_LOAD_RANGE}_epoch{EPOCH_TO_LOAD}.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1, last_epoch=EPOCH_TO_LOAD-1) # if not saved with scheduler
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

# Use model

In [None]:
fen = input("Enter starting FEN: ")
board = chess.Board(fen=fen)
game = pgn.Game.from_board(board)

print(game.board())


def get_move():
    move_str = input("Enter move: ")
    if move_str == "q":
        exit()
    try:
        move = chess.Move.from_uci(move_str)
    except ValueError:
        print("Invalid move")
        return get_move()
    if move not in game.board().legal_moves:
        print("Illegal move")
        return get_move()
    return move


def get_bot_move():
    state, consts = game_to_tensor(game)
    state = (
        state.permute(2, 0, 1).unsqueeze(0).to(device)
    )  # Adds a batch dimension and moves to the device
    consts = consts.unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        raw_output = model(state, consts)
        probabilities = torch.softmax(raw_output, dim=1)
        # sort by probability and get the first legal move
        _, sorted_indices = torch.sort(probabilities, descending=True)
        for i in range(sorted_indices.size(1)):
            one_hot = torch.zeros(sorted_indices.size(1))
            one_hot[sorted_indices[0][i]] = 1
            move = tensor_to_move(one_hot, game.board().turn)
            print("Bot wants", move)
            if move in game.board().legal_moves:
                return move
            print("Bot move was illegal; trying next move")
    return None


def play_with_bot():
    global game
    while not game.board().is_game_over():
        move = get_bot_move()
        if not move:
            print("No legal moves found")
            break
        game.add_main_variation(move)
        game = game.next()
        if game.board().is_game_over():
            print("Game over")
            break
        print(game.board())
        print()

        move = get_move()
        game.add_main_variation(move)
        game = game.next()
        print(game.board())
        print()


def play_bot_with_bot(sleep_time=5):
    global game
    while not game.board().is_game_over():
        move = get_bot_move()
        if not move:
            print("No legal moves found")
            break
        game.add_main_variation(move)
        game = game.next()
        if game.board().is_game_over():
            print("Game over")
            break
        print(game.board())
        print()

        move = get_bot_move()
        if not move:
            print("No legal moves found")
            break
        game.add_main_variation(move)
        game = game.next()
        print(game.board())
        print()

Enter starting FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
P P P P P P P P
R N B Q K B N R


In [None]:
move = get_bot_move()
if not move:
    print("No legal moves found")
game.add_main_variation(move)
game = game.next()
if game.board().is_game_over():
    print("Game over")
game.board()

Bot wants 60 67


IndexError: list index out of range