# A replication of methods used in the AlphaGo paper "Mastering the game of Go with deep neural networks and tree search

## Game Environment

A Python-based game env taken from https://github.com/maxpumperla/deep_learning_and_the_game_of_go


### Point, Player

In [2]:
import enum
from collections import namedtuple

class Point(namedtuple('Point', 'row col')):
    def neighbors(self):
        return [
            Point(self.row - 1, self.col),
            Point(self.row + 1, self.col),
            Point(self.row, self.col - 1),
            Point(self.row, self.col + 1),
        ]
class Player(enum.Enum):
    black = 1
    white = 2

    @property
    def other(self):
        return Player.black if self == Player.white else Player.white

### GameResult, Territory

In [4]:
from __future__ import absolute_import
from collections import namedtuple

class Territory(object):
    def __init__(self, territory_map):
        self.num_black_territory = 0
        self.num_white_territory = 0
        self.num_black_stones = 0
        self.num_white_stones = 0
        self.num_dame = 0
        self.dame_points = []
        for point, status in territory_map.items():
            if status == Player.black:
                self.num_black_stones += 1
            elif status == Player.white:
                self.num_white_stones += 1
            elif status == 'territory_b':
                self.num_black_territory += 1
            elif status == 'territory_w':
                self.num_white_territory += 1
            elif status == 'dame':
                self.num_dame += 1
                self.dame_points.append(point)

class GameResult(namedtuple('GameResult', 'b w komi')):
    @property
    def winner(self):
        if self.b > self.w + self.komi:
            return Player.black
        if self.b < self.w + self.komi:
            return Player.white
        return None

    @property
    def winning_margin(self):
        w = self.w + self.komi
        return abs(self.b - w)

    def __str__(self):
        w = self.w + self.komi
        if self.b > w:
            return 'B+%.1f' % (self.b - w,)
        return 'W+%.1f' % (w - self.b,)

def _collect_region(start_pos, board, visited=None):
    if visited is None:
        visited = {}
    if start_pos in visited:
        return [], set()
    all_points = [start_pos]
    all_borders = set()
    visited[start_pos] = True
    here = board.get(start_pos)
    deltas = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    for delta_r, delta_c in deltas:
        next_p = Point(row=start_pos.row + delta_r, col=start_pos.col + delta_c)
        if not board.is_on_grid(next_p):
            continue
        neighbor = board.get(next_p)
        if neighbor == here:
            points, borders = _collect_region(next_p, board, visited)
            all_points += points
            all_borders |= borders
        else:
            all_borders.add(neighbor)
    return all_points, all_borders

def evaluate_territory(board):
    status = {}
    for r in range(1, board.num_rows + 1):
        for c in range(1, board.num_cols + 1):
            p = Point(row=r, col=c)
            if p in status:
                continue
            stone = board.get(p)
            if stone is not None:
                status[p] = board.get(p)
            else:
                group, neighbors = _collect_region(p, board)
                if len(neighbors) == 1:
                    neighbor_stone = neighbors.pop()
                    stone_str = 'b' if neighbor_stone == Player.black else 'w'
                    fill_with = 'terrtory_' + stone_str
                else:
                    fill_with = 'dame'
                for pos in group:
                    status[pos] = fill_with
    return Territory(status)

def compute_game_result(game_state):
    territory = evaluate_territory(game_state.board)
    return GameResult(
        territory.num_black_territory + territory.num_black_stones,
        territory.num_white_territory + territory.num_white_stones,
        komi=7.5)

### GameBoard, GameState, Move, GoString

In [5]:
import copy

class Board():
    def __init__(self, num_rows, num_cols):
        self.num_rows = num_rows
        self.num_cols = num_cols
        self._grid = {}
        self._hash = EMPTY_BOARD

    def _replace_string(self, new_string):
        for point in new_string.stones:
            self._grid[point] = new_string

    def _remove_string(self, string):
        for point in string.stones:
            for neighbor in point.neighbors():
                neighbor_string = self._grid.get(neighbor)
                if neighbor_string is None:
                    continue
                if neighbor_string is not string:
                    self._replace_string(neighbor_string.with_liberty(point))
            self._grid[point] = None
            self._hash ^= HASH_CODE[point, string.color]

    def zobrist_hash(self):
        return self._hash

    def place_stone(self, player, point):
        assert self.is_on_grid(point)
        assert self._grid.get(point) is None
        adjacent_same_color = []
        adjacent_opposite_color = []
        liberties = []
        for neighbor in point.neighbors():
            if not self.is_on_grid(neighbor):
                continue
            neighbor_string = self._grid.get(neighbor)
            if neighbor_string is None:
                liberties.append(neighbor)
            elif neighbor_string.color == player:
                if neighbor_string not in adjacent_same_color:
                    adjacent_same_color.append(neighbor_string)
            else:
                if neighbor_string not in adjacent_opposite_color:
                    adjacent_opposite_color.append(neighbor_string)
        new_string = GoString(player, [point], liberties)
        for same_color_string in adjacent_same_color:
            new_string = new_string.merged_with(same_color_string)
        for new_string_point in new_string.stones:
            self._grid[new_string_point] = new_string
        self._hash ^= HASH_CODE[point, player]
        for other_color_string in adjacent_opposite_color:
            replacement = other_color_string.without_liberty(point)
            if replacement.num_liberties:
                self._replace_string(other_color_string.without_liberty(point))
            else:
                self._remove_string(other_color_string)

    def is_on_grid(self, point):
        return 1 <= point.row <= self.num_rows and \
            1 <= point.col <= self.num_cols

    def get(self, point):
        string = self._grid.get(point)
        if string is None:
            return None
        return string.color

    def get_go_string(self, point):
        string = self._grid.get(point)
        if string is None:
            return None
        return string

class Move():
    def __init__(self, point=None, is_pass=False, is_resign=False):
        assert (point is not None) ^ is_pass ^ is_resign
        self.point = point
        self.is_play = (self.point is not None)
        self.is_pass = is_pass
        self.is_resign = is_resign

    @classmethod
    def play(cls, point):
        return Move(point=point)

    @classmethod
    def pass_turn(cls):
        return Move(is_pass=True)

    @classmethod
    def resign(cls):
        return Move(is_resign=True)

# Chain of connected stones (used to e.g. efficiently check for liberties)
class GoString():
    def __init__(self, color, stones, liberties):
        self.color = color
        self.stones = frozenset(stones)
        self.liberties = frozenset(liberties)

    def without_liberty(self, point):
        new_liberties = self.liberties - set([point])
        return GoString(self.color, self.stones, new_liberties)

    def with_liberty(self, point):
        new_liberties = self.liberties | set([point])
        return GoString(self.color, self.stones, new_liberties)

    def merged_with(self, go_string):
        assert go_string.color == self.color
        combined_stones = self.stones | go_string.stones
        return GoString(
            self.color,
            combined_stones,
            (self.liberties | go_string.liberties) - combined_stones
        )

    @property
    def num_liberties(self):
        return len(self.liberties)

    def __eq__(self, other):
        return isinstance(other, GoString) and \
            self.color == other.color and \
            self.stones == other.stones and \
            self.liberties == other.liberties

class GameState():
    def __init__(self, board, next_player, previous, move):
        self.board = board
        self.next_player = next_player
        self.previous_state = previous
        if self.previous_state is None:
            self.previous_states = frozenset()
        else:
            self.previous_states = frozenset(
                previous.previous_states |
                {(previous.next_player, previous.board.zobrist_hash())})
        self.last_move = move

    def apply_move(self, move):
        if move.is_play:
            next_board = copy.deepcopy(self.board)
            next_board.place_stone(self.next_player, move.point)
        else:
            next_board = self.board
        return GameState(next_board, self.next_player.other, self, move)

    def is_over(self):
        if self.last_move is None:
            return False
        if self.last_move.is_resign:
            return True
        second_last_move = self.previous_state.last_move
        if second_last_move is None:
            return False
        return self.last_move.is_pass and second_last_move.is_pass

    def is_move_self_capture(self, player, move):
        if not move.is_play:
            return False
        next_board = copy.deepcopy(self.board)
        next_board.place_stone(player, move.point)
        new_string = next_board.get_go_string(move.point)
        return new_string.num_liberties == 0

    def is_valid_move(self, move):
        if self.is_over():
            return False
        if move.is_pass or move.is_resign:
            return True
        return (
            self.board.get(move.point) is None and
            not self.is_move_self_capture(self.next_player, move) and
            not self.does_move_violate_ko(self.next_player, move))
    
    def legal_moves(self):
        if self.is_over():
            return []
        moves = []
        for row in range(1, self.board.num_rows + 1):
            for col in range(1, self.board.num_cols + 1):
                move = Move.play(Point(row, col))
                if self.is_valid_move(move):
                    moves.append(move)
        moves.append(Move.pass_turn())
        moves.append(Move.resign())
        return moves
    
    def winner(self):
        if not self.is_over():
            return None
        if self.last_move.is_resign:
            return self.next_player
        game_result = compute_game_result(self)
        return game_result.winner

    @classmethod
    def new_game(cls, board_size):
        if isinstance(board_size, int):
            board_size = (board_size, board_size)
        board = Board(*board_size)
        return GameState(board, Player.black, None, None)

    @property
    def situation(self):
        return (self.next_player, self.board)

    def does_move_violate_ko(self, player, move):
        if not move.is_play:
            return False
        next_board = copy.deepcopy(self.board)
        next_board.place_stone(player, move.point)
        next_situation = (player.other, next_board.zobrist_hash())
        return next_situation in self.previous_states

## Zobrist Hash

