# A replication of methods used in the AlphaGo paper "Mastering the game of Go with deep neural networks and tree search

In [None]:
!pip install -q line_profiler
%load_ext line_profiler


## Game Environment

A Python-based game env taken from https://github.com/maxpumperla/deep_learning_and_the_game_of_go


### Point, Player

In [None]:
import enum
from collections import namedtuple

class Point(namedtuple('Point', 'row col')):
    def neighbors(self):
        return [
            Point(self.row - 1, self.col),
            Point(self.row + 1, self.col),
            Point(self.row, self.col - 1),
            Point(self.row, self.col + 1),
        ]
class Player(enum.Enum):
    black = 1
    white = 2

    @property
    def other(self):
        return Player.black if self == Player.white else Player.white

### GameResult, Territory

In [None]:
from __future__ import absolute_import
from collections import namedtuple

class Territory(object):
    def __init__(self, territory_map):
        self.num_black_territory = 0
        self.num_white_territory = 0
        self.num_black_stones = 0
        self.num_white_stones = 0
        self.num_dame = 0
        self.dame_points = []
        for point, status in territory_map.items():
            if status == Player.black:
                self.num_black_stones += 1
            elif status == Player.white:
                self.num_white_stones += 1
            elif status == 'territory_b':
                self.num_black_territory += 1
            elif status == 'territory_w':
                self.num_white_territory += 1
            elif status == 'dame':
                self.num_dame += 1
                self.dame_points.append(point)

class GameResult(namedtuple('GameResult', 'b w komi')):
    @property
    def winner(self):
        if self.b > self.w + self.komi:
            return Player.black
        if self.b < self.w + self.komi:
            return Player.white
        return None

    @property
    def winning_margin(self):
        w = self.w + self.komi
        return abs(self.b - w)

    def __str__(self):
        w = self.w + self.komi
        if self.b > w:
            return 'B+%.1f' % (self.b - w,)
        return 'W+%.1f' % (w - self.b,)

def _collect_region(start_pos, board, visited=None):
    if visited is None:
        visited = {}
    if start_pos in visited:
        return [], set()
    all_points = [start_pos]
    all_borders = set()
    visited[start_pos] = True
    here = board.get(start_pos)
    deltas = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    for delta_r, delta_c in deltas:
        next_p = Point(row=start_pos.row + delta_r, col=start_pos.col + delta_c)
        if not board.is_on_grid(next_p):
            continue
        neighbor = board.get(next_p)
        if neighbor == here:
            points, borders = _collect_region(next_p, board, visited)
            all_points += points
            all_borders |= borders
        else:
            all_borders.add(neighbor)
    return all_points, all_borders

def evaluate_territory(board):
    status = {}
    for r in range(1, board.num_rows + 1):
        for c in range(1, board.num_cols + 1):
            p = Point(row=r, col=c)
            if p in status:
                continue
            stone = board.get(p)
            if stone is not None:
                status[p] = board.get(p)
            else:
                group, neighbors = _collect_region(p, board)
                if len(neighbors) == 1:
                    neighbor_stone = neighbors.pop()
                    stone_str = 'b' if neighbor_stone == Player.black else 'w'
                    fill_with = 'terrtory_' + stone_str
                else:
                    fill_with = 'dame'
                for pos in group:
                    status[pos] = fill_with
    return Territory(status)

def compute_game_result(game_state):
    territory = evaluate_territory(game_state.board)
    return GameResult(
        territory.num_black_territory + territory.num_black_stones,
        territory.num_white_territory + territory.num_white_stones,
        komi=7.5)

### GameBoard, GameState, Move, GoString

In [None]:
import copy

