In [2]:
from stockfish import Stockfish
import numpy as np

In [11]:
strong_stockfish = Stockfish(
    path='/home/tmek1244/InternalProjects/stockfish_15_linux_x64_avx2/stockfish_15_x64_avx2',
    depth=18,
    parameters={
        "Threads": 6,
        "Hash": 1024*4
    }
)
weak_stockfish = Stockfish(
    path='/home/tmek1244/InternalProjects/stockfish_15_linux_x64_avx2/stockfish_15_x64_avx2',
    depth=3,
    parameters={
        "Threads": 1,
        "Hash": 16
    }
)

In [12]:
class Node:
    moves: 'Node'

    def __init__(self, move, parent=None, top_moves=3):
        self.move = move
        self.parent = parent
        self.children = []
        self.all_moves = self.get_moves()
        self.top_moves = top_moves
    
    def get_moves(self) -> list[str]:
        if self.parent:
            return self.parent.get_moves() + [self.move]
        else:
            return []
    
    def create_children(self, depth=1):
        if depth == 0:
            return
        weak_stockfish.set_position(self.all_moves)
        for move in np.random.choice(weak_stockfish.get_top_moves(self.top_moves), self.top_moves, replace=False):
            self.children.append(Node(move['Move'], self, self.top_moves))
            self.children[-1].create_children(depth-1)
        
    def get_curr_position(self) -> np.ndarray:
        result = np.zeros((2, 6, 8, 8))
        weak_stockfish.set_position(self.all_moves)
        for row in range(1, 9):
            for i, col in enumerate(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']):
                piece = weak_stockfish.get_what_is_on_square(col+str(row))
                if not piece:
                    continue
                
                color = 1 if 'a' < piece.value < 'z' else 0
                piece_id = None
                if piece.value.lower() == 'k':
                    piece_id = 0
                elif piece.value.lower() == 'q':
                    piece_id = 1
                elif piece.value.lower() == 'r':
                    piece_id = 2
                elif piece.value.lower() == 'b':
                    piece_id = 3
                elif piece.value.lower() == 'n':
                    piece_id = 4
                elif piece.value.lower() == 'p':
                    piece_id = 5
                result[color, piece_id, row-1, i] = 1
        return result
    
    def get_eval(self):
        strong_stockfish.set_position(self.all_moves)
        evaluation = strong_stockfish.get_evaluation()
        return evaluation["value"] if evaluation["type"] == "cp" else 100*(evaluation["value"]//abs(evaluation["value"]))
    
    def export(self, depth=1, file="data.csv"):
        if depth == 0:
            return
        strong_stockfish.set_position(self.all_moves)
        positions = self.get_curr_position().flatten()
        evaluation = self.get_eval()
        all = np.append(positions, [evaluation])
        with open(file, "a") as f:
            np.savetxt(f, all.reshape(1, all.shape[0]), delimiter=',', fmt='%d')
        print(".", end='')
        for child in self.children:
            child.export(depth-1, file)
        

In [13]:
root = Node(None, None, top_moves=2)
root.create_children(depth=15)

In [14]:
root.export(depth=15, file="data1.csv")

........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................