In [None]:
from src.utils import CfgNode as CN
from src.valid_chess import GPT_valid_chess
import json
import torch
from safetensors.torch import load_model
import re
import chess
from stockfish import Stockfish
from IPython.display import display, clear_output
import pickle

conf = 'config.json'

C = CN()

with open(conf) as f:
    config = json.load(f)

with open("meta.pkl", "rb") as f:
    meta = pickle.load(f)

C.n_layer = config["n_layer"]
C.n_embd = config["n_embd"]
C.n_head = config["n_head"]
C.embd_pdrop = config["embd_pdrop"]
C.resid_pdrop = config["resid_pdrop"]
C.attn_pdrop = config["attn_pdrop"]
C.block_size = 768

chars = ['0', '5', '2', 'Q', 'e', 'a', 'R', 'd', '9', '8', '1', 'N', 'x', 'f', '6', '+', 'c', '=', 'h', 'O', 'B', '.', '7', '/', '4', '3', '*', 'b', '-', ' ', 'K', 'g', '#', ';', '[PAD]']

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

C.vocab_size = len(stoi)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
C.vocabulary = chars
model = GPT_valid_chess(C).eval()
load_model(model, "models/stockfish-16/model.safetensors")

In [None]:
def get_next_llm(history, round_number, device="cpu"):
    encoded_text = torch.tensor(encode(history), dtype=torch.int64).unsqueeze(0).to(device)
    output = model.generate(encoded_text, 10, do_sample=True, temperature=0.1).squeeze()
    decoded_text = decode(output.tolist())
    match = re.search(fr"{round_number}\.([a-zA-Z0-9-\+\#\=]*)", decoded_text)
    new_move = match.group(1).strip()
    return new_move

def get_next_llm_valid_chess(history, board, round_number, device="cpu"):
    encoded_text = torch.tensor(encode(history), dtype=torch.int64).unsqueeze(0).to(device)
    output = model.generate_valid_chess_move(encoded_text, board, do_sample=True, temperature=0.1).squeeze()
    decoded_text = decode(output.tolist())
    match = re.search(fr"{round_number}\.([a-zA-Z0-9-\+\#\=]*)", decoded_text)
    new_move = match.group(1).strip()
    return new_move

def get_next_llm_kv(history, round_number, kv_cache, device="cpu"):
    encoded_text = torch.tensor(encode(history), dtype=torch.int64).unsqueeze(0).to(device)
    output, kv_cache = model.generate(encoded_text, 10, do_sample=True, temperature=0.1, kv_cache=kv_cache, return_kv_cache=True)
    output = output.squeeze()
    decoded_text = decode(output.tolist())
    match = re.search(fr"{round_number}\.([a-zA-Z0-9-\+\#\=]*)", decoded_text)
    new_move = match.group(1).strip()
    return new_move, kv_cache

def get_next_llm_kv_valid_chess(history, board, round_number, kv_cache, device="cpu"):
    encoded_text = torch.tensor(encode(history), dtype=torch.int64).unsqueeze(0).to(device)
    output, kv_cache = model.generate_valid_chess_move(encoded_text, board, do_sample=True, temperature=0.1, kv_cache=kv_cache, return_kv_cache=True)
    output = output.squeeze()
    decoded_text = decode(output.tolist())
    match = re.search(fr"{round_number}\.([a-zA-Z0-9-\+\#\=]*)", decoded_text)
    new_move = match.group(1).strip()
    return new_move, kv_cache
    

In [None]:
class ChessFight:
    def __init__(self, model, elo=1500, with_kv_cache=False, with_valid_move=False, device="cpu"):
        self.device = device
        self.model = model
        self.model.eval()
        self.model.to(device)
        self.history = ";1-0#1."
        self.board = chess.Board()
        self.stockfish = Stockfish(
            path="/usr/games/stockfish",
            depth=3,
            parameters={
                "Threads": 2,
                "Minimum Thinking Time": 30,
                "UCI_Elo": elo,
                },
            )
        self.round_number = 1
        self.play = True
        self.with_kv_cache = with_kv_cache
        self.with_valid_move = with_valid_move
        self.kv_cache = None

    def parse_move(self, board_move):
        try:
            next_move = self.board.parse_san(board_move)
        except Exception as e:
            print("Error parsing move:", e)
            return None
        return next_move
    
    def check_win(self, player="stockfish"):
        if self.board.is_checkmate() and player == "LLM":
            print("Checkmate! LLM wins!")
            return 3
        if self.board.is_checkmate() and player == "stockfish":
            print("Checkmate! Stockfish wins!")
            return 2
        if self.board.is_stalemate():
            print("Stalemate!")
            return 1
        return 0
    
    def update_history(self, board_move, player="LLM"):

        if player == "LLM":
            self.history += board_move + " "
        else:
            self.history += board_move + " " + str(self.round_number+1) + "."
            
    def play_round(self):
        print("Round number:",self.round_number)

        if self.with_kv_cache:
            if self.with_valid_move:
                move_llm_txt, new_kv_cache = get_next_llm_kv_valid_chess(self.history, self.board, self.round_number, self.kv_cache, device=self.device)
            else:
                move_llm_txt, new_kv_cache = get_next_llm_kv(self.history, self.round_number, self.kv_cache, device=self.device)
            new_len = self.kv_cache[0][0].size(1)+len(move_llm_txt) if self.kv_cache else len(move_llm_txt)
            self.kv_cache = [(elt[0][:,:new_len,:], elt[1][:,:new_len,:]) for elt in new_kv_cache]
        else:
            if self.with_valid_move:
                move_llm_txt = get_next_llm_valid_chess(self.history, self.board, self.round_number, device=self.device)
            else:
                move_llm_txt = get_next_llm(self.history, self.round_number, device=self.device)

        move_llm = self.parse_move(move_llm_txt)
        if not move_llm:
            self.play = False
            print("Invalid move from LLM")
            return -1
        
        self.update_history(self.board.san(move_llm), player="LLM")
        self.board.push(move_llm)

        game_val = self.check_win(player="LLM")
        if game_val > 0:
            self.play = False
            return game_val
        
        self.stockfish.make_moves_from_current_position([move_llm])
        best = self.stockfish.get_best_move()
        self.stockfish.make_moves_from_current_position([best])
        move_stockfish = chess.Move.from_uci(best)
        self.update_history(self.board.san(move_stockfish), player="stockfish")
        self.board.push(move_stockfish)

        game_val = self.check_win(player="stockfish")
        if game_val > 0:
            self.play = False
            return game_val

        return 0
    
    def start(self):
        count = 5
        while self.play:
            val = self.play_round()
            while val == -1 and count > 0:
                print("Invalid move from LLM, retrying...")
                val = self.play_round()
                count -= 1
            self.round_number += 1
            # time.sleep(0.5)
            if val<=0 : 
                clear_output()
            display(self.board)

In [None]:
chess_fight = ChessFight(model, elo=1200, with_kv_cache=True, with_valid_move=True, device=device)
chess_fight.start()