class Board():
    def __init__(self, num_rows, num_cols):
        self.num_rows = num_rows
        self.num_cols = num_cols
        self._grid = {}
        self._hash = EMPTY_BOARD

    # replace string should be immutable, which means that it should return a new grid
    def _replace_string(self, new_string):
        for point in new_string.stones:
            self._grid[point] = new_string

    def _remove_string(self, string):
        for point in string.stones:
            for neighbor in point.neighbors():
                neighbor_string = self._grid.get(neighbor)
                if neighbor_string is None:
                    continue
                if neighbor_string is not string:
                    self._replace_string(neighbor_string.with_liberty(point))
            self._grid[point] = None
            self._hash ^= HASH_CODE[point, string.color]

    def zobrist_hash(self):
        return self._hash

    def place_stone(self, player, point):
        assert self.is_on_grid(point)
        assert self._grid.get(point) is None
        adjacent_same_color = []
        adjacent_opposite_color = []
        liberties = []
        for neighbor in point.neighbors():
            if not self.is_on_grid(neighbor):
                continue
            neighbor_string = self._grid.get(neighbor)
            if neighbor_string is None:
                liberties.append(neighbor)
            elif neighbor_string.color == player:
                if neighbor_string not in adjacent_same_color:
                    adjacent_same_color.append(neighbor_string)
            else:
                if neighbor_string not in adjacent_opposite_color:
                    adjacent_opposite_color.append(neighbor_string)
        new_string = GoString(player, [point], liberties)
        for same_color_string in adjacent_same_color:
            new_string = new_string.merged_with(same_color_string)
        for new_string_point in new_string.stones:
            self._grid[new_string_point] = new_string
        self._hash ^= HASH_CODE[point, player]
        for other_color_string in adjacent_opposite_color:
            replacement = other_color_string.without_liberty(point)
            if replacement.num_liberties:
                self._replace_string(other_color_string.without_liberty(point))
            else:
                self._remove_string(other_color_string)

    def is_on_grid(self, point):
        return 1 <= point.row <= self.num_rows and \
            1 <= point.col <= self.num_cols

    def get(self, point):
        string = self._grid.get(point)
        if string is None:
            return None
        return string.color

    def get_go_string(self, point):
        string = self._grid.get(point)
        if string is None:
            return None
        return string

class Move():
    def __init__(self, point=None, is_pass=False, is_resign=False):
        assert (point is not None) ^ is_pass ^ is_resign
        self.point = point
        self.is_play = (self.point is not None)
        self.is_pass = is_pass
        self.is_resign = is_resign

    @classmethod
    def play(cls, point):
        return Move(point=point)

    @classmethod
    def pass_turn(cls):
        return Move(is_pass=True)

    @classmethod
    def resign(cls):
        return Move(is_resign=True)

# Chain of connected stones (used to e.g. efficiently check for liberties)
class GoString():
    def __init__(self, color, stones, liberties):
        self.color = color
        self.stones = frozenset(stones)
        self.liberties = frozenset(liberties)

    def without_liberty(self, point):
        new_liberties = self.liberties - set([point])
        return GoString(self.color, self.stones, new_liberties)

    def with_liberty(self, point):
        new_liberties = self.liberties | set([point])
        return GoString(self.color, self.stones, new_liberties)

    def merged_with(self, go_string):
        assert go_string.color == self.color
        combined_stones = self.stones | go_string.stones
        return GoString(
            self.color,
            combined_stones,
            (self.liberties | go_string.liberties) - combined_stones
        )

    @property
    def num_liberties(self):
        return len(self.liberties)

    def __eq__(self, other):
        return isinstance(other, GoString) and \
            self.color == other.color and \
            self.stones == other.stones and \
            self.liberties == other.liberties

class GameState():
    def __init__(self, board, next_player, previous, move):
        self.board = board
        self.next_player = next_player
        self.previous_state = previous
        if self.previous_state is None:
            self.previous_states = frozenset()
        else:
            self.previous_states = frozenset(
                previous.previous_states |
                {(previous.next_player, previous.board.zobrist_hash())})
        self.last_move = move

    def apply_move(self, move):
        if move.is_play:
            next_board = copy.deepcopy(self.board)
            next_board.place_stone(self.next_player, move.point)
        else:
            next_board = self.board
        return GameState(next_board, self.next_player.other, self, move)

    def is_over(self):
        if self.last_move is None:
            return False
        if self.last_move.is_resign:
            return True
        second_last_move = self.previous_state.last_move
        if second_last_move is None:
            return False
        return self.last_move.is_pass and second_last_move.is_pass

    def is_move_self_capture(self, player, move):
        if not move.is_play:
            return False
        next_board = copy.deepcopy(self.board)
        next_board.place_stone(player, move.point)
        new_string = next_board.get_go_string(move.point)
        return new_string.num_liberties == 0

    def is_valid_move(self, move):
        if self.is_over():
            return False
        if move.is_pass or move.is_resign:
            return True
        return (
            self.board.get(move.point) is None and
            not self.is_move_self_capture(self.next_player, move) and
            not self.does_move_violate_ko(self.next_player, move))
    
    def legal_moves(self):
        if self.is_over():
            return []
        moves = []
        for row in range(1, self.board.num_rows + 1):
            for col in range(1, self.board.num_cols + 1):
                move = Move.play(Point(row, col))
                if self.is_valid_move(move):
                    moves.append(move)
        moves.append(Move.pass_turn())
        moves.append(Move.resign())
        return moves
    
    def winner(self):
        if not self.is_over():
            return None
        if self.last_move.is_resign:
            return self.next_player
        game_result = compute_game_result(self)
        return game_result.winner

    @classmethod
    def new_game(cls, board_size):
        if isinstance(board_size, int):
            board_size = (board_size, board_size)
        board = Board(*board_size)
        return GameState(board, Player.black, None, None)

    @property
    def situation(self):
        return (self.next_player, self.board)

    def does_move_violate_ko(self, player, move):
        if not move.is_play:
            return False
        next_board = copy.deepcopy(self.board)
        next_board.place_stone(player, move.point)
        next_situation = (player.other, next_board.zobrist_hash())
        return next_situation in self.previous_states