In [3]:
HASH_CODE = {
    (Point(row=1, col=1), Player.black): 5393977873783756346,
    (Point(row=1, col=1), Player.white): 5809908186856798082,
    (Point(row=1, col=2), Player.black): 7529245096038496553,
    (Point(row=1, col=2), Player.white): 4203604640916720214,
    (Point(row=1, col=3), Player.black): 7689426742090425927,
    (Point(row=1, col=3), Player.white): 2086893975018221701,
    (Point(row=1, col=4), Player.black): 2519303819741611959,
    (Point(row=1, col=4), Player.white): 6759747478641010517,
    (Point(row=1, col=5), Player.black): 5923826413836607084,
    (Point(row=1, col=5), Player.white): 6335637203333882852,
    (Point(row=1, col=6), Player.black): 5214099692239048175,
    (Point(row=1, col=6), Player.white): 3614103592443371537,
    (Point(row=1, col=7), Player.black): 3849677082517236653,
    (Point(row=1, col=7), Player.white): 5657592154080615132,
    (Point(row=1, col=8), Player.black): 4690764937906290494,
    (Point(row=1, col=8), Player.white): 3960441528920116343,
    (Point(row=1, col=9), Player.black): 1528580023562037700,
    (Point(row=1, col=9), Player.white): 6795660655311091843,
    (Point(row=1, col=10), Player.black): 2038381746311752170,
    (Point(row=1, col=10), Player.white): 1123885171509414104,
    (Point(row=1, col=11), Player.black): 1971213527857795068,
    (Point(row=1, col=11), Player.white): 5013069788436838336,
    (Point(row=1, col=12), Player.black): 6532060403916218765,
    (Point(row=1, col=12), Player.white): 1313983340929306023,
    (Point(row=1, col=13), Player.black): 6482131094716469960,
    (Point(row=1, col=13), Player.white): 5060957524475622739,
    (Point(row=1, col=14), Player.black): 7708434382856268031,
    (Point(row=1, col=14), Player.white): 4502877348597339059,
    (Point(row=1, col=15), Player.black): 1363754886070945335,
    (Point(row=1, col=15), Player.white): 6441414251647916185,
    (Point(row=1, col=16), Player.black): 3675395389505183553,
    (Point(row=1, col=16), Player.white): 8024407047403175625,
    (Point(row=1, col=17), Player.black): 1666791651634627427,
    (Point(row=1, col=17), Player.white): 8553805033900708981,
    (Point(row=1, col=18), Player.black): 4484631451415237281,
    (Point(row=1, col=18), Player.white): 2979820244129849588,
    (Point(row=1, col=19), Player.black): 2876264275760326021,
    (Point(row=1, col=19), Player.white): 6242463375661141726,
    (Point(row=2, col=1), Player.black): 6709422836882881335,
    (Point(row=2, col=1), Player.white): 2742968805440332251,
    (Point(row=2, col=2), Player.black): 5506212339005219961,
    (Point(row=2, col=2), Player.white): 6164349948806951056,
    (Point(row=2, col=3), Player.black): 2533967792650062203,
    (Point(row=2, col=3), Player.white): 7700847486124359210,
    (Point(row=2, col=4), Player.black): 3867236819875477334,
    (Point(row=2, col=4), Player.white): 6299277939177615327,
    (Point(row=2, col=5), Player.black): 7146337618158669635,
    (Point(row=2, col=5), Player.white): 1964240974756327606,
    (Point(row=2, col=6), Player.black): 347767546781907151,
    (Point(row=2, col=6), Player.white): 658810794641889467,
    (Point(row=2, col=7), Player.black): 2552597538017343647,
    (Point(row=2, col=7), Player.white): 9026327557579943499,
    (Point(row=2, col=8), Player.black): 8358613095927693150,
    (Point(row=2, col=8), Player.white): 3842000290834543361,
    (Point(row=2, col=9), Player.black): 5656467551902666949,
    (Point(row=2, col=9), Player.white): 5694105665440562368,
    (Point(row=2, col=10), Player.black): 6828832104324704314,
    (Point(row=2, col=10), Player.white): 4905373805124530338,
    (Point(row=2, col=11), Player.black): 1353146281674316264,
    (Point(row=2, col=11), Player.white): 357432456999811959,
    (Point(row=2, col=12), Player.black): 7420879869806222125,
    (Point(row=2, col=12), Player.white): 852733351160817442,
    (Point(row=2, col=13), Player.black): 741195307496222566,
    (Point(row=2, col=13), Player.white): 5895687820668112841,
    (Point(row=2, col=14), Player.black): 133533041508766811,
    (Point(row=2, col=14), Player.white): 7852120722573876428,
    (Point(row=2, col=15), Player.black): 5449983063381231285,
    (Point(row=2, col=15), Player.white): 3939770961425961474,
    (Point(row=2, col=16), Player.black): 8818507212704304617,
    (Point(row=2, col=16), Player.white): 1868379306986161529,
    (Point(row=2, col=17), Player.black): 4922458507825236663,
    (Point(row=2, col=17), Player.white): 6855188129468292450,
    (Point(row=2, col=18), Player.black): 6594562755818111569,
    (Point(row=2, col=18), Player.white): 7076337354022138785,
    (Point(row=2, col=19), Player.black): 4792087523504882222,
    (Point(row=2, col=19), Player.white): 4967787996953711806,
    (Point(row=3, col=1), Player.black): 409559989713330046,
    (Point(row=3, col=1), Player.white): 974106538742651670,
    (Point(row=3, col=2), Player.black): 7214295382553334637,
    (Point(row=3, col=2), Player.white): 2243063899979941088,
    (Point(row=3, col=3), Player.black): 6279612204540518511,
    (Point(row=3, col=3), Player.white): 1059153980578709206,
    (Point(row=3, col=4), Player.black): 6805386743698354674,
    (Point(row=3, col=4), Player.white): 4698297706194128889,
    (Point(row=3, col=5), Player.black): 2863121949156059219,
    (Point(row=3, col=5), Player.white): 1050005787526343157,
    (Point(row=3, col=6), Player.black): 1070257652289356962,
    (Point(row=3, col=6), Player.white): 5256231082289767435,
    (Point(row=3, col=7), Player.black): 5912065481421843592,
    (Point(row=3, col=7), Player.white): 3986322931305438869,
    (Point(row=3, col=8), Player.black): 2912916693058227159,
    (Point(row=3, col=8), Player.white): 4252865205536430649,
    (Point(row=3, col=9), Player.black): 4152469923788670414,
    (Point(row=3, col=9), Player.white): 8816900713618664906,
    (Point(row=3, col=10), Player.black): 2630058779803692874,
    (Point(row=3, col=10), Player.white): 6077336711699791501,
    (Point(row=3, col=11), Player.black): 2473162956209103943,
    (Point(row=3, col=11), Player.white): 1560792612851443297,
    (Point(row=3, col=12), Player.black): 4865433962044860302,
    (Point(row=3, col=12), Player.white): 8986429821088216591,
    (Point(row=3, col=13), Player.black): 6244626043292736993,
    (Point(row=3, col=13), Player.white): 7327180209888835161,
    (Point(row=3, col=14), Player.black): 7361934649523265023,
    (Point(row=3, col=14), Player.white): 6673082472323031533,
    (Point(row=3, col=15), Player.black): 8330395830591870302,
    (Point(row=3, col=15), Player.white): 4038698861011080614,
    (Point(row=3, col=16), Player.black): 3254622122792372141,
    (Point(row=3, col=16), Player.white): 3422860185267100986,
    (Point(row=3, col=17), Player.black): 27171439856566070,
    (Point(row=3, col=17), Player.white): 3897257892663221275,
    (Point(row=3, col=18), Player.black): 7001511514934680202,
    (Point(row=3, col=18), Player.white): 2921562228995293149,
    (Point(row=3, col=19), Player.black): 9093285488265791964,
    (Point(row=3, col=19), Player.white): 8966327084048429928,
    (Point(row=4, col=1), Player.black): 8276181643064889311,
    (Point(row=4, col=1), Player.white): 6439885832557472148,
    (Point(row=4, col=2), Player.black): 6877887143511363226,
    (Point(row=4, col=2), Player.white): 1832593332972063210,
    (Point(row=4, col=3), Player.black): 4797628863973581833,
    (Point(row=4, col=3), Player.white): 3989888329872392169,
    (Point(row=4, col=4), Player.black): 614680584744613110,
    (Point(row=4, col=4), Player.white): 4353601408583045201,
    (Point(row=4, col=5), Player.black): 6567494846456204349,
    (Point(row=4, col=5), Player.white): 3332059749137543268,
    (Point(row=4, col=6), Player.black): 6339902497103120878,
    (Point(row=4, col=6), Player.white): 4671056937145414347,
    (Point(row=4, col=7), Player.black): 6869633312231023389,
    (Point(row=4, col=7), Player.white): 5168561737872003725,
    (Point(row=4, col=8), Player.black): 8784236346344547734,
    (Point(row=4, col=8), Player.white): 6894454448461817389,
    (Point(row=4, col=9), Player.black): 2216442754566989128,
    (Point(row=4, col=9), Player.white): 384790405088178748,
    (Point(row=4, col=10), Player.black): 4807404968645902936,
    (Point(row=4, col=10), Player.white): 7538950995131630435,
    (Point(row=4, col=11), Player.black): 6565828035490618255,
    (Point(row=4, col=11), Player.white): 560157743088555264,
    (Point(row=4, col=12), Player.black): 1364948911328279608,
    (Point(row=4, col=12), Player.white): 6822997728613499617,
    (Point(row=4, col=13), Player.black): 3059418406854576943,
    (Point(row=4, col=13), Player.white): 3143272298701658593,
    (Point(row=4, col=14), Player.black): 4956511710974584863,
    (Point(row=4, col=14), Player.white): 6237243466258988852,
    (Point(row=4, col=15), Player.black): 5510083046629441538,
    (Point(row=4, col=15), Player.white): 8883192098662136887,
    (Point(row=4, col=16), Player.black): 8752852582298458334,
    (Point(row=4, col=16), Player.white): 1043982339746051354,
    (Point(row=4, col=17), Player.black): 4720112292813366147,
    (Point(row=4, col=17), Player.white): 1936034210720901247,
    (Point(row=4, col=18), Player.black): 943256428683494257,
    (Point(row=4, col=18), Player.white): 651807249302252844,
    (Point(row=4, col=19), Player.black): 3931661889398113915,
    (Point(row=4, col=19), Player.white): 4597340776103258030,
    (Point(row=5, col=1), Player.black): 595194255482594152,
    (Point(row=5, col=1), Player.white): 2346664626707363982,
    (Point(row=5, col=2), Player.black): 2608956254571183248,
    (Point(row=5, col=2), Player.white): 4739958023738092203,
    (Point(row=5, col=3), Player.black): 6451481994892099546,
    (Point(row=5, col=3), Player.white): 5449111700251338465,
    (Point(row=5, col=4), Player.black): 1535132115423330175,
    (Point(row=5, col=4), Player.white): 7415705628037074502,
    (Point(row=5, col=5), Player.black): 2148331470303358161,
    (Point(row=5, col=5), Player.white): 6874172295131535339,
    (Point(row=5, col=6), Player.black): 1583124744582528453,
    (Point(row=5, col=6), Player.white): 2823661675395206720,
    (Point(row=5, col=7), Player.black): 1016838572143578056,
    (Point(row=5, col=7), Player.white): 2167636868483863463,
    (Point(row=5, col=8), Player.black): 6959161909260731551,
    (Point(row=5, col=8), Player.white): 3770737924020034312,
    (Point(row=5, col=9), Player.black): 8078984159677323907,
    (Point(row=5, col=9), Player.white): 1501002364703759380,
    (Point(row=5, col=10), Player.black): 3862680792396219822,
    (Point(row=5, col=10), Player.white): 7210971028588143757,
    (Point(row=5, col=11), Player.black): 8982620145096306885,
    (Point(row=5, col=11), Player.white): 9077845211666620910,
    (Point(row=5, col=12), Player.black): 3814824378855469442,
    (Point(row=5, col=12), Player.white): 9009998477517604850,
    (Point(row=5, col=13), Player.black): 1104378123828383537,
    (Point(row=5, col=13), Player.white): 5612681327488320150,
    (Point(row=5, col=14), Player.black): 6405485393395554292,
    (Point(row=5, col=14), Player.white): 7906705770454684298,
    (Point(row=5, col=15), Player.black): 333163786216313552,
    (Point(row=5, col=15), Player.white): 2074570223376919718,
    (Point(row=5, col=16), Player.black): 4069116839454068023,
    (Point(row=5, col=16), Player.white): 6740812917189108882,
    (Point(row=5, col=17), Player.black): 7397720748128852538,
    (Point(row=5, col=17), Player.white): 2819445626367081139,
    (Point(row=5, col=18), Player.black): 3946232768391836340,
    (Point(row=5, col=18), Player.white): 185957405779036489,
    (Point(row=5, col=19), Player.black): 2668367148429208536,
    (Point(row=5, col=19), Player.white): 5046849423023203233,
    (Point(row=6, col=1), Player.black): 5236780538240254953,
    (Point(row=6, col=1), Player.white): 1700933403518321526,
    (Point(row=6, col=2), Player.black): 4087975939719812724,
    (Point(row=6, col=2), Player.white): 7742737189100913351,
    (Point(row=6, col=3), Player.black): 6266041176462277111,
    (Point(row=6, col=3), Player.white): 7740179331720822825,
    (Point(row=6, col=4), Player.black): 3535621807865789967,
    (Point(row=6, col=4), Player.white): 2863902286232919613,
    (Point(row=6, col=5), Player.black): 8646527938910427377,
    (Point(row=6, col=5), Player.white): 3508705032761589085,
    (Point(row=6, col=6), Player.black): 1861529986495637235,
    (Point(row=6, col=6), Player.white): 5185610610805231467,
    (Point(row=6, col=7), Player.black): 7910934926251634294,
    (Point(row=6, col=7), Player.white): 1418259942735627675,
    (Point(row=6, col=8), Player.black): 3826993986756202088,
    (Point(row=6, col=8), Player.white): 8891058428243753213,
    (Point(row=6, col=9), Player.black): 7715952192948925002,
    (Point(row=6, col=9), Player.white): 1951610596526998228,
    (Point(row=6, col=10), Player.black): 6247445946551473868,
    (Point(row=6, col=10), Player.white): 6379303498343873298,
    (Point(row=6, col=11), Player.black): 6902199809165214188,
    (Point(row=6, col=11), Player.white): 2928810189082639091,
    (Point(row=6, col=12), Player.black): 5984798129835678887,
    (Point(row=6, col=12), Player.white): 4960339944258760426,
    (Point(row=6, col=13), Player.black): 6397730786943417060,
    (Point(row=6, col=13), Player.white): 5148391062805140437,
    (Point(row=6, col=14), Player.black): 4964731085005199471,
    (Point(row=6, col=14), Player.white): 6152744930515284842,
    (Point(row=6, col=15), Player.black): 2993132224282123225,
    (Point(row=6, col=15), Player.white): 2364208787297440432,
    (Point(row=6, col=16), Player.black): 2872926936884471917,
    (Point(row=6, col=16), Player.white): 7140521628548206767,
    (Point(row=6, col=17), Player.black): 8881082997175684589,
    (Point(row=6, col=17), Player.white): 3194242279437021731,
    (Point(row=6, col=18), Player.black): 833506010508081772,
    (Point(row=6, col=18), Player.white): 5499281829804540286,
    (Point(row=6, col=19), Player.black): 7510573601119746717,
    (Point(row=6, col=19), Player.white): 7688868713374694687,
    (Point(row=7, col=1), Player.black): 5134367864289547617,
    (Point(row=7, col=1), Player.white): 5546761652964616685,
    (Point(row=7, col=2), Player.black): 8361470650111405789,
    (Point(row=7, col=2), Player.white): 4600575817047215590,
    (Point(row=7, col=3), Player.black): 5920970251139551047,
    (Point(row=7, col=3), Player.white): 5170050938567328701,
    (Point(row=7, col=4), Player.black): 8447381837064326321,
    (Point(row=7, col=4), Player.white): 4496754018973119671,
    (Point(row=7, col=5), Player.black): 7759894504057097273,
    (Point(row=7, col=5), Player.white): 7362408676865259730,
    (Point(row=7, col=6), Player.black): 8091516812944720585,
    (Point(row=7, col=6), Player.white): 937681080301387970,
    (Point(row=7, col=7), Player.black): 6466007903763180389,
    (Point(row=7, col=7), Player.white): 1434137069818838275,
    (Point(row=7, col=8), Player.black): 2462483960127589370,
    (Point(row=7, col=8), Player.white): 1330621330070498087,
    (Point(row=7, col=9), Player.black): 8714692177860131522,
    (Point(row=7, col=9), Player.white): 5598487330039369816,
    (Point(row=7, col=10), Player.black): 6387751976134814198,
    (Point(row=7, col=10), Player.white): 4839175085293953742,
    (Point(row=7, col=11), Player.black): 4009768377537954211,
    (Point(row=7, col=11), Player.white): 7261498075580989092,
    (Point(row=7, col=12), Player.black): 1433994774896454601,
    (Point(row=7, col=12), Player.white): 6956127723123283037,
    (Point(row=7, col=13), Player.black): 9152276347174896334,
    (Point(row=7, col=13), Player.white): 7612321844704865546,
    (Point(row=7, col=14), Player.black): 5536492350038567412,
    (Point(row=7, col=14), Player.white): 5103780139885100354,
    (Point(row=7, col=15), Player.black): 5146199694256476573,
    (Point(row=7, col=15), Player.white): 1146440997431173123,
    (Point(row=7, col=16), Player.black): 2410872820356173150,
    (Point(row=7, col=16), Player.white): 8468967471631997173,
    (Point(row=7, col=17), Player.black): 5326657457342798880,
    (Point(row=7, col=17), Player.white): 5940049075980419871,
    (Point(row=7, col=18), Player.black): 3383702678844309492,
    (Point(row=7, col=18), Player.white): 782083488923036607,
    (Point(row=7, col=19), Player.black): 8940476901160387407,
    (Point(row=7, col=19), Player.white): 6636181769984252847,
    (Point(row=8, col=1), Player.black): 2175086052015880094,
    (Point(row=8, col=1), Player.white): 5182718210478909815,
    (Point(row=8, col=2), Player.black): 6769877419412543242,
    (Point(row=8, col=2), Player.white): 7710431355269059732,
    (Point(row=8, col=3), Player.black): 6831758894284463886,
    (Point(row=8, col=3), Player.white): 5023970341691177955,
    (Point(row=8, col=4), Player.black): 6741675309883887699,
    (Point(row=8, col=4), Player.white): 2197415534784593997,
    (Point(row=8, col=5), Player.black): 6777397725432524384,
    (Point(row=8, col=5), Player.white): 6378947886149682363,
    (Point(row=8, col=6), Player.black): 3677954927341352626,
    (Point(row=8, col=6), Player.white): 2642052992721139607,
    (Point(row=8, col=7), Player.black): 7429463078881783589,
    (Point(row=8, col=7), Player.white): 3531299991136405110,
    (Point(row=8, col=8), Player.black): 1730176684838230933,
    (Point(row=8, col=8), Player.white): 4742196632708194391,
    (Point(row=8, col=9), Player.black): 8272907418282925258,
    (Point(row=8, col=9), Player.white): 4568249133546646978,
    (Point(row=8, col=10), Player.black): 2425833485083966808,
    (Point(row=8, col=10), Player.white): 2464496008068452125,
    (Point(row=8, col=11), Player.black): 1166236825577579994,
    (Point(row=8, col=11), Player.white): 8755696017873810055,
    (Point(row=8, col=12), Player.black): 3360954741871645211,
    (Point(row=8, col=12), Player.white): 1579160809117095473,
    (Point(row=8, col=13), Player.black): 3240483061421295842,
    (Point(row=8, col=13), Player.white): 6157732470412641843,
    (Point(row=8, col=14), Player.black): 5612749803550580195,
    (Point(row=8, col=14), Player.white): 872997176094403986,
    (Point(row=8, col=15), Player.black): 5038817978473765731,
    (Point(row=8, col=15), Player.white): 6849032463251242149,
    (Point(row=8, col=16), Player.black): 3126333456611605385,
    (Point(row=8, col=16), Player.white): 1607668012356518301,
    (Point(row=8, col=17), Player.black): 6130204063940933829,
    (Point(row=8, col=17), Player.white): 7276174816580560888,
    (Point(row=8, col=18), Player.black): 8302727268571218780,
    (Point(row=8, col=18), Player.white): 5107152615440002984,
    (Point(row=8, col=19), Player.black): 3408038075209981553,
    (Point(row=8, col=19), Player.white): 1799009080741760463,
    (Point(row=9, col=1), Player.black): 7543863305589459818,
    (Point(row=9, col=1), Player.white): 4831143822733028670,
    (Point(row=9, col=2), Player.black): 9146650584516432477,
    (Point(row=9, col=2), Player.white): 4777800004539369306,
    (Point(row=9, col=3), Player.black): 831008639309065582,
    (Point(row=9, col=3), Player.white): 6483569622522938257,
    (Point(row=9, col=4), Player.black): 919962239975731612,
    (Point(row=9, col=4), Player.white): 4494833599050192972,
    (Point(row=9, col=5), Player.black): 2705818628478499768,
    (Point(row=9, col=5), Player.white): 4854366284343923762,
    (Point(row=9, col=6), Player.black): 2186344428202592969,
    (Point(row=9, col=6), Player.white): 4717347223326583882,
    (Point(row=9, col=7), Player.black): 1344560774247871853,
    (Point(row=9, col=7), Player.white): 4792569208284827688,
    (Point(row=9, col=8), Player.black): 6072230198406231713,
    (Point(row=9, col=8), Player.white): 7690153008628290874,
    (Point(row=9, col=9), Player.black): 788962555121341532,
    (Point(row=9, col=9), Player.white): 7747934504986053130,
    (Point(row=9, col=10), Player.black): 7189795430232810600,
    (Point(row=9, col=10), Player.white): 4659700620955871126,
    (Point(row=9, col=11), Player.black): 5241598291302408127,
    (Point(row=9, col=11), Player.white): 2350505170769143398,
    (Point(row=9, col=12), Player.black): 625290367975994669,
    (Point(row=9, col=12), Player.white): 2619722814727204466,
    (Point(row=9, col=13), Player.black): 3833490124974817734,
    (Point(row=9, col=13), Player.white): 7057204456609825728,
    (Point(row=9, col=14), Player.black): 5043467040834593920,
    (Point(row=9, col=14), Player.white): 90641293599267669,
    (Point(row=9, col=15), Player.black): 2166311380672564335,
    (Point(row=9, col=15), Player.white): 1533194082477017293,
    (Point(row=9, col=16), Player.black): 1479942284986511491,
    (Point(row=9, col=16), Player.white): 7147945142977386594,
    (Point(row=9, col=17), Player.black): 6069951347394519832,
    (Point(row=9, col=17), Player.white): 9176919624143976444,
    (Point(row=9, col=18), Player.black): 4213364162046731365,
    (Point(row=9, col=18), Player.white): 140010096985573503,
    (Point(row=9, col=19), Player.black): 3559160881771774117,
    (Point(row=9, col=19), Player.white): 6420622308789233217,
    (Point(row=10, col=1), Player.black): 8193769264964446698,
    (Point(row=10, col=1), Player.white): 2945755644309646803,
    (Point(row=10, col=2), Player.black): 3055325339654451473,
    (Point(row=10, col=2), Player.white): 1819491752349900221,
    (Point(row=10, col=3), Player.black): 2075862157928515937,
    (Point(row=10, col=3), Player.white): 3790360661753336822,
    (Point(row=10, col=4), Player.black): 4307970204159696538,
    (Point(row=10, col=4), Player.white): 8551079606919301724,
    (Point(row=10, col=5), Player.black): 4006923773478274141,
    (Point(row=10, col=5), Player.white): 5527203000164812679,
    (Point(row=10, col=6), Player.black): 2216926060345003549,
    (Point(row=10, col=6), Player.white): 8585904118663887844,
    (Point(row=10, col=7), Player.black): 2136756164451110808,
    (Point(row=10, col=7), Player.white): 4881316629073290743,
    (Point(row=10, col=8), Player.black): 824604388971042841,
    (Point(row=10, col=8), Player.white): 7059131224257977686,
    (Point(row=10, col=9), Player.black): 8627852865962149814,
    (Point(row=10, col=9), Player.white): 8853818995444247566,
    (Point(row=10, col=10), Player.black): 1383226153637045749,
    (Point(row=10, col=10), Player.white): 149398191736682054,
    (Point(row=10, col=11), Player.black): 6155487804714620510,
    (Point(row=10, col=11), Player.white): 6834731035458783322,
    (Point(row=10, col=12), Player.black): 2926261996259094686,
    (Point(row=10, col=12), Player.white): 5997534056087441484,
    (Point(row=10, col=13), Player.black): 2854320166994049636,
    (Point(row=10, col=13), Player.white): 5245848157740311906,
    (Point(row=10, col=14), Player.black): 83348804910913138,
    (Point(row=10, col=14), Player.white): 644001774932798361,
    (Point(row=10, col=15), Player.black): 8991260279131768694,
    (Point(row=10, col=15), Player.white): 4942827517903700190,
    (Point(row=10, col=16), Player.black): 8275187656768181735,
    (Point(row=10, col=16), Player.white): 8262819563015115636,
    (Point(row=10, col=17), Player.black): 5241028636492453279,
    (Point(row=10, col=17), Player.white): 8759821912357491041,
    (Point(row=10, col=18), Player.black): 9130614826839327071,
    (Point(row=10, col=18), Player.white): 3165549531458223316,
    (Point(row=10, col=19), Player.black): 3828703932740890806,
    (Point(row=10, col=19), Player.white): 4442392114657903309,
    (Point(row=11, col=1), Player.black): 8364720543836484044,
    (Point(row=11, col=1), Player.white): 7110563149578423107,
    (Point(row=11, col=2), Player.black): 1740230811027023719,
    (Point(row=11, col=2), Player.white): 8750233206086649548,
    (Point(row=11, col=3), Player.black): 6459191711849081191,
    (Point(row=11, col=3), Player.white): 1333989939522666744,
    (Point(row=11, col=4), Player.black): 4404329805061720517,
    (Point(row=11, col=4), Player.white): 8031632707748625026,
    (Point(row=11, col=5), Player.black): 6315442374324997477,
    (Point(row=11, col=5), Player.white): 5921358150765893741,
    (Point(row=11, col=6), Player.black): 2829244137945086994,
    (Point(row=11, col=6), Player.white): 7282705646404900052,
    (Point(row=11, col=7), Player.black): 4280443328512852368,
    (Point(row=11, col=7), Player.white): 1022991342743351889,
    (Point(row=11, col=8), Player.black): 7871038442990540934,
    (Point(row=11, col=8), Player.white): 8048999973996160363,
    (Point(row=11, col=9), Player.black): 8349814250535577695,
    (Point(row=11, col=9), Player.white): 6903980952969054988,
    (Point(row=11, col=10), Player.black): 2636182925266198218,
    (Point(row=11, col=10), Player.white): 6597384852076955196,
    (Point(row=11, col=11), Player.black): 7304185660435474241,
    (Point(row=11, col=11), Player.white): 1161428769983383423,
    (Point(row=11, col=12), Player.black): 4764222324594565334,
    (Point(row=11, col=12), Player.white): 4973434098545015793,
    (Point(row=11, col=13), Player.black): 2464925395013663474,
    (Point(row=11, col=13), Player.white): 8070605117031420796,
    (Point(row=11, col=14), Player.black): 6226819779382658570,
    (Point(row=11, col=14), Player.white): 1768676603373485595,
    (Point(row=11, col=15), Player.black): 3774839020667568405,
    (Point(row=11, col=15), Player.white): 3612990894153626431,
    (Point(row=11, col=16), Player.black): 4680962640063627081,
    (Point(row=11, col=16), Player.white): 4395180480776461709,
    (Point(row=11, col=17), Player.black): 7891328264877981175,
    (Point(row=11, col=17), Player.white): 4784648832232185044,
    (Point(row=11, col=18), Player.black): 7227594185716364054,
    (Point(row=11, col=18), Player.white): 2676007520110268962,
    (Point(row=11, col=19), Player.black): 8651325328126583228,
    (Point(row=11, col=19), Player.white): 2180297362002720016,
    (Point(row=12, col=1), Player.black): 569072085233777300,
    (Point(row=12, col=1), Player.white): 3962892463713553636,
    (Point(row=12, col=2), Player.black): 6279939543482054459,
    (Point(row=12, col=2), Player.white): 7047196536008696563,
    (Point(row=12, col=3), Player.black): 8441247324807337428,
    (Point(row=12, col=3), Player.white): 510137263107572572,
    (Point(row=12, col=4), Player.black): 2785091596871451816,
    (Point(row=12, col=4), Player.white): 5797756979063517338,
    (Point(row=12, col=5), Player.black): 6830589117034638793,
    (Point(row=12, col=5), Player.white): 6538062896471236161,
    (Point(row=12, col=6), Player.black): 1247793357136783780,
    (Point(row=12, col=6), Player.white): 5473746680142625577,
    (Point(row=12, col=7), Player.black): 8792374339307062852,
    (Point(row=12, col=7), Player.white): 8696409571971981958,
    (Point(row=12, col=8), Player.black): 1986788646529006433,
    (Point(row=12, col=8), Player.white): 1334048313728698490,
    (Point(row=12, col=9), Player.black): 5571764755728880579,
    (Point(row=12, col=9), Player.white): 8110195065241758761,
    (Point(row=12, col=10), Player.black): 7000303332689460577,
    (Point(row=12, col=10), Player.white): 3342697079614579858,
    (Point(row=12, col=11), Player.black): 5396600951650151682,
    (Point(row=12, col=11), Player.white): 3753217233276505022,
    (Point(row=12, col=12), Player.black): 418451402852089519,
    (Point(row=12, col=12), Player.white): 4496544675151207696,
    (Point(row=12, col=13), Player.black): 2138561593212247665,
    (Point(row=12, col=13), Player.white): 8559523709731213966,
    (Point(row=12, col=14), Player.black): 8040701969649459609,
    (Point(row=12, col=14), Player.white): 9122687236308934980,
    (Point(row=12, col=15), Player.black): 4991206904094122250,
    (Point(row=12, col=15), Player.white): 6940242564819928736,
    (Point(row=12, col=16), Player.black): 8673068435035150089,
    (Point(row=12, col=16), Player.white): 4940897150134805099,
    (Point(row=12, col=17), Player.black): 3178625453368722151,
    (Point(row=12, col=17), Player.white): 5821876020350324999,
    (Point(row=12, col=18), Player.black): 1085714775192227787,
    (Point(row=12, col=18), Player.white): 1025787196633417229,
    (Point(row=12, col=19), Player.black): 6113798611850165943,
    (Point(row=12, col=19), Player.white): 7742354712700579012,
    (Point(row=13, col=1), Player.black): 3668223328997750726,
    (Point(row=13, col=1), Player.white): 7010043925118139735,
    (Point(row=13, col=2), Player.black): 3561096209509449159,
    (Point(row=13, col=2), Player.white): 3794104467522863383,
    (Point(row=13, col=3), Player.black): 2083499440834748592,
    (Point(row=13, col=3), Player.white): 1104332826341311480,
    (Point(row=13, col=4), Player.black): 2337085801575977441,
    (Point(row=13, col=4), Player.white): 73496622814212846,
    (Point(row=13, col=5), Player.black): 986213037559210631,
    (Point(row=13, col=5), Player.white): 7858645873690090304,
    (Point(row=13, col=6), Player.black): 877885740364188784,
    (Point(row=13, col=6), Player.white): 7491462567832668631,
    (Point(row=13, col=7), Player.black): 1453245397849934843,
    (Point(row=13, col=7), Player.white): 9087626689514535797,
    (Point(row=13, col=8), Player.black): 943211089584136659,
    (Point(row=13, col=8), Player.white): 8564804778257589909,
    (Point(row=13, col=9), Player.black): 4281933374166808643,
    (Point(row=13, col=9), Player.white): 7582464849786823960,
    (Point(row=13, col=10), Player.black): 8081287552430406341,
    (Point(row=13, col=10), Player.white): 8975853951813914918,
    (Point(row=13, col=11), Player.black): 3297979934470085459,
    (Point(row=13, col=11), Player.white): 7415112116974850988,
    (Point(row=13, col=12), Player.black): 1045244640488930215,
    (Point(row=13, col=12), Player.white): 8633968582272730966,
    (Point(row=13, col=13), Player.black): 1667507781336455107,
    (Point(row=13, col=13), Player.white): 6085240277011882085,
    (Point(row=13, col=14), Player.black): 7232288043213468840,
    (Point(row=13, col=14), Player.white): 7184750073175391472,
    (Point(row=13, col=15), Player.black): 5538978544170635168,
    (Point(row=13, col=15), Player.white): 5040051481022218170,
    (Point(row=13, col=16), Player.black): 7600248111604936189,
    (Point(row=13, col=16), Player.white): 8586459941279921361,
    (Point(row=13, col=17), Player.black): 6376364379461146126,
    (Point(row=13, col=17), Player.white): 9041413789618077747,
    (Point(row=13, col=18), Player.black): 8743922232471933112,
    (Point(row=13, col=18), Player.white): 5535640510605537944,
    (Point(row=13, col=19), Player.black): 1558301110599766838,
    (Point(row=13, col=19), Player.white): 8797117176370424670,
    (Point(row=14, col=1), Player.black): 4949300437981660406,
    (Point(row=14, col=1), Player.white): 2708028507768117065,
    (Point(row=14, col=2), Player.black): 2399672244464706017,
    (Point(row=14, col=2), Player.white): 3131976715504297069,
    (Point(row=14, col=3), Player.black): 3762606156379134836,
    (Point(row=14, col=3), Player.white): 3018762030201101018,
    (Point(row=14, col=4), Player.black): 1553947336874678791,
    (Point(row=14, col=4), Player.white): 2787843346506338002,
    (Point(row=14, col=5), Player.black): 7286262140063523680,
    (Point(row=14, col=5), Player.white): 5601603084934773677,
    (Point(row=14, col=6), Player.black): 891332335745024460,
    (Point(row=14, col=6), Player.white): 146106825927916238,
    (Point(row=14, col=7), Player.black): 3003229725021016381,
    (Point(row=14, col=7), Player.white): 3950268400717425653,
    (Point(row=14, col=8), Player.black): 7555050606774196460,
    (Point(row=14, col=8), Player.white): 1933272040598731843,
    (Point(row=14, col=9), Player.black): 5687500800284396077,
    (Point(row=14, col=9), Player.white): 6064827202472988142,
    (Point(row=14, col=10), Player.black): 5460555210682265907,
    (Point(row=14, col=10), Player.white): 8523072018176531610,
    (Point(row=14, col=11), Player.black): 6969379328052051755,
    (Point(row=14, col=11), Player.white): 3797815354895629104,
    (Point(row=14, col=12), Player.black): 6102290786862268343,
    (Point(row=14, col=12), Player.white): 4451384081571233788,
    (Point(row=14, col=13), Player.black): 8867534245202963011,
    (Point(row=14, col=13), Player.white): 4062219166163574875,
    (Point(row=14, col=14), Player.black): 402600950333698869,
    (Point(row=14, col=14), Player.white): 4708727251635506003,
    (Point(row=14, col=15), Player.black): 685159067419537059,
    (Point(row=14, col=15), Player.white): 1831983877595190195,
    (Point(row=14, col=16), Player.black): 5693685234172893631,
    (Point(row=14, col=16), Player.white): 8606113477882278453,
    (Point(row=14, col=17), Player.black): 2559688780599510603,
    (Point(row=14, col=17), Player.white): 4916962426494787981,
    (Point(row=14, col=18), Player.black): 8239702277090843823,
    (Point(row=14, col=18), Player.white): 446427342887807507,
    (Point(row=14, col=19), Player.black): 8918191794051205877,
    (Point(row=14, col=19), Player.white): 5779966796775430704,
    (Point(row=15, col=1), Player.black): 5806602293025768267,
    (Point(row=15, col=1), Player.white): 4427894744050950930,
    (Point(row=15, col=2), Player.black): 3321210467146544466,
    (Point(row=15, col=2), Player.white): 1037996862123848054,
    (Point(row=15, col=3), Player.black): 2483790886051235894,
    (Point(row=15, col=3), Player.white): 865665247864907514,
    (Point(row=15, col=4), Player.black): 4709701331639358649,
    (Point(row=15, col=4), Player.white): 2520365484542978487,
    (Point(row=15, col=5), Player.black): 6726027909448868866,
    (Point(row=15, col=5), Player.white): 4334313633995489245,
    (Point(row=15, col=6), Player.black): 944638889554577315,
    (Point(row=15, col=6), Player.white): 1923788071090552491,
    (Point(row=15, col=7), Player.black): 2895976155820269644,
    (Point(row=15, col=7), Player.white): 7837513138217925255,
    (Point(row=15, col=8), Player.black): 7165664909107152910,
    (Point(row=15, col=8), Player.white): 928708421883905487,
    (Point(row=15, col=9), Player.black): 1694638656652867722,
    (Point(row=15, col=9), Player.white): 26294624292270423,
    (Point(row=15, col=10), Player.black): 2733509329805299637,
    (Point(row=15, col=10), Player.white): 1627504245446971116,
    (Point(row=15, col=11), Player.black): 5223103574740689621,
    (Point(row=15, col=11), Player.white): 7595236035274156841,
    (Point(row=15, col=12), Player.black): 7354618009789493265,
    (Point(row=15, col=12), Player.white): 3243165254510905460,
    (Point(row=15, col=13), Player.black): 9159748120463681730,
    (Point(row=15, col=13), Player.white): 43421204626817835,
    (Point(row=15, col=14), Player.black): 6664617357472592159,
    (Point(row=15, col=14), Player.white): 3484507983815796621,
    (Point(row=15, col=15), Player.black): 2064130300265745803,
    (Point(row=15, col=15), Player.white): 8586522619632753596,
    (Point(row=15, col=16), Player.black): 3231989606919965256,
    (Point(row=15, col=16), Player.white): 4410167066375254348,
    (Point(row=15, col=17), Player.black): 2229290957528159863,
    (Point(row=15, col=17), Player.white): 4955066102017269963,
    (Point(row=15, col=18), Player.black): 3077993148742326725,
    (Point(row=15, col=18), Player.white): 4617963469056641699,
    (Point(row=15, col=19), Player.black): 1535516273607427100,
    (Point(row=15, col=19), Player.white): 2587048028802906261,
    (Point(row=16, col=1), Player.black): 6693036795721593615,
    (Point(row=16, col=1), Player.white): 5797022974636371173,
    (Point(row=16, col=2), Player.black): 7309325935971511255,
    (Point(row=16, col=2), Player.white): 2048107833560567985,
    (Point(row=16, col=3), Player.black): 6916829147904847036,
    (Point(row=16, col=3), Player.white): 8976175690437770791,
    (Point(row=16, col=4), Player.black): 6337420236892083738,
    (Point(row=16, col=4), Player.white): 2699283977808103350,
    (Point(row=16, col=5), Player.black): 3138296515049670113,
    (Point(row=16, col=5), Player.white): 3900930764281472876,
    (Point(row=16, col=6), Player.black): 5243158829920910810,
    (Point(row=16, col=6), Player.white): 461920753856065231,
    (Point(row=16, col=7), Player.black): 2433567883755867339,
    (Point(row=16, col=7), Player.white): 8573361817426861282,
    (Point(row=16, col=8), Player.black): 1819618894604938,
    (Point(row=16, col=8), Player.white): 3991310233808842791,
    (Point(row=16, col=9), Player.black): 7711076169703051025,
    (Point(row=16, col=9), Player.white): 3853759001566152858,
    (Point(row=16, col=10), Player.black): 8507000970286247929,
    (Point(row=16, col=10), Player.white): 2224774790709653561,
    (Point(row=16, col=11), Player.black): 5257511071529447461,
    (Point(row=16, col=11), Player.white): 4568571098998923585,
    (Point(row=16, col=12), Player.black): 8391026675239904811,
    (Point(row=16, col=12), Player.white): 4478371109982441149,
    (Point(row=16, col=13), Player.black): 2592227198464777178,
    (Point(row=16, col=13), Player.white): 1373660215734540323,
    (Point(row=16, col=14), Player.black): 206163682928864712,
    (Point(row=16, col=14), Player.white): 4621022966321476065,
    (Point(row=16, col=15), Player.black): 1149238353779835973,
    (Point(row=16, col=15), Player.white): 8827492870593442684,
    (Point(row=16, col=16), Player.black): 69575139572224472,
    (Point(row=16, col=16), Player.white): 6441564680740325644,
    (Point(row=16, col=17), Player.black): 3089056130085564940,
    (Point(row=16, col=17), Player.white): 5952435253754032443,
    (Point(row=16, col=18), Player.black): 4737329103633574835,
    (Point(row=16, col=18), Player.white): 4343128122064860124,
    (Point(row=16, col=19), Player.black): 5628283739398868314,
    (Point(row=16, col=19), Player.white): 4022764529815418408,
    (Point(row=17, col=1), Player.black): 2799542350728883106,
    (Point(row=17, col=1), Player.white): 4210942142931533333,
    (Point(row=17, col=2), Player.black): 2339443054731527774,
    (Point(row=17, col=2), Player.white): 4928853112448883686,
    (Point(row=17, col=3), Player.black): 4417392374555331603,
    (Point(row=17, col=3), Player.white): 2507453166652721573,
    (Point(row=17, col=4), Player.black): 8507922001163833867,
    (Point(row=17, col=4), Player.white): 8998915572759011258,
    (Point(row=17, col=5), Player.black): 2594797118287840610,
    (Point(row=17, col=5), Player.white): 3515420103254932759,
    (Point(row=17, col=6), Player.black): 1028139137336838295,
    (Point(row=17, col=6), Player.white): 7711582185037469640,
    (Point(row=17, col=7), Player.black): 5659287222566492269,
    (Point(row=17, col=7), Player.white): 414963994641988728,
    (Point(row=17, col=8), Player.black): 1572352295549779740,
    (Point(row=17, col=8), Player.white): 8653964348753312095,
    (Point(row=17, col=9), Player.black): 1285644318249553858,
    (Point(row=17, col=9), Player.white): 8943714325861617477,
    (Point(row=17, col=10), Player.black): 6633177518122826638,
    (Point(row=17, col=10), Player.white): 5784574738634166660,
    (Point(row=17, col=11), Player.black): 8428433168480251645,
    (Point(row=17, col=11), Player.white): 8264794900765456725,
    (Point(row=17, col=12), Player.black): 1723187226414585962,
    (Point(row=17, col=12), Player.white): 7722974962003877798,
    (Point(row=17, col=13), Player.black): 8843550870650650496,
    (Point(row=17, col=13), Player.white): 7826916245428642595,
    (Point(row=17, col=14), Player.black): 563167361599960654,
    (Point(row=17, col=14), Player.white): 231523301481916518,
    (Point(row=17, col=15), Player.black): 4432773687497887257,
    (Point(row=17, col=15), Player.white): 3902615190797491444,
    (Point(row=17, col=16), Player.black): 3123967568891051730,
    (Point(row=17, col=16), Player.white): 8291618119520037897,
    (Point(row=17, col=17), Player.black): 2877564329571250234,
    (Point(row=17, col=17), Player.white): 7446782146894086930,
    (Point(row=17, col=18), Player.black): 5105815048566790300,
    (Point(row=17, col=18), Player.white): 620447092550959207,
    (Point(row=17, col=19), Player.black): 6815000078046116340,
    (Point(row=17, col=19), Player.white): 6151574574213299950,
    (Point(row=18, col=1), Player.black): 6571480042821441929,
    (Point(row=18, col=1), Player.white): 1199990554784515382,
    (Point(row=18, col=2), Player.black): 3796849565242642751,
    (Point(row=18, col=2), Player.white): 730627388821549642,
    (Point(row=18, col=3), Player.black): 1751528162981747310,
    (Point(row=18, col=3), Player.white): 9222944527743712463,
    (Point(row=18, col=4), Player.black): 1814376202105907330,
    (Point(row=18, col=4), Player.white): 4384078969581574670,
    (Point(row=18, col=5), Player.black): 3034085094723097573,
    (Point(row=18, col=5), Player.white): 6927961953433350838,
    (Point(row=18, col=6), Player.black): 6250120405369796891,
    (Point(row=18, col=6), Player.white): 7548122758175413070,
    (Point(row=18, col=7), Player.black): 3687503192185331146,
    (Point(row=18, col=7), Player.white): 3253751124438592240,
    (Point(row=18, col=8), Player.black): 7791409198905727389,
    (Point(row=18, col=8), Player.white): 8250324432776851912,
    (Point(row=18, col=9), Player.black): 1439293085965194443,
    (Point(row=18, col=9), Player.white): 617240904654191960,
    (Point(row=18, col=10), Player.black): 276630585824513566,
    (Point(row=18, col=10), Player.white): 8352247052107831969,
    (Point(row=18, col=11), Player.black): 661676028462227840,
    (Point(row=18, col=11), Player.white): 3657069844824784408,
    (Point(row=18, col=12), Player.black): 7088512060665037665,
    (Point(row=18, col=12), Player.white): 7416935933712485590,
    (Point(row=18, col=13), Player.black): 2318469128889327316,
    (Point(row=18, col=13), Player.white): 1249426624150088220,
    (Point(row=18, col=14), Player.black): 2267060888769604199,
    (Point(row=18, col=14), Player.white): 821882225288170376,
    (Point(row=18, col=15), Player.black): 3319741941550431328,
    (Point(row=18, col=15), Player.white): 2105230784839636855,
    (Point(row=18, col=16), Player.black): 2301242191607843604,
    (Point(row=18, col=16), Player.white): 5981497215870719022,
    (Point(row=18, col=17), Player.black): 98990970224676302,
    (Point(row=18, col=17), Player.white): 2178126089052175656,
    (Point(row=18, col=18), Player.black): 6449007811764407198,
    (Point(row=18, col=18), Player.white): 7805198354947116183,
    (Point(row=18, col=19), Player.black): 1041111610000268296,
    (Point(row=18, col=19), Player.white): 2441419540456318234,
    (Point(row=19, col=1), Player.black): 8628942558952620980,
    (Point(row=19, col=1), Player.white): 6503925635287259248,
    (Point(row=19, col=2), Player.black): 6548525928563290312,
    (Point(row=19, col=2), Player.white): 253622147386190057,
    (Point(row=19, col=3), Player.black): 7848312172346985629,
    (Point(row=19, col=3), Player.white): 5754420781802719712,
    (Point(row=19, col=4), Player.black): 1270488872375764706,
    (Point(row=19, col=4), Player.white): 7141894458017281044,
    (Point(row=19, col=5), Player.black): 5231528925201219548,
    (Point(row=19, col=5), Player.white): 2041343553201251536,
    (Point(row=19, col=6), Player.black): 644245936179045818,
    (Point(row=19, col=6), Player.white): 6113961115927663443,
    (Point(row=19, col=7), Player.black): 8089143129423172906,
    (Point(row=19, col=7), Player.white): 6427655425807228751,
    (Point(row=19, col=8), Player.black): 1706564856287370943,
    (Point(row=19, col=8), Player.white): 5478388786794426104,
    (Point(row=19, col=9), Player.black): 2602425697660956605,
    (Point(row=19, col=9), Player.white): 7624639913686172582,
    (Point(row=19, col=10), Player.black): 5745308943663018072,
    (Point(row=19, col=10), Player.white): 658288214186153350,
    (Point(row=19, col=11), Player.black): 4226459480241676938,
    (Point(row=19, col=11), Player.white): 168356440089928826,
    (Point(row=19, col=12), Player.black): 4111708675202059066,
    (Point(row=19, col=12), Player.white): 5723015285475349776,
    (Point(row=19, col=13), Player.black): 4374722413264218178,
    (Point(row=19, col=13), Player.white): 3074794750348349624,
    (Point(row=19, col=14), Player.black): 7210981354065565695,
    (Point(row=19, col=14), Player.white): 6369287181126618340,
    (Point(row=19, col=15), Player.black): 3939014091425425180,
    (Point(row=19, col=15), Player.white): 4450833838688552406,
    (Point(row=19, col=16), Player.black): 5408413225150116867,
    (Point(row=19, col=16), Player.white): 961401055240003418,
    (Point(row=19, col=17), Player.black): 6517207004494975912,
    (Point(row=19, col=17), Player.white): 2813539809172697790,
    (Point(row=19, col=18), Player.black): 1635317647869146951,
    (Point(row=19, col=18), Player.white): 6596044404321289826,
    (Point(row=19, col=19), Player.black): 4775892994827151184,
    (Point(row=19, col=19), Player.white): 2824421132400747346,
}

