In [6]:
import chess
import chess.pgn
import numpy as np
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import sys

In [7]:
class result_prediction(nn.Module):
    def __init__(self):
        super().__init__()
    
        self.both = nn.Sequential(
            nn.Linear(12 * 64 + 1, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
        )

    def forward(self, x1):
        x = torch.flatten( x1)
        return self.both(x)


def to_numpy(board):
    b = np.zeros((64)) 
    s = str(board)

    i = 0
    
    for x in s:
        
        if x == "K":
            b[i] = 1
        if x == "Q":
            b[i] = 2
        if x == "R":
            b[i] = 3
        if x == "B":
            b[i] = 4
        if x == "N":
            b[i] = 5
        if x == "P":
            b[i] = 6

        if x == "k":
            b[i] = 7
        if x == "q":
            b[i] = 8
        if x == "r":
            b[i] = 9
        if x == "b":
            b[i] = 10
        if x == "n":
            b[i] = 11
        if x == "p":
            b[i] = 12

        if x != "\n" and x != " ":    
            i = i + 1
    return b

def to_binary(bo, white, res):
    side_to_move = np.zeros(( 769 ))

    a = [1,2,3,4,5,6]
    b = [7,8,9,10,11,12]

    for x in range(0, len(a)):
        for y in range(64):
            side_to_move[x*64+y] = bo[y] == a[x]

    for x in range(0, len(b)):
        for y in range(64):
            side_to_move[(x+6)*64+y] = bo[y] == a[x]

    side_to_move[768] = white
    return res, side_to_move
def nth(i, x): 
    for a,b in enumerate(x):
        if a == i:
            return b

def count(i):
    return sum(1 for e in i)

def make_x_random_moves(n,board):
    b = board.copy()

    for x in range(n):
        lm = b.legal_moves
        c = lm.count()
        if c == 0:
            break
        r = random.randint(0,c)
        move = nth(r, lm) 
        move = chess.Move.from_uci(str(move)) 
        b.push(move)
    
    return b

pgn = open("datasets/lichess_db_standard_rated_2015-05.pgn", encoding="utf-8")

In [8]:
device = "cpu"
model = result_prediction()
model.load_state_dict(torch.load("model.pth", weights_only=True))
model.eval()

sd = "1/2-1/2"
sl = "0-1"
sw = "1-0"

In [9]:
for x in range(10):
    game = chess.pgn.read_game(pgn)
    board = game.board()


    total_move_num = count(game.mainline_moves())
    whitetomove = True

    for i, move in enumerate(game.mainline_moves()):
        board.push(move)
        whitetomove = whitetomove == False
        if i >= total_move_num-1:
            b = to_numpy(board)
            result = None
            if game.headers["Result"] == sd:
                result = 0
            if game.headers["Result"] == sl:
                result = -1
            if game.headers["Result"] == sw:
                result = 1

            y, stm = to_binary(b, whitetomove, result)

            stm = torch.tensor(stm).to(device).float()
            print(round(model(stm).item(),2))
            print(result)
            print()

-0.42
-1

-0.68
-1

1.02
1

-0.35
-1

0.89
1

-0.44
-1

-0.54
-1

-0.69
-1

0.97
1

0.59
1



In [10]:
from cairosvg import svg2png
import cv2

out = cv2.VideoWriter('project.mp4', -1, 15, (500,500))
print(out)
 
board = chess.Board()
whitetomove = True
counter = 0
while not board.is_checkmate() or not board.is_stalemate:
    lm = [ str(x) for x in board.legal_moves ]
    moves = []
    
    for x in lm:
        bo = board.copy()
        bo.push_san(x)
        b = to_numpy(board)

        y, stm = to_binary(b, whitetomove, 0)
        stm = torch.tensor(stm).to(device).float()

        moves.append(model(stm).item())
    
    whitetomove = whitetomove == False    
    if random.randint(0,10) < 9:
        idx = moves.index(max(moves))
        board.push_san(lm[idx])
    else:    
        board.push_san(lm[random.randint(0,len(lm)-1)])
    
    boardsvg = chess.svg.board(board=board)
    svg2png(bytestring=boardsvg, write_to='temp.png')
    png = cv2.imread('temp.png', 0) 
    
    for x in range(24):
        out.write(png)

    counter += 1

    if counter > 150:
        break


print(counter)
print(board.is_stalemate())
print(board.is_checkmate())
out.release()

< cv2.VideoWriter 000001C5C7344C90>
151
False
False