## Zobrist Hash

In [None]:
import random

MAX63 = 0x7fffffffffffffff

HASH_CODE = {}
EMPTY_BOARD = 0

for row in range(1, 20):
    for col in range(1, 20):
        for state in (1, 2):
            code = random.randint(0, MAX63)
            HASH_CODE[Point(row, col), state] = code

print(HASH_CODE)

## Fast Game Environment

In [None]:
import numpy as np
from collections import defaultdict

class FastBoard:
    """Optimized board using numpy arrays"""
    def __init__(self, num_rows, num_cols):
        self.num_rows = num_rows
        self.num_cols = num_cols
        # Use numpy arrays: 0=empty, 1=black, 2=white
        self.grid = np.zeros((num_rows, num_cols), dtype=np.int8)
        self._hash = 0
        
    def copy(self):
        """Fast shallow copy with numpy array copy"""
        new_board = FastBoard.__new__(FastBoard)
        new_board.num_rows = self.num_rows
        new_board.num_cols = self.num_cols
        new_board.grid = self.grid.copy()  # Fast numpy copy
        new_board._hash = self._hash
        return new_board
    
    def place_stone(self, player, row, col):
        """Place stone and handle captures"""
        self.grid[row, col] = player
        point = Point(row+1, col+1)
        self._hash ^= HASH_CODE[point, player]
        
        # Check for captures of opponent stones
        opponent = 3 - player  # Toggles between 1 and 2
        for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
            nr, nc = row + dr, col + dc
            if self._is_on_grid(nr, nc) and self.grid[nr, nc] == opponent:
                if self._count_liberties(nr, nc) == 0:
                    self._remove_group(nr, nc)
    
    def _is_on_grid(self, row, col):
        return 0 <= row < self.num_rows and 0 <= col < self.num_cols
    
    def _count_liberties(self, row, col):
        """Count liberties of group using flood fill"""
        color = self.grid[row, col]
        if color == 0:
            return 0
        
        visited = np.zeros_like(self.grid, dtype=bool)
        liberty_count = 0
        stack = [(row, col)]
        
        while stack:
            r, c = stack.pop()
            if visited[r, c]:
                continue
            visited[r, c] = True
            
            for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                nr, nc = r + dr, c + dc
                if not self._is_on_grid(nr, nc):
                    continue
                    
                if self.grid[nr, nc] == 0:
                    liberty_count += 1
                elif self.grid[nr, nc] == color and not visited[nr, nc]:
                    stack.append((nr, nc))
        
        return liberty_count
    
    def _remove_group(self, row, col):
        """Remove a captured group"""
        color = self.grid[row, col]
        stack = [(row, col)]
        visited = set()
        
        while stack:
            r, c = stack.pop()
            if (r, c) in visited:
                continue
            visited.add((r, c))
            
            if not self._is_on_grid(r, c) or self.grid[r, c] != color:
                continue
            
            self.grid[r, c] = 0
            point = Point(row=r+1, col=c+1)
            self._hash ^= HASH_CODE[point, color]
            
            for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                stack.append((r + dr, c + dc))
    
    def zobrist_hash(self):
        return self._hash
    
    def get(self, row, col):
        return self.grid[row, col]


