In [1]:
!pip install python-chess

Collecting python-chess
  Downloading python_chess-1.999-py3-none-any.whl.metadata (776 bytes)
Collecting chess<2,>=1 (from python-chess)
  Downloading chess-1.11.2.tar.gz (6.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m36.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading python_chess-1.999-py3-none-any.whl (1.4 kB)
Building wheels for collected packages: chess
  Building wheel for chess (setup.py) ... [?25l[?25hdone
  Created wheel for chess: filename=chess-1.11.2-py3-none-any.whl size=147775 sha256=d86a0c56d5cf5b926488707b1f7108f9676f2f6db6584909e434516efc4f9bca
  Stored in directory: /root/.cache/pip/wheels/fb/5d/5c/59a62d8a695285e59ec9c1f66add6f8a9ac4152499a2be0113
Successfully built chess
Installing collected packages: chess, python-chess
Successfully installed chess-1.11.2 python-chess-1.999


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import chess

# define piece names
piece_names = {
  'p': 'black pawn',   'P': 'white pawn',
  'r': 'black rook',   'R': 'white rook',
  'n': 'black knight', 'N': 'white knight',
  'b': 'black bishop', 'B': 'white bishop',
  'q': 'black queen',  'Q': 'white queen',
  'k': 'black king',   'K': 'white king'
}

# convert FEN string to description of board state
def fenToDescription(fen_string):

  description = ""

  # load chess board with the FEN
  board = chess.Board(fen_string)

  isFirstPiece = True
  for square in chess.SQUARES:
      piece = board.piece_at(square)
      if piece:
          piece_name = piece_names[piece.symbol()]

          if isFirstPiece:
            piece_name = piece_name[0].upper() + piece_name[1:]
            isFirstPiece = False

          description += f"{piece_name} on {chess.square_name(square)}, "

  # replace comma at end with period
  description = description[:-2] + '.'
  return description


# testing with example FEN
fen_example = "r1bqk2r/pPp2pbp/2np2p1/4p3/2B1n3/5N2/1PPP1PPP/RNBQK2R w KQkq - 1 8"
print(fenToDescription(fen_example))

White rook on a1, white knight on b1, white bishop on c1, white queen on d1, white king on e1, white rook on h1, white pawn on b2, white pawn on c2, white pawn on d2, white pawn on f2, white pawn on g2, white pawn on h2, white knight on f3, white bishop on c4, black knight on e4, black pawn on e5, black knight on c6, black pawn on d6, black pawn on g6, black pawn on a7, white pawn on b7, black pawn on c7, black pawn on f7, black bishop on g7, black pawn on h7, black rook on a8, black bishop on c8, black queen on d8, black king on e8, black rook on h8.


In [4]:
# determine all legal moves in a board state
def generate_legal_moves(fen_string):
    board = chess.Board(fen_string)
    return [move.uci() for move in board.legal_moves]

# test
legal_moves = generate_legal_moves(fen_example)
print(legal_moves)

['c4f7', 'c4e6', 'c4a6', 'c4d5', 'c4b5', 'c4d3', 'c4b3', 'c4e2', 'c4a2', 'c4f1', 'f3g5', 'f3e5', 'f3h4', 'f3d4', 'f3g1', 'h1g1', 'h1f1', 'e1e2', 'e1f1', 'd1e2', 'b1c3', 'b1a3', 'a1a7', 'a1a6', 'a1a5', 'a1a4', 'a1a3', 'a1a2', 'e1g1', 'b7c8q', 'b7c8r', 'b7c8b', 'b7c8n', 'b7a8q', 'b7a8r', 'b7a8b', 'b7a8n', 'b7b8q', 'b7b8r', 'b7b8b', 'b7b8n', 'h2h3', 'g2g3', 'd2d3', 'c2c3', 'b2b3', 'h2h4', 'g2g4', 'd2d4', 'b2b4']


In [5]:
# convert move in chess notation into a textual description
def describe_move(fen_string, move_uci):
    board = chess.Board(fen_string)

    move = chess.Move.from_uci(move_uci)

    piece = board.piece_at(move.from_square)
    if not piece:
        return "Invalid move: no piece on source square."

    piece_names = {
        'p': 'pawn', 'n': 'knight', 'b': 'bishop',
        'r': 'rook', 'q': 'queen', 'k': 'king'
    }

    color = "White" if piece.color == chess.WHITE else "Black"
    name = piece_names[piece.symbol().lower()]
    from_sq = chess.square_name(move.from_square)
    to_sq = chess.square_name(move.to_square)

    # copy current board to apply the move
    board_copy = board.copy()
    board_copy.push(move)

    # castling
    if board.is_castling(move):
        side = "kingside" if move.to_square > move.from_square else "queenside"
        desc = f"{color} castles {side}"
    # captures
    elif board.is_capture(move):
        if board.is_en_passant(move):
            captured = 'pawn (en passant)'
        else:
            captured_piece = board.piece_at(move.to_square)
            captured = piece_names[captured_piece.symbol().lower()] if captured_piece else "unknown piece"
        desc = f"{color} {name} captures {captured} on {to_sq}"

        # if it's a promotion too
        if move.promotion:
            promo_name = piece_names[chess.Piece(move.promotion, piece.color).symbol().lower()]
            desc += f" and promotes to a {promo_name}"
    # promotions
    elif move.promotion:
        promo_name = piece_names[chess.Piece(move.promotion, piece.color).symbol().lower()]
        desc = f"{color} {name} moves from {from_sq} to {to_sq} and promotes to a {promo_name}"
    # normal move otherwise
    else:
        desc = f"{color} {name} moves from {from_sq} to {to_sq}"

    # checks or checkmates
    if board_copy.is_checkmate():
        desc += " and delivers checkmate"
    elif board_copy.is_check():
        desc += " and delivers check"

    return desc + '.'

for move in legal_moves:
  print(f"Move: {move}. Description: {describe_move(fen_example, move)}")

Move: c4f7. Description: White bishop captures pawn on f7 and delivers check.
Move: c4e6. Description: White bishop moves from c4 to e6.
Move: c4a6. Description: White bishop moves from c4 to a6.
Move: c4d5. Description: White bishop moves from c4 to d5.
Move: c4b5. Description: White bishop moves from c4 to b5.
Move: c4d3. Description: White bishop moves from c4 to d3.
Move: c4b3. Description: White bishop moves from c4 to b3.
Move: c4e2. Description: White bishop moves from c4 to e2.
Move: c4a2. Description: White bishop moves from c4 to a2.
Move: c4f1. Description: White bishop moves from c4 to f1.
Move: f3g5. Description: White knight moves from f3 to g5.
Move: f3e5. Description: White knight captures pawn on e5.
Move: f3h4. Description: White knight moves from f3 to h4.
Move: f3d4. Description: White knight moves from f3 to d4.
Move: f3g1. Description: White knight moves from f3 to g1.
Move: h1g1. Description: White rook moves from h1 to g1.
Move: h1f1. Description: White rook mov

In [6]:
# ONLY RUN THIS CELL FOR CHESS-SPECIFIC BERT

from transformers import AutoTokenizer, AutoModel
import torch

# load chess-specific BERT model
bert_path_trained = "/content/drive/MyDrive/chess-bert-mlm/final_model"

In [7]:
# DEFINING THE DRRN MODEL

"""Inspiration from Singh et. al. from https://arxiv.org/pdf/2107.08408.
Their DRRN framework: https://github.com/Exploration-Lab/IFG-Pretrained-LM/blob/main/dbert_drrn/model.py
My framework is inspired from them,  but simipler and modified for chess"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from transformers import BertTokenizer, BertModel
from transformers import AutoTokenizer, AutoModel

class DRRN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, lm_path="bert-base-uncased", custom_model=False):
        super(DRRN, self).__init__()

        # if not using custom model
        if custom_model == False:
            self.tokenizer = BertTokenizer.from_pretrained(lm_path)
            self.bert = BertModel.from_pretrained(lm_path)

        # if custom model, user different tokenizer and load model differently
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(lm_path)
            self.bert = AutoModel.from_pretrained(lm_path)

        self.bert.eval()
        self.state_proj = nn.Linear(self.bert.config.hidden_size, hidden_dim)
        self.action_proj = nn.Linear(self.bert.config.hidden_size, hidden_dim)
        self.q_proj = nn.Linear(hidden_dim, 1)

    # encode states and actions using the BERT [cls] token embedding
    def encode(self, texts):
        encoding = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
        input_ids, attention_mask = encoding['input_ids'].to(device), encoding['attention_mask'].to(device)
        with torch.no_grad():
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0, :]

    def forward(self, state_batch, act_batch):
        state_texts = [s[0] for s in state_batch]
        state_vecs = self.state_proj(self.encode(state_texts))

        q_values = []
        for s_vec, acts in zip(state_vecs, act_batch):
            act_vecs = self.action_proj(self.encode(acts))
            scores = self.q_proj(torch.tanh(s_vec.unsqueeze(0) * act_vecs)).squeeze(-1)
            q_values.append(scores)
        return q_values

    # decide moves to play with greedy-epsilon
    def act(self, state_batch, act_batch, epsilon=0.1):
        q_values = self.forward(state_batch, act_batch)
        chosen_idxs = []
        for qv in q_values:
            if random.random() < epsilon:
                chosen_idxs.append(random.randint(0, len(qv) - 1))
            else:
                chosen_idxs.append(torch.argmax(qv).item())
        return chosen_idxs, q_values

In [8]:
# for simulating games. This function is a random move generator that is able to play against the model.
def apply_move(fen_string, move):
  board = chess.Board(fen_string)
  game_over = False
  result = None

  move_obj = chess.Move.from_uci(move)

  board.push(move_obj)

  # print(board)

  # check if game is over after model plays a move
  if board.is_game_over():
      game_over = True
      board_result = board.result()
      if board_result == "1/2-1/2":
        result = 0
        print("DRAW")
      else: # if a player won after the bot played a move, it must mean the bot performed checkmate
        result = 1
        print("BOT WINS") # Model wins

      return None, fen_string, game_over, result

  # get all legal moves in the current position
  legal_moves = list(board.legal_moves)

  # select a random move from the legal moves
  random_move = random.choice(legal_moves)

  # apply the random move to the board
  board.push(random_move)

  new_fen_state = board.fen()

  if board.is_game_over():
      game_over = True
      board_result = board.result()
      if board_result == "1/2-1/2":
        result = 0
        print("DRAW")

      # if a player won after the user (random move generator) played a move,
      # it must mean the user performed checkmate
      else:
        result = -1
        print("USER WINS")


  return str(random_move), new_fen_state, game_over, result

In [9]:
import chess
import chess.engine

# install Stockfish chess engine into Colab environment
!apt-get install stockfish

engine = chess.engine.SimpleEngine.popen_uci("/usr/games/stockfish")

def get_best_move(fen):
  board = chess.Board(fen)

  # Get Stockfish's best move
  result = engine.play(board, chess.engine.Limit(time=0.1))  # Play with a time limit (e.g., 2 seconds)
  return str(result.move)

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
Suggested packages:
  polyglot xboard | scid
The following NEW packages will be installed:
  stockfish
0 upgraded, 1 newly installed, 0 to remove and 34 not upgraded.
Need to get 24.8 MB of archives.
After this operation, 47.4 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 stockfish amd64 14.1-1 [24.8 MB]
Fetched 24.8 MB in 4s (6,056 kB/s)
Selecting previously unselected package stockfish.
(Reading database ... 126102 files and directories currently installed.)
Preparing to unpack .../stockfish_14.1-1_amd64.deb ...
Unpacking stockfish (14.1-1) ...
Setting up stockfish (14.1-1) ...
Processing triggers for man-db (2.10.2-1) ...


In [10]:
get_best_move("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1")

'c7c5'

In [11]:
# create a PGN string from a list of moves in a chess game
def move_list_to_pgn(move_list):
  # create a new game and get the board from it
  game = chess.pgn.Game()
  node = game

  board = game.board()

  # apply moves
  for uci_move in move_list:
      move = board.parse_uci(uci_move)
      board.push(move)
      node = node.add_variation(move)

  # return pgn
  return(game)


def move_to_pgn_with_comments(move_comment_list, result):
  # convert result int into string
  result_str = ""
  if result == 1:
    result_str = "1-0"
  elif result == 0:
    result_str = "1/2-1/2"
  else:
    result_str = "0-1"

  # create game and get board
  game = chess.pgn.Game()
  node = game
  board = game.board()

  # add moves with comments
  for uci, comment in move_comment_list:
      move = board.parse_uci(uci)
      board.push(move)
      node = node.add_variation(move)
      node.comment = comment

  game.headers["White"] = "BOT"
  game.headers["Black"] = "RANDOM"
  game.headers["Result"] = result_str

  # return pgn
  return(game)


import os
os.makedirs("/content/drive/MyDrive/chess_training", exist_ok=True)

def add_game_to_pgn_file(game_pgn):
  with open("many_games_trained.pgn", "a") as f:
      f.write(str(game_pgn) + "\n\n")

  # also add/update file to drive to make sure data is not lost
  # drive_path = "/content/drive/MyDrive/chess_training/many_games_vanilla.pgn" # ONLY USE FOR BERT-BASE
  drive_path = "/content/drive/MyDrive/chess_training/many_games_trained.pgn" # ONLY USE FOR CHESS-PRETRAINED MODEL


  with open(drive_path, "a") as f:
      f.write(str(game_pgn) + "\n\n")

In [12]:
# functions to continually save model while training

import os

# checkpoint_dir = "/content/drive/MyDrive/checkpoints_DRRN/base"  # for vanilla BERT-base model
checkpoint_dir = "/content/drive/MyDrive/checkpoints_DRRN/chess"  # for chess-specific BERT model
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, "drrn_checkpoint.pth")

def save_checkpoint(model, optimizer, episode, path):
    torch.save({
        'episode': episode,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)
    print(f"Checkpoint saved to {path}")

def load_checkpoint(model, optimizer, path):
    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_episode = checkpoint['episode'] + 1
        print(f"Loaded checkpoint from episode {checkpoint['episode']}")
    else:
        start_episode = 0
        print("No checkpoint found. Starting from scratch.")
    return start_episode

In [None]:
import chess
import chess.engine
import torch
import random
import chess.pgn

# USE THIS ONE FOR BERT-BASE
# drrn = DRRN(vocab_size=None, embedding_dim=256, hidden_dim=128).to(device)

# USE THIS ONE FOR THE PRETRAINED BERT ON CHESS TEXTS
drrn = DRRN(vocab_size=None, embedding_dim=256, hidden_dim=128, lm_path=bert_path_trained, custom_model=True).to(device)


optimizer = torch.optim.Adam(drrn.parameters(), lr=1e-4)
engine = chess.engine.SimpleEngine.popen_uci("/usr/games/stockfish")

# load saved model if we are continuing to train it
start_episode = load_checkpoint(drrn, optimizer, checkpoint_path)

def get_best_move(fen):
    board = chess.Board(fen)
    result = engine.play(board, chess.engine.Limit(time=0.1))  # give Stockfish 0.1 seconds to make a move
    return result.move.uci()

def compute_reward(recorded_state, chosen_move, best_move):
  # evaluate position before move
  board = chess.Board(recorded_state)
  info_before = engine.analyse(board, chess.engine.Limit(time=0.1))
  score_before = info_before["score"].white().score(mate_score=1000)  # get score in centipawns
  if abs(score_before) > 1000:
    score_before = (1000 / abs(score_before)) * score_before  # max score is +/- 10.0

  # play chosen move
  move = chess.Move.from_uci(chosen_move)
  board.push(move)

  # evaluate position after move
  info_after = engine.analyse(board, chess.engine.Limit(time=0.1))
  score_after = info_after["score"].white().score(mate_score=1000)
  if abs(score_after) > 1000:
    score_after = (1000 / abs(score_after)) * score_after  # max score is +/- 10.0

  # calculate centipawn loss
  cpl = (score_before - score_after) if board.turn == chess.BLACK else (score_before - score_after)
  cpl = max(cpl, 0)  # only include loss if move worsened position


  # if centipawn loss is over 100, we want to start penalizing
  reward = -(cpl - 100) / (100 * 20)

  # increase reward extra of cps is under 100
  if cpl < 100:
    reward = 0.5 - .005 * cpl

  # so reward ranges from -1 to 0.5 b/c max cpl is 2000

  return(reward)


# TRAINING
game_results = []

for episode in range(300):  # number of games to train on
    game_memory = []
    state = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
    done = False

    game_moves = []
    while not done:
        state_desc = fenToDescription(state)
        legal_uci_moves = generate_legal_moves(state)
        move_descs = [describe_move(state, move) for move in legal_uci_moves]

        state_batch = [(state_desc, "", "", "")]
        act_batch = [move_descs]

        # the model chooses a move
        chosen_idx, _ = drrn.act(state_batch, act_batch, epsilon=0.2)
        chosen_move = legal_uci_moves[chosen_idx[0]]
        game_moves.append(chosen_move)

        # store the move and the state for later comparison with Stockfish evaluations
        game_memory.append((state_desc, legal_uci_moves, move_descs, chosen_idx[0], state))

        # apply the move to the chess board
        opp_move, state, done, result = apply_move(state, chosen_move)
        if not (opp_move == None):
          game_moves.append(opp_move)

    game_pgn = move_list_to_pgn(game_moves)

    final_reward = result
    game_results.append(result)

    # after the game ends, get Stockfish's evaluations for each recorded position
    moves_with_comments = []
    for i in range(len(game_memory)):
        state_desc, legal_uci_moves, move_descs, chosen_idx, recorded_state = game_memory[i]

        # get Stockfish's best move for the recorded state
        best_move = get_best_move(recorded_state)

        # get the move played by the model
        chosen_move = legal_uci_moves[chosen_idx]

        # calculate the reward based on the model's move compared to Stockfish evaluations
        reward = compute_reward(recorded_state, chosen_move, best_move)

        moves_with_comments.append((game_moves[2 * i], f"Best move: {best_move}. Reward: {reward}."))

        # update memory with the reward
        game_memory[i] = (*game_memory[i][:-1], reward)

        # add move from opponent
        if len(game_moves) > 2 * i + 1:
          moves_with_comments.append((game_moves[2 * i + 1], ""))


    # get PGN of the chess game and add it to our PGN file
    pgn_with_comments = move_to_pgn_with_comments(moves_with_comments, result)
    add_game_to_pgn_file(pgn_with_comments)

    # update the model with the rewards for all moves in the game
    for state_desc, legal_uci_moves, move_descs, chosen_idx, reward in game_memory:
        state_batch = [(state_desc, "", "", "")]
        act_batch = [move_descs]

        q_values = drrn(state_batch, act_batch)[0]
        q_value = q_values[chosen_idx]

        # calculate loss and backpropagate
        combined_reward = 0.5 * reward + 0.5 * final_reward
        loss = (q_value - combined_reward) ** 2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Episode {episode} done. Reward: {final_reward}")

    # save after every episode
    save_checkpoint(drrn, optimizer, episode, checkpoint_path)

Some weights of BertModel were not initialized from the model checkpoint at /content/drive/MyDrive/chess-bert-mlm/final_model and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded checkpoint from episode 127
DRAW
Episode 0 done. Reward: 0
Checkpoint saved to /content/drive/MyDrive/checkpoints_DRRN/chess/drrn_checkpoint.pth
DRAW
Episode 1 done. Reward: 0
Checkpoint saved to /content/drive/MyDrive/checkpoints_DRRN/chess/drrn_checkpoint.pth
DRAW
Episode 2 done. Reward: 0
Checkpoint saved to /content/drive/MyDrive/checkpoints_DRRN/chess/drrn_checkpoint.pth
DRAW
Episode 3 done. Reward: 0
Checkpoint saved to /content/drive/MyDrive/checkpoints_DRRN/chess/drrn_checkpoint.pth
BOT WINS
Episode 4 done. Reward: 1
Checkpoint saved to /content/drive/MyDrive/checkpoints_DRRN/chess/drrn_checkpoint.pth
DRAW
Episode 5 done. Reward: 0
Checkpoint saved to /content/drive/MyDrive/checkpoints_DRRN/chess/drrn_checkpoint.pth
DRAW
Episode 6 done. Reward: 0
Checkpoint saved to /content/drive/MyDrive/checkpoints_DRRN/chess/drrn_checkpoint.pth
DRAW
Episode 7 done. Reward: 0
Checkpoint saved to /content/drive/MyDrive/checkpoints_DRRN/chess/drrn_checkpoint.pth
DRAW
Episode 8 done. Rewa