EMPTY_BOARD = 0

## 4 Plane Encoder

Original AlphaGo uses a 48 plane encoder, we use 4 plane encoder, which is also mentioned in the paper.

In [6]:
import numpy as np

"""
Feature name            num of planes   Description
Stone colour            3               Player stone / opponent stone / empty
Ones                    1               A constant plane filled with 1
"""

FEATURE_OFFSETS = {
    "stone_color": 0,
    "ones": 3,
    "current_player_color": 4,
    "legal_moves": 5
}


def offset(feature):
    return FEATURE_OFFSETS[feature]


class FourplaneEncoder:
    def __init__(self, board_size=(19, 19), use_player_plane=True, use_legal_moves=True):
        self.board_width, self.board_height = board_size
        self.use_player_plane = use_player_plane
        self.use_legal_moves = use_legal_moves
        self.num_planes = 4 + use_player_plane + use_legal_moves

    def name(self):
        return 'fourplane'

    def encode(self, game_state):
        board_tensor = np.zeros((self.num_planes, self.board_height, self.board_width))
        
        # Set empty cells plane (default to empty)
        board_tensor[offset("stone_color") + 2] = 1
        
        # Iterate over occupied points only (much faster for sparse boards)
        next_player = game_state.next_player
        opponent = next_player.other
        stone_color_offset = offset("stone_color")
        
        for point, go_string in game_state.board._grid.items():
            if go_string is None:
                continue
            r = point.row - 1
            c = point.col - 1
            if go_string.color == next_player:
                board_tensor[stone_color_offset][r][c] = 1
                board_tensor[stone_color_offset + 2][r][c] = 0  # Not empty
            elif go_string.color == opponent:
                board_tensor[stone_color_offset + 1][r][c] = 1
                board_tensor[stone_color_offset + 2][r][c] = 0  # Not empty
        
        # Set ones plane once (moved outside loop)
        board_tensor[offset("ones")] = 1
        
        # Set player plane once (moved outside loop)
        if self.use_player_plane and next_player == Player.black:
            board_tensor[offset("current_player_color")] = 1
        
        # Set legal moves - optimized inline computation to avoid expensive deep copies
        if self.use_legal_moves:
            if not game_state.is_over():
                legal_moves_offset = offset("legal_moves")
                board = game_state.board
                
                # Fast inline legal move checking without deep copies
                for row in range(1, board.num_rows + 1):
                    for col in range(1, board.num_cols + 1):
                        point = Point(row, col)
                        
                        # Fast check: point must be empty
                        if board._grid.get(point) is not None:
                            continue
                        
                        # Check self-capture without deep copy
                        has_liberty = False
                        would_capture = False
                        friendly_strings = []
                        
                        for neighbor in point.neighbors():
                            if not board.is_on_grid(neighbor):
                                continue
                            neighbor_string = board._grid.get(neighbor)
                            if neighbor_string is None:
                                has_liberty = True
                                break
                            elif neighbor_string.color == next_player:
                                friendly_strings.append(neighbor_string)
                            else:  # opponent string
                                if neighbor_string.num_liberties == 1:
                                    would_capture = True
                        
                        # If has liberty, not self-capture
                        if not has_liberty:
                            # Check if all friendly strings would have no liberties
                            if friendly_strings and all(s.num_liberties == 1 for s in friendly_strings):
                                if not would_capture:
                                    continue  # Self-capture, skip
                        
                        # Check ko violation - simplified heuristic for performance
                        # Full ko check requires deep copy, so we use a fast heuristic:
                        # If last move captured exactly one stone and we're trying to recapture
                        # at that same position, it's likely a ko violation
                        if would_capture and game_state.last_move and game_state.last_move.is_play:
                            # Check if we're trying to play at the last move position
                            # (which would be recapturing after a single-stone capture)
                            if point == game_state.last_move.point:
                                # Count how many stones we'd capture
                                captured_stones = 0
                                for neighbor in point.neighbors():
                                    if not board.is_on_grid(neighbor):
                                        continue
                                    neighbor_string = board._grid.get(neighbor)
                                    if neighbor_string and neighbor_string.color == opponent:
                                        if neighbor_string.num_liberties == 1:
                                            captured_stones += len(neighbor_string.stones)
                                # If capturing exactly one stone at last move position, likely ko
                                if captured_stones == 1:
                                    continue  # Likely ko violation, skip
                        
                        # Valid move - set it
                        r = row - 1
                        c = col - 1
                        board_tensor[legal_moves_offset][r][c] = 1

        return board_tensor

    def ones(self):
        return np.ones((1, self.board_height, self.board_width))

    def zeros(self):
        return np.zeros((1, self.board_height, self.board_width))

    def encode_point(self, point):
        return self.board_width * (point.row - 1) + (point.col - 1)

    def decode_point_index(self, index):
        row = index // self.board_width
        col = index % self.board_width
        return Point(row=row + 1, col=col + 1)

    def num_points(self):
        return self.board_width * self.board_height

    def shape(self):
        return self.num_planes, self.board_height, self.board_width