class FastGameState:
    """Optimized game state"""
    def __init__(self, board, next_player, previous_hash=None, last_move = None, second_last_move = None):
        self.board = board
        self.next_player = next_player  # 1 for black, 2 for white
        self.previous_hashes = set()
        if previous_hash is not None:
            self.previous_hashes.add(previous_hash)
        self.last_move = last_move
        self.second_last_move = second_last_move
    
    def apply_move(self, move):
        """Apply move and return new state (only copies when needed)"""
        if move.is_play:
            point = move.point
            row = point.row - 1
            col = point.col - 1
            new_board = self.board.copy()  # Fast numpy copy
            new_board.place_stone(self.next_player, row, col)
        else:
            new_board = self.board.copy()
        
        new_state = FastGameState(new_board, 3 - self.next_player, 
                                  self.board.zobrist_hash(), move, self.last_move)
        new_state.previous_hashes = self.previous_hashes.copy()
        new_state.previous_hashes.add(self.board.zobrist_hash())
        
        return new_state

    def is_over(self):
        if self.last_move is None:
            return False
        if self.last_move.is_resign:
            return True
        if self.second_last_move is None:
            return False
        return self.last_move.is_pass and self.second_last_move.is_pass
    
    def is_valid_move(self, move):
        """Check if move is valid"""
        if self.is_over():
            return False
        if move.is_pass or move.is_resign:
            return True

        point = move.point
        row = point.row - 1
        col = point.col - 1

        if self.board.get(row, col) != 0:
            return False
        
        # Quick check: simulate move
        test_board = self.board.copy()
        test_board.place_stone(self.next_player, row, col)
        
        # Check for self-capture
        if test_board._count_liberties(row, col) == 0:
            return False
        
        # Check for ko violation
        if test_board.zobrist_hash() in self.previous_hashes:
            return False
        
        return True
    
    def legal_moves(self):
        """Return all legal moves as (row, col) tuples"""
        legal = []
        for row in range(self.board.num_rows):
            for col in range(self.board.num_cols):
                move = Move.play(Point(row+1, col+1))
                if self.is_valid_move(move):
                    legal.append(move)
        legal.append(Move.pass_turn())
        legal.append(Move.resign())
        return legal

    def winner(self):
        if not self.is_over():
            return None
        if self.last_move.is_resign:
            return self.next_player
        game_result = compute_game_result(self)
        return game_result.winner
    
    @classmethod
    def new_game(cls, board_size=19):
        board = FastBoard(board_size, board_size)
        return FastGameState(board, 1)  # Black plays first

In [None]:
class FastTerritory:
    """Fast territory evaluation results"""
    __slots__ = ['num_black_territory', 'num_white_territory', 
                 'num_black_stones', 'num_white_stones', 
                 'num_dame', 'dame_points']
    
    def __init__(self):
        self.num_black_territory = 0
        self.num_white_territory = 0
        self.num_black_stones = 0
        self.num_white_stones = 0
        self.num_dame = 0
        self.dame_points = []

class FastGameResult:
    """Game result with score calculation"""
    __slots__ = ['b', 'w', 'komi']
    
    def __init__(self, b, w, komi=7.5):
        self.b = b
        self.w = w
        self.komi = komi
    
    @property
    def winner(self):
        """Return 1 for black, 2 for white, 0 for tie"""
        if self.b > self.w + self.komi:
            return 1  # Black
        if self.b < self.w + self.komi:
            return 2  # White
        return 0  # Tie
    
    @property
    def winning_margin(self):
        w = self.w + self.komi
        return abs(self.b - w)
    
    def __str__(self):
        w = self.w + self.komi
        if self.b > w:
            return f'B+{self.b - w:.1f}'
        return f'W+{w - self.b:.1f}'

def evaluate_territory_fast(board):
    """Optimized territory evaluation using numpy and flood fill
    
    Returns Territory object with stone and territory counts
    """
    grid = board.grid
    rows, cols = grid.shape
    
    # Status: 0=unvisited, 1=black stone, 2=white stone, 
    # 3=black territory, 4=white territory, 5=dame
    status = np.zeros((rows, cols), dtype=np.int8)
    status[grid > 0] = grid[grid > 0]  # Copy stones
    
    territory = Territory()
    
    # Count stones directly from grid
    territory.num_black_stones = np.sum(grid == 1)
    territory.num_white_stones = np.sum(grid == 2)
    
    # Flood fill empty regions
    for r in range(rows):
        for c in range(cols):
            if status[r, c] == 0:  # Empty and unvisited
                points, borders = _collect_region_fast(r, c, grid, status)
                
                # Determine territory ownership
                if len(borders) == 1:
                    # Single color border = territory
                    owner = borders.pop()
                    if owner == 1:
                        territory.num_black_territory += len(points)
                        fill_value = 3
                    else:
                        territory.num_white_territory += len(points)
                        fill_value = 4
                else:
                    # Multiple colors or no border = dame
                    territory.num_dame += len(points)
                    territory.dame_points.extend(points)
                    fill_value = 5
                
                # Mark all points in region
                for pr, pc in points:
                    status[pr, pc] = fill_value
    
    return territory


