In [66]:
from transformer import GPT
from utils import CfgNode as CN
import json
import torch
from safetensors.torch import load_model
import re
import chess
from stockfish import Stockfish
from IPython.display import display, clear_output
import pickle
import time

In [None]:
conf = 'config.json'
meta = 'meta.pkl'

C = CN()

with open(conf) as f:
    config = json.load(f)

with open(meta, 'rb') as f:
    meta = pickle.load(f)

C.n_layer = config["n_layer"]
C.n_embd = config["n_embd"]
C.n_head = config["n_head"]
C.embd_pdrop = config["embd_pdrop"]
C.resid_pdrop = config["resid_pdrop"]
C.attn_pdrop = config["attn_pdrop"]
C.block_size = 768
C.vocab_size = meta['vocab_size']

stoi = meta['stoi']
itos = meta['itos']

def encode(s):
    return [stoi[c] for c in s]
def decode(l):
    return ''.join([itos[i] for i in l])

In [None]:
model = GPT(C).eval()
load_model(model, "out/final.pt/model.safetensors")

In [68]:
def get_next_llm(history, round_number):
    encoded_text = torch.tensor(encode(history), dtype=torch.int64).unsqueeze(0)
    output = model.generate(encoded_text, 10, do_sample=True).squeeze()
    decoded_text = decode(output.tolist())
    match = re.search(fr"{round_number}\.(\S*) ", decoded_text)
    new_move = match.group(1).strip()
    return new_move

def get_next_stockfish(next_move):
    try:
        board_move = board.parse_san(next_move)
        check = True 
    except Exception as err:
        print(f"Illegal move {err=}, {type(err)=}")
        check = False
    if check:
        board.push(board_move)
        stockfish.make_moves_from_current_position([ board_move ])
        best = stockfish.get_best_move()
        stockfish.make_moves_from_current_position([ best ])
        move = chess.Move.from_uci(best)
        st_san = board.san(move)
        board.push(move)
        clear_output()
        return (next_move, st_san)
    else:
        print(f"{next_move} is not correct")
        return (next_move, None)
    

In [69]:
class ChessFight:
    def __init__(self, model, elo=1500):
        self.model = model
        self.history = "1."
        self.board = chess.Board()
        self.stockfish = Stockfish(
            path="/usr/games/stockfish",
            depth=3,
            parameters={
                "Threads": 2,
                "Minimum Thinking Time": 30,
                "UCI_Elo": elo,
                },
            )

In [None]:
board = chess.Board()

elo = 1500
stockfish = Stockfish(
    path="/usr/games/stockfish",
    depth=3,
    parameters={
        "Threads": 2,
        "Minimum Thinking Time": 30,
        "UCI_Elo": elo,
        },
)
history = ";3#1-0#1."
round_number = 1
play = True

while play:
    move_stockfish = None
    while not move_stockfish:
        next_move = get_next_llm(history, round_number)
        _, move_stockfish = get_next_stockfish(next_move)
        if move_stockfish and move_stockfish[-1] == "#":
            play = False
            break
    round_number += 1
    history = f"{history}{next_move} {move_stockfish} {round_number}."
    display(board)
    time.sleep(0.5)

print(history)