def create(board_size):
    return FourplaneEncoder(board_size)


## Supervised Learning Policy Network

 - Convolution layer with rectifier non-linearities.
 - CNN with 13 layers
 - Final softmax applied to legal moves
 - Data has been split into 4:1 train to test ratio. (In AlphaGo it is 28:1)
 - Pass moves has been excluded from the dataset
 - All 8 reflections and rotations has been applied and precomputed. We randomly sample mini batch from the augmented sample
 - Asynchronous stochastic gradient descent to minimize the log likelihood
 - Learning parameter is initialized to 0.03 and halved every 30k steps. (In AlphaGo, learning rate is 0.003 and halved every 80 mil training steps)
 - mini batch size of 16 as mentioned in the paper
 - zero momentum

In [8]:
import torch.nn as nn
import torch.nn.functional as F
import torch

class SLPolicyNetwork(nn.Module):
    def __init__(self, features=5, filters=192):
        super(SLPolicyNetwork, self).__init__()
        self.features = features
        self.filters = filters
        self.first_layer = nn.Conv2d(self.features, self.filters, kernel_size=5, stride=1, padding=2)

        self.hidden_layers = nn.ModuleList([
            nn.Conv2d(self.filters, self.filters, kernel_size=3, stride=1, padding=1)
            for _ in range(11)
        ])

        self.final_layer = nn.Conv2d(self.filters, 1, kernel_size=1, stride=1)

    def forward(self, x):
        # extract legal moves from the data tensor
        legal_moves = x[:, 5, :, :]
        x = F.relu(self.first_layer(x))
        for layer in self.hidden_layers:
            x = F.relu(layer(x))
        # Logits are the unnormalized scores output by the network for each possible move before softmax
        policy_logits = self.final_layer(x)  # (batch, 1, board_size, board_size)

        # apply legal move mask
        # Reshape and apply log_softmax for NLLLoss compatibility as per Alphago paper
        batch_size = policy_logits.size(0)
        policy_logits = policy_logits.view(batch_size, -1)
        legal_move_mask = legal_moves.view(batch_size, -1)
        # apply log softmax to only legal moves
        masked_policy_logits = torch.where(
            legal_move_mask.bool(),
            policy_logits,
            torch.full_like(policy_logits, float('-inf'))
        )
        log_probs = F.log_softmax(masked_policy_logits, dim=1)
        return log_probs