def _collect_region_fast(start_r, start_c, grid, status):
    """Fast flood fill using stack instead of recursion
    
    Args:
        start_r, start_c: Starting position
        grid: Board grid (numpy array)
        status: Status tracking array
        
    Returns:
        (points, borders): List of (row, col) tuples and set of border colors
    """
    rows, cols = grid.shape
    points = []
    borders = set()
    
    # Stack-based flood fill (much faster than recursion)
    stack = [(start_r, start_c)]
    visited = np.zeros((rows, cols), dtype=bool)
    
    while stack:
        r, c = stack.pop()
        
        if visited[r, c]:
            continue
        visited[r, c] = True
        
        points.append((r, c))
        
        # Check all 4 neighbors
        for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
            nr, nc = r + dr, c + dc
            
            if not (0 <= nr < rows and 0 <= nc < cols):
                continue
            
            neighbor_value = grid[nr, nc]
            
            if neighbor_value == 0 and not visited[nr, nc]:
                # Empty space, continue flood fill
                stack.append((nr, nc))
            elif neighbor_value > 0:
                # Stone found, add to borders
                borders.add(neighbor_value)
    
    return points, borders


def compute_game_result_fast(game_state, komi=7.5):
    """Compute final game result
    
    Args:
        game_state: FastGameState object
        komi: Komi value (default 7.5 for standard rules)
        
    Returns:
        GameResult object
    """
    territory = evaluate_territory_fast(game_state.board)
    
    black_score = territory.num_black_territory + territory.num_black_stones
    white_score = territory.num_white_territory + territory.num_white_stones
    
    return GameResult(black_score, white_score, komi)

In [None]:
import random
import time
class FastRandomBot:
    def __init__(self, player):
        self.player = player
    def play_move(self, game_state):
        legal_moves = game_state.legal_moves()
        random_move = random.choice(legal_moves)
        new_game_state = game_state.apply_move(random_move)
        return new_game_state, 3 - self.player

def play_game():
    game_state = FastGameState.new_game()
    
    agent = FastRandomBot(player = 1)
    opponent = FastRandomBot(player = 2)
    next_player = 1
    while not game_state.is_over():
        if next_player == 1:
            game_state, next_player = agent.play_move(game_state)
        else:
            game_state, next_player = opponent.play_move(game_state)
    return game_state.winner()

start = time.time()
for _ in range(10):
    winner = play_game()
    print(winner)
end = time.time()

print(f"time elapsed {end - start} seconds")


## 4 Plane Encoder

Original AlphaGo uses a 48 plane encoder, we use 4 plane encoder, which is also mentioned in the paper.

In [None]:
import numpy as np

"""
Feature name            num of planes   Description
Stone colour            3               Player stone / opponent stone / empty
Ones                    1               A constant plane filled with 1
"""

FEATURE_OFFSETS = {
    "stone_color": 0,
    "ones": 3,
    "current_player_color": 4,
    "legal_moves": 5
}


def offset(feature):
    return FEATURE_OFFSETS[feature]

