In [23]:
import chess
import chess.pgn
import numpy as np
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import sys

In [24]:
class result_prediction(nn.Module):
    def __init__(self):
        super().__init__()
    
        self.both = nn.Sequential(
            nn.Linear(12 * 64 + 1, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
        )

    def forward(self, x1):
        x = torch.flatten( x1)
        return self.both(x)

def to_numpy(board):
    b = np.zeros((64)) 
    s = str(board)

    i = 0
    
    for x in s:
        
        if x == "K":
            b[i] = 1
        if x == "Q":
            b[i] = 2
        if x == "R":
            b[i] = 3
        if x == "B":
            b[i] = 4
        if x == "N":
            b[i] = 5
        if x == "P":
            b[i] = 6

        if x == "k":
            b[i] = 7
        if x == "q":
            b[i] = 8
        if x == "r":
            b[i] = 9
        if x == "b":
            b[i] = 10
        if x == "n":
            b[i] = 11
        if x == "p":
            b[i] = 12

        if x != "\n" and x != " ":    
            i = i + 1
    return b

def to_binary(bo, white, res):
    side_to_move = np.zeros(( 769 ))

    a = [1,2,3,4,5,6]
    b = [7,8,9,10,11,12]

    for x in range(0, len(a)):
        for y in range(64):
            side_to_move[x*64+y] = bo[y] == a[x]

    for x in range(0, len(b)):
        for y in range(64):
            side_to_move[(x+6)*64+y] = bo[y] == a[x]

    side_to_move[768] = white
    return res, side_to_move
def nth(i, x): 
    for a,b in enumerate(x):
        if a == i:
            return b

def count(i):
    return sum(1 for e in i)

def make_x_random_moves(n,board):
    b = board.copy()

    for x in range(n):
        lm = b.legal_moves
        c = lm.count()
        if c == 0:
            break
        r = random.randint(0,c)
        move = nth(r, lm) 
        move = chess.Move.from_uci(str(move)) 
        b.push(move)
    
    return b

pgn = open("datasets/lichess_games.pgn", encoding="utf-8")

In [25]:
device = "cuda:0"
model = result_prediction()
model.load_state_dict(torch.load("self_play.pth", weights_only=True, map_location=torch.device(device)))
model.eval()
model.to(device)

sd = "1/2-1/2"
sl = "0-1"
sw = "1-0"

In [26]:
for x in range(10):
    game = chess.pgn.read_game(pgn)
    board = game.board()


    total_move_num = count(game.mainline_moves())
    whitetomove = True

    for i, move in enumerate(game.mainline_moves()):
        board.push(move)
        whitetomove = whitetomove == False
        if i >= total_move_num-1:
            b = to_numpy(board)
            result = None
            if game.headers["Result"] == sd:
                result = 0
            if game.headers["Result"] == sl:
                result = -1
            if game.headers["Result"] == sw:
                result = 1

            y, stm = to_binary(b, whitetomove, result)

            stm = torch.tensor(stm).to(device).float()
            print(round(model(stm).item(),2))
            print(result)
            print()

0.15
0

0.84
-1

-0.44
0

-1.08
0

-1.43
0

-1.17
-1

0.56
0

0.69
0

-0.99
0

-0.23
0



In [27]:
from cairosvg import svg2png
import cv2

out = cv2.VideoWriter('project.mp4', cv2.VideoWriter_fourcc(*'mp4v') , 15, (390,390))
 
board = chess.Board()
whitetomove = True
counter = 0
while not board.is_checkmate() or not board.is_stalemate():
    lm = [ str(x) for x in board.legal_moves ]
    moves = []
    if len(lm) == 0:
        break

    for x in lm:
        bo = board.copy()
        bo.push_san(x)
        b = to_numpy(board)

        y, stm = to_binary(b, whitetomove, 0)
        stm = torch.tensor(stm).to(device).float()

        moves.append(model(stm).item())
    
    whitetomove = whitetomove == False    

    if random.randint(0,10) < 5:
        if whitetomove:
            idx = moves.index(min(moves))
        else:
            idx = moves.index(max(moves))

        print(moves[idx])
        board.push_san(lm[idx])
    else:    
        board.push_san(lm[random.randint(0,len(lm)-1)])
    
    boardsvg = chess.svg.board(board=board)
    svg2png(bytestring=boardsvg, write_to='temp.jpg')
    png = cv2.imread('temp.jpg') 
    
    for x in range(5):
        out.write(png)

    counter += 1
    if counter == 300:
        break


print(counter)
print(board.is_stalemate())
print(board.is_checkmate())

out.release()

-0.7082256078720093
-0.0810806155204773
0.3677518367767334
-0.32326987385749817
-0.3949330449104309
-0.11798325181007385
-0.19963878393173218
-0.32326987385749817
-0.3949330449104309
-0.19963878393173218
-0.8361948728561401
-0.9968838691711426
-1.140926718711853
-0.7236970663070679
-0.8361948728561401
-0.9968838691711426
-1.140926718711853
0.43221986293792725
0.27382099628448486
1.139859914779663
0.9747387170791626
0.9204158782958984
0.7603559494018555
1.139859914779663
0.1792927384376526
-0.0494999885559082
-0.7741411328315735
-0.5839671492576599
0.4325202703475952
0.11784827709197998
-0.15200692415237427
-0.190618097782135
-0.3723065257072449
0.13312453031539917
-0.6930558085441589
-0.8965153694152832
-0.14931154251098633
-0.3919924199581146
0.019315123558044434
-0.1986645758152008
-0.18663913011550903
0.019315123558044434
0.04153287410736084
-0.14345532655715942
0.019315123558044434
-0.02812439203262329
-0.9915192723274231
-1.069309949874878
0.14815747737884521
-0.02812439203262329