## Supervised Learning Policy Trainer

In [33]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

class SLPolicyTrainer:
    def __init__(self, model):
        # initialize hyperparams
        self.model = model
        self.optimizer = None
        self.criterion = None
        self.scheduler = None

    def initialize(self):
        self.optimizer = optim.SGD(
            params = self.model.parameters(),
            lr = 0.03
        )
        self.scheduler = StepLR(self.optimizer, step_size=30000, gamma=0.5)
        self.criterion = nn.NLLLoss()
        return
    def train(self, loader):
        self.model.train()
        total_loss = 0
        correct_predictions = 0
        total_sample = 0
        for (x, y) in loader:
            # Cast input to float32
            x = x.to(torch.float32)
            self.optimizer.zero_grad()
            log_probs = self.model(x)

            loss = self.criterion(log_probs, y)

            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            total_loss += loss.item()
            total_sample += y.size(0)
            prediction = log_probs.argmax(dim=1)
            correct_predictions += (prediction == y).sum().item()

        avg_loss = total_loss / len(loader)
        acc = correct_predictions / total_sample

        return avg_loss, acc
    def evaluate(self, loader):
        self.model.eval()
        total_loss = 0
        total_sample = 0
        correct_predictions = 0
        with torch.no_grad():
            for (x, y) in loader:
                # Cast input to float32
                x = x.to(torch.float32)
                log_probs = self.model(x)
                loss = self.criterion(log_probs, y)
                prediction = log_probs.argmax(dim=1)
                total_loss += loss.item()
                total_sample += y.size(0)
                correct_predictions += (prediction == y).sum().item()


        avg_loss = total_loss / len(loader)
        acc = correct_predictions / total_sample
        return avg_loss, acc