# TODO: I need to change the Four plane encoder, just because I have changed the environment
class FourplaneEncoder:
    def __init__(self, board_size=(19, 19), use_player_plane=True, use_legal_moves=True):
        self.board_width, self.board_height = board_size
        self.use_player_plane = use_player_plane
        self.use_legal_moves = use_legal_moves
        self.num_planes = 4 + use_player_plane + use_legal_moves

    def name(self):
        return 'fourplane'

    def encode(self, game_state):
        board_tensor = np.zeros((self.num_planes, self.board_height, self.board_width))
        
        # Set empty cells plane (default to empty)
        board_tensor[offset("stone_color") + 2] = 1
        
        # Iterate over occupied points only (much faster for sparse boards)
        next_player = game_state.next_player
        opponent = next_player.other
        stone_color_offset = offset("stone_color")
        
        for point, go_string in game_state.board._grid.items():
            if go_string is None:
                continue
            r = point.row - 1
            c = point.col - 1
            if go_string.color == next_player:
                board_tensor[stone_color_offset][r][c] = 1
                board_tensor[stone_color_offset + 2][r][c] = 0  # Not empty
            elif go_string.color == opponent:
                board_tensor[stone_color_offset + 1][r][c] = 1
                board_tensor[stone_color_offset + 2][r][c] = 0  # Not empty
        
        # Set ones plane once (moved outside loop)
        board_tensor[offset("ones")] = 1
        
        # Set player plane once (moved outside loop)
        if self.use_player_plane and next_player == Player.black:
            board_tensor[offset("current_player_color")] = 1
        
        # Set legal moves - optimized inline computation to avoid expensive deep copies
        if self.use_legal_moves:
            if not game_state.is_over():
                legal_moves_offset = offset("legal_moves")
                board = game_state.board
                
                # Fast inline legal move checking without deep copies
                for row in range(1, board.num_rows + 1):
                    for col in range(1, board.num_cols + 1):
                        point = Point(row, col)
                        
                        # Fast check: point must be empty
                        if board._grid.get(point) is not None:
                            continue
                        
                        # Check self-capture without deep copy
                        has_liberty = False
                        would_capture = False
                        friendly_strings = []
                        
                        for neighbor in point.neighbors():
                            if not board.is_on_grid(neighbor):
                                continue
                            neighbor_string = board._grid.get(neighbor)
                            if neighbor_string is None:
                                has_liberty = True
                                break
                            elif neighbor_string.color == next_player:
                                friendly_strings.append(neighbor_string)
                            else:  # opponent string
                                if neighbor_string.num_liberties == 1:
                                    would_capture = True
                        
                        # If has liberty, not self-capture
                        if not has_liberty:
                            # Check if all friendly strings would have no liberties
                            if friendly_strings and all(s.num_liberties == 1 for s in friendly_strings):
                                if not would_capture:
                                    continue  # Self-capture, skip
                        
                        # Check ko violation - simplified heuristic for performance
                        # Full ko check requires deep copy, so we use a fast heuristic:
                        # If last move captured exactly one stone and we're trying to recapture
                        # at that same position, it's likely a ko violation
                        if would_capture and game_state.last_move and game_state.last_move.is_play:
                            # Check if we're trying to play at the last move position
                            # (which would be recapturing after a single-stone capture)
                            if point == game_state.last_move.point:
                                # Count how many stones we'd capture
                                captured_stones = 0
                                for neighbor in point.neighbors():
                                    if not board.is_on_grid(neighbor):
                                        continue
                                    neighbor_string = board._grid.get(neighbor)
                                    if neighbor_string and neighbor_string.color == opponent:
                                        if neighbor_string.num_liberties == 1:
                                            captured_stones += len(neighbor_string.stones)
                                # If capturing exactly one stone at last move position, likely ko
                                if captured_stones == 1:
                                    continue  # Likely ko violation, skip
                        
                        # Valid move - set it
                        r = row - 1
                        c = col - 1
                        board_tensor[legal_moves_offset][r][c] = 1

        return board_tensor

    def ones(self):
        return np.ones((1, self.board_height, self.board_width))

    def zeros(self):
        return np.zeros((1, self.board_height, self.board_width))

    def encode_point(self, point):
        return self.board_width * (point.row - 1) + (point.col - 1)

    def decode_point_index(self, index):
        row = index // self.board_width
        col = index % self.board_width
        return Point(row=row + 1, col=col + 1)

    def num_points(self):
        return self.board_width * self.board_height

    def shape(self):
        return self.num_planes, self.board_height, self.board_width


def create(board_size):
    return FourplaneEncoder(board_size)


## Supervised Learning Policy Network

 - Convolution layer with rectifier non-linearities.
 - CNN with 13 layers
 - Final softmax applied to legal moves
 - Data has been split into 4:1 train to test ratio. (In AlphaGo it is 28:1)
 - Pass moves has been excluded from the dataset
 - All 8 reflections and rotations has been applied and precomputed. We randomly sample mini batch from the augmented sample
 - Asynchronous stochastic gradient descent to minimize the log likelihood
 - Learning parameter is initialized to 0.03 and halved every 30k steps. (In AlphaGo, learning rate is 0.003 and halved every 80 mil training steps)
 - mini batch size of 16 as mentioned in the paper
 - zero momentum

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch

class SLPolicyNetwork(nn.Module):
    def __init__(self, features=5, filters=192):
        super(SLPolicyNetwork, self).__init__()
        self.features = features
        self.filters = filters
        self.first_layer = nn.Conv2d(self.features, self.filters, kernel_size=5, stride=1, padding=2)

        self.hidden_layers = nn.ModuleList([
            nn.Conv2d(self.filters, self.filters, kernel_size=3, stride=1, padding=1)
            for _ in range(11)
        ])

        self.final_layer = nn.Conv2d(self.filters, 1, kernel_size=1, stride=1)

    def forward(self, x):
        # extract legal moves from the data tensor
        legal_moves = x[:, 5, :, :]
        x = F.relu(self.first_layer(x))
        for layer in self.hidden_layers:
            x = F.relu(layer(x))
        # Logits are the unnormalized scores output by the network for each possible move before softmax
        policy_logits = self.final_layer(x)  # (batch, 1, board_size, board_size)

        # apply legal move mask
        # Reshape and apply log_softmax for NLLLoss compatibility as per Alphago paper
        batch_size = policy_logits.size(0)
        policy_logits = policy_logits.view(batch_size, -1)
        legal_move_mask = legal_moves.view(batch_size, -1)
        # apply log softmax to only legal moves
        masked_policy_logits = torch.where(
            legal_move_mask.bool(),
            policy_logits,
            torch.full_like(policy_logits, float('-inf'))
        )
        log_probs = F.log_softmax(masked_policy_logits, dim=1)
        return log_probs

## Supervised Learning Policy Trainer

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

class SLPolicyTrainer:
    def __init__(self, model):
        # initialize hyperparams
        self.model = model
        self.optimizer = None
        self.criterion = None
        self.scheduler = None

    def initialize(self):
        self.optimizer = optim.SGD(
            params = self.model.parameters(),
            lr = 0.03
        )
        self.scheduler = StepLR(self.optimizer, step_size=30000, gamma=0.5)
        self.criterion = nn.NLLLoss()
        return
    def train(self, loader):
        self.model.train()
        total_loss = 0
        correct_predictions = 0
        total_sample = 0
        for (x, y) in loader:
            # Cast input to float32
            x = x.to(torch.float32)
            self.optimizer.zero_grad()
            log_probs = self.model(x)

            loss = self.criterion(log_probs, y)

            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            total_loss += loss.item()
            total_sample += y.size(0)
            prediction = log_probs.argmax(dim=1)
            correct_predictions += (prediction == y).sum().item()

        avg_loss = total_loss / len(loader)
        acc = correct_predictions / total_sample

        return avg_loss, acc
    def evaluate(self, loader):
        self.model.eval()
        total_loss = 0
        total_sample = 0
        correct_predictions = 0
        with torch.no_grad():
            for (x, y) in loader:
                # Cast input to float32
                x = x.to(torch.float32)
                log_probs = self.model(x)
                loss = self.criterion(log_probs, y)
                prediction = log_probs.argmax(dim=1)
                total_loss += loss.item()
                total_sample += y.size(0)
                correct_predictions += (prediction == y).sum().item()


        avg_loss = total_loss / len(loader)
        acc = correct_predictions / total_sample
        return avg_loss, acc

## DataLoader for Go

In [None]:
import numpy as np
import torch
import torch.utils.data as td

__all__ = [
    'GoDataLoader'
]

class GoDataLoader:
    def __init__(self, feature_path, label_path = None):
        self.feature_path = feature_path
        self.label_path = label_path

    def load_data(self):
        features = np.load(self.feature_path)
        features_tensor = torch.from_numpy(features)

        if self.label_path:
            labels = np.load(self.label_path)
            labels_tensor = torch.from_numpy(labels).to(torch.int64) # Cast labels to torch.int64
            # this dataset can be directly used with torch's DataLoader
            dataset = td.TensorDataset(features_tensor, labels_tensor)
        else:
            dataset = td.TensorDataset(features_tensor)
        return dataset

## Training Loop of Supervised Learning Policy

In [None]:
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
import torch.nn as nn

print("Load model")
model = SLPolicyNetwork(features=6)
num_of_epochs = 10

# prepare the dataset
print("fetch and load data")
training_feature_path = '/kaggle/input/alphago-kgs-200k/KGS-2019_04-19-1255-_train_features.npy'
training_label_path = '/kaggle/input/alphago-kgs-200k/KGS-2019_04-19-1255-_train_labels.npy'
test_feature_path = '/kaggle/input/alphago-kgs-200k/KGS-2019_04-19-1255-_test_features.npy'
test_label_path = '/kaggle/input/alphago-kgs-200k/KGS-2019_04-19-1255-_test_labels.npy'
training_dataset = GoDataLoader(training_feature_path, training_label_path).load_data()
test_dataset = GoDataLoader(test_feature_path, test_label_path).load_data()

# load the dataset to a loader
# this is what we will pass to our Trainer
training_loader = DataLoader(training_dataset, batch_size=16, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

# load the trainer
trainer = SLPolicyTrainer(model)
trainer.initialize()

for epoch in range(num_of_epochs):
    training_loss, training_acc = trainer.train(training_loader)
    test_loss, test_acc = trainer.evaluate(test_loader)
    print(f"epoch: {epoch}, training loss: {training_loss}, training acc: {training_acc}, test loss: {test_loss}, test_acc: {test_acc}")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': trainer.optimizer.state_dict(),
        'train_acc': training_acc,
        'test_acc': test_acc,
    }, f'sl_policy_epoch_{epoch}.pth')

In [None]:
# check if gpu is available

import torch

print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"Current CUDA device: {torch.cuda.current_device()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")

## Supervised Learning Policy Agent

In [None]:
import torch