## DataLoader for Go

In [34]:
import numpy as np
import torch
import torch.utils.data as td

__all__ = [
    'GoDataLoader'
]

class GoDataLoader:
    def __init__(self, feature_path, label_path = None):
        self.feature_path = feature_path
        self.label_path = label_path

    def load_data(self):
        features = np.load(self.feature_path)
        features_tensor = torch.from_numpy(features)

        if self.label_path:
            labels = np.load(self.label_path)
            labels_tensor = torch.from_numpy(labels).to(torch.int64) # Cast labels to torch.int64
            # this dataset can be directly used with torch's DataLoader
            dataset = td.TensorDataset(features_tensor, labels_tensor)
        else:
            dataset = td.TensorDataset(features_tensor)
        return dataset

## Training Loop of Supervised Learning Policy

In [35]:
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
import torch.nn as nn

print("Load model")
model = SLPolicyNetwork(features=6)
num_of_epochs = 10

# prepare the dataset
print("fetch and load data")
training_feature_path = '/kaggle/input/alphago-kgs-200k/KGS-2019_04-19-1255-_train_features.npy'
training_label_path = '/kaggle/input/alphago-kgs-200k/KGS-2019_04-19-1255-_train_labels.npy'
test_feature_path = '/kaggle/input/alphago-kgs-200k/KGS-2019_04-19-1255-_test_features.npy'
test_label_path = '/kaggle/input/alphago-kgs-200k/KGS-2019_04-19-1255-_test_labels.npy'
training_dataset = GoDataLoader(training_feature_path, training_label_path).load_data()
test_dataset = GoDataLoader(test_feature_path, test_label_path).load_data()