class SLPolicyAgent:
    def __init__(self, model):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = model.to(self.device)
        self.encoder = FourplaneEncoder()

    def select_move(self, game_state):
        assert isinstance(game_state, GameState)
        encoded_game_state = self.encoder.encode(game_state)
        x = torch.from_numpy(encoded_game_state)
        x = x.unsqueeze(0)
        x = x.to(torch.float32)
        x = x.to(self.device)
        log_probs = self.model(x)
        log_probs = log_probs.squeeze(0)
        legal_moves = game_state.legal_moves()
        legal_moves = legal_moves[:-2]
        if len(legal_moves) == 0:
            return None, float('-inf')

        legal_moves_mask = torch.zeros(log_probs.size(0), dtype=torch.bool, device=self.device)
        
        for move in legal_moves:
            if move is not None and move.is_play:
                encoded_point = self.encoder.encode_point(move.point)
                legal_moves_mask[encoded_point] = True

        log_probs = log_probs.masked_fill(~legal_moves_mask, float('-inf'))
        action = torch.argmax(log_probs)
        log_prob = log_probs[action]
        return action, log_prob
        

## Reinforcement learning of Policy Network

 - We use the same network as for Supervised Learning
 - We initialize the policy to be theta ( which is the supervised learning model)
 - We also use an opponent pool to randomly select opponent for mini batch training as mentioned in the paper
 - Also, using baseline default to 0 for the first pass. AlphaGo uses value network as baseline in the second pass
 - 10k episodes of 128 mini batch size
 - We add a model to the opponent pool after every 500 episodes
 - Used REINFORCE algorithm with stochastic gradient ascent for policy updates.

In [None]:
import torch.optim as optim
import torch
import random

class RLPolicyTrainer:
    def __init__(self, baseline = 0.0):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = SLPolicyNetwork(features=6).to(self.device)
        self.opponent_pool = []
        self.optimizer = None
        self.baseline = baseline
        self.episodes = 1
        self.mini_batch = 1
        self.board_size = 19
        self.encoder = FourplaneEncoder()

    def initialize(self):
        model_file = torch.load('/kaggle/input/sl-policy-epoch-9/pytorch/default/1/sl_policy_epoch_9.pth', map_location=self.device)
        self.model.load_state_dict(model_file['model_state_dict'])
        self.optimizer = optim.SGD(params=self.model.parameters(), lr=0.03)
        opponent_model = SLPolicyNetwork(features=6).to(self.device)
        opponent_model.load_state_dict(model_file['model_state_dict'])
        self.opponent_pool.append(opponent_model)

    def train_batch(self):
        opponent_model = random.choice(self.opponent_pool)
        opponent_model.to(self.device)
        opponent_model.eval()
        self.model.eval()

        streams = [torch.cuda.Stream(device=self.device) for _ in range(2)]
        
        agent = SLPolicyAgent(self.model)
        opponent = SLPolicyAgent(opponent_model)
        
        expected_reward = 0
        total_wins = 0
        trajectories_list = []
        rewards_list = []

        for i in range(self.mini_batch):
            print(f"starting game: {i+1}")

            stream = streams[i % 2]
            with torch.cuda.stream(stream):
                game_state = GameState.new_game(self.board_size)
                trajectory = []
                while not game_state.is_over():
                    if game_state.next_player == Player.black:
                        action, _ = agent.select_move(game_state)
                        if action is None:
                            # player does not have any legal moves left apart from resigning or pass
                            move = Move.resign()
                            point = None
                            break
                        point = self.encoder.decode_point_index(action.item())
                        trajectory.append((action, game_state))
                    else:
                        action, log_prob = opponent.select_move(game_state)
                        if action is None:
                            move = Move.resign()
                            point = None
                            break
                        point = self.encoder.decode_point_index(action.item())
                        
                        if point is not None:
                            move = Move(point)
                        game_state = game_state.apply_move(move)
                winner = game_state.winner()
                print(f"Winner: {winner}")
                reward = 1 if winner == Player.black else -1
                trajectories_list.append(trajectory)
                rewards_list.append(reward)

        torch.cuda.synchronize(device=self.device)

        for trajectory, reward in zip(trajectories_list, rewards_list):
            if reward == 1:
                total_wins += 1
            for (action, ) in trajectory:
                expected_reward += (reward - self.baseline) * log_prob

        expected_reward /= self.mini_batch
        win_rate = total_wins / self.mini_batch
        
        return expected_reward, win_rate

    def train(self):
        for i in range(self.episodes):
            print(f"starting episode: {i+1}")
            expected_reward, win_rate = self.train_batch()
            self.model.train()
            self.optimizer.zero_grad()
            expected_loss = -expected_reward
            expected_loss.backward()
            self.optimizer.step()

            torch.cuda.empty_cache()
            
            print(f"expected loss: {expected_loss}, accuracy: {win_rate}")
            
            if i % 500 == 0:
                opponent_model = SLPolicyNetwork(features=6)
                opponent_model.load_state_dict(self.model.state_dict())
                self.opponent_pool.append(opponent_model)
                
        
                    

rl_trainer = RLPolicyTrainer()
rl_trainer.initialize()
rl_trainer.train()
        
        
        