# load the dataset to a loader
# this is what we will pass to our Trainer
training_loader = DataLoader(training_dataset, batch_size=16, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

# load the trainer
trainer = SLPolicyTrainer(model)
trainer.initialize()

for epoch in range(num_of_epochs):
    training_loss, training_acc = trainer.train(training_loader)
    test_loss, test_acc = trainer.evaluate(test_loader)
    print(f"epoch: {epoch}, training loss: {training_loss}, training acc: {training_acc}, test loss: {test_loss}, test_acc: {test_acc}")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': trainer.optimizer.state_dict(),
        'train_acc': training_acc,
        'test_acc': test_acc,
    }, f'sl_policy_epoch_{epoch}.pth')

Load model
fetch and load data
epoch: 0, training loss: 5.589934656644437, training acc: 0.007685738684884714, test loss: 5.406435972935444, test_acc: 0.007679180887372013
epoch: 1, training loss: 5.580736010554707, training acc: 0.008539709649871904, test loss: 5.383652764397699, test_acc: 0.017918088737201365
epoch: 2, training loss: 5.528502023260748, training acc: 0.01345004269854825, test loss: 5.363201173576149, test_acc: 0.012798634812286689
epoch: 3, training loss: 5.358327323666205, training acc: 0.024551665243381725, test loss: 5.363026000357963, test_acc: 0.011092150170648464
epoch: 4, training loss: 5.058352058658014, training acc: 0.035866780529461996, test loss: 5.2731228583567855, test_acc: 0.016211604095563138
epoch: 5, training loss: 4.863961750329965, training acc: 0.04099060631938514, test loss: 5.180677111084397, test_acc: 0.01962457337883959
epoch: 6, training loss: 4.668683294550144, training acc: 0.057643040136635355, test loss: 5.123422384262085, test_acc: 0.031

In [9]:
# check if gpu is available

import torch

print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"Current CUDA device: {torch.cuda.current_device()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")

CUDA available: True
CUDA device count: 2
Current CUDA device: 0
CUDA device name: Tesla T4


## Supervised Learning Policy Agent

In [13]:
import torch

class SLPolicyAgent:
    def __init__(self, model):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = model.to(self.device)
        self.encoder = FourplaneEncoder()

    def select_move(self, game_state):
        assert isinstance(game_state, GameState)
        encoded_game_state = self.encoder.encode(game_state)
        x = torch.from_numpy(encoded_game_state)
        x = x.unsqueeze(0)
        x = x.to(torch.float32)
        x = x.to(self.device)
        log_probs = self.model(x)
        log_probs = log_probs.squeeze(0)
        legal_moves = game_state.legal_moves()
        legal_moves = legal_moves[:-2]
        if len(legal_moves) == 0:
            return None, float('-inf')

        legal_moves_mask = torch.zeros(log_probs.size(0), dtype=torch.bool, device=self.device)
        
        for move in legal_moves:
            if move is not None and move.is_play:
                encoded_point = self.encoder.encode_point(move.point)
                legal_moves_mask[encoded_point] = True

        log_probs = log_probs.masked_fill(~legal_moves_mask, float('-inf'))
        action = torch.argmax(log_probs)
        log_prob = log_probs[action]
        return action, log_prob
        

## Reinforcement learning of Policy Network

 - We use the same network as for Supervised Learning
 - We initialize the policy to be theta ( which is the supervised learning model)
 - We also use an opponent pool to randomly select opponent for mini batch training as mentioned in the paper
 - Also, using baseline default to 0 for the first pass. AlphaGo uses value network as baseline in the second pass
 - 10k episodes of 128 mini batch size
 - We add a model to the opponent pool after every 500 episodes
 - Used REINFORCE algorithm with stochastic gradient ascent for policy updates.

In [None]:
import torch.optim as optim
import torch
import random

class RLPolicyTrainer:
    def __init__(self, baseline = 0.0):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = SLPolicyNetwork(features=6).to(self.device)
        self.opponent_pool = []
        self.optimizer = None
        self.baseline = baseline
        self.episodes = 1
        self.mini_batch = 1
        self.board_size = 19
        self.encoder = FourplaneEncoder()

    def initialize(self):
        model_file = torch.load('/kaggle/input/sl-policy-epoch-9/pytorch/default/1/sl_policy_epoch_9.pth', map_location=self.device)
        self.model.load_state_dict(model_file['model_state_dict'])
        self.optimizer = optim.SGD(params=self.model.parameters(), lr=0.03)
        opponent_model = SLPolicyNetwork(features=6).to(self.device)
        opponent_model.load_state_dict(model_file['model_state_dict'])
        self.opponent_pool.append(opponent_model)

    def train_batch(self):
        opponent_model = random.choice(self.opponent_pool)
        opponent_model.to(self.device)
        opponent_model.eval()
        self.model.eval()

        streams = [torch.cuda.Stream(device=self.device) for _ in range(2)]
        
        agent = SLPolicyAgent(self.model)
        opponent = SLPolicyAgent(opponent_model)
        
        expected_reward = 0
        total_wins = 0
        trajectories_list = []
        rewards_list = []

        for i in range(self.mini_batch):
            print(f"starting game: {i+1}")

            stream = streams[i % 2]
            with torch.cuda.stream(stream):
                game_state = GameState.new_game(self.board_size)
                trajectory = []
                while not game_state.is_over():
                    if game_state.next_player == Player.black:
                        action, _ = agent.select_move(game_state)
                        if action is None:
                            # player does not have any legal moves left apart from resigning or pass
                            move = Move.resign()
                            point = None
                            break
                        point = self.encoder.decode_point_index(action.item())
                        trajectory.append((action, game_state))
                    else:
                        action, log_prob = opponent.select_move(game_state)
                        if action is None:
                            move = Move.resign()
                            point = None
                            break
                        point = self.encoder.decode_point_index(action.item())
                        
                        if point is not None:
                            move = Move(point)
                        game_state = game_state.apply_move(move)
                winner = game_state.winner()
                print(f"Winner: {winner}")
                reward = 1 if winner == Player.black else -1
                trajectories_list.append(trajectory)
                rewards_list.append(reward)

        torch.cuda.synchronize(device=self.device)

        for trajectory, reward in zip(trajectories_list, rewards_list):
            if reward == 1:
                total_wins += 1
            for (action, ) in trajectory:
                expected_reward += (reward - self.baseline) * log_prob

        expected_reward /= self.mini_batch
        win_rate = total_wins / self.mini_batch
        
        return expected_reward, win_rate

    def train(self):
        for i in range(self.episodes):
            print(f"starting episode: {i+1}")
            expected_reward, win_rate = self.train_batch()
            self.model.train()
            self.optimizer.zero_grad()
            expected_loss = -expected_reward
            expected_loss.backward()
            self.optimizer.step()

            torch.cuda.empty_cache()
            
            print(f"expected loss: {expected_loss}, accuracy: {win_rate}")
            
            if i % 500 == 0:
                opponent_model = SLPolicyNetwork(features=6)
                opponent_model.load_state_dict(self.model.state_dict())
                self.opponent_pool.append(opponent_model)
                
        
                    

rl_trainer = RLPolicyTrainer()
rl_trainer.initialize()
rl_trainer.train()
        
        
        

starting episode: 1
starting game: 1
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_38/1981055840.py", line 112, in <cell line: 0>
    rl_trainer.train()
  File "/tmp/ipykernel_38/1981055840.py", line 91, in train
    expected_reward, win_rate = self.train_batch()
                                ^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_38/1981055840.py", line 50, in train_batch
    action, _ = agent.select_move(game_state)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_38/730542582.py", line None, in select_move
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
    stb = value._render_traceback_()
          ^^^^^^^