# ConnectX - Monty Carlo Tree Search

Monty Carlo Tree Search using an Object Oriented Tree with Numpy Bitboard Bitshifting

In [None]:
# %%writm efile submission.py

In [None]:
%%writefile submission.py
#!/usr/bin/env python3

##### 
##### ../../kaggle_compile.py agents/MontyCarlo/MontyCarloPure.py
##### 
##### 2020-08-26 16:45:21+01:00
##### 
##### origin	git@github.com:JamesMcGuigan/ai-games.git (fetch)
##### origin	git@github.com:JamesMcGuigan/ai-games.git (push)
##### 
##### * master 247327a [ahead 6] ConnectX | reduce safety_time to 0.25s
##### 
##### 247327afa97dfaa0c87ea36321e7be3deaa9d8d4
##### 

#####
##### START core/ConnectXBBNN.py
#####

# This is a functional implementation of ConnectX that has been optimized using both numpy and numba

from collections import namedtuple
from typing import List
from typing import Tuple
from typing import Union
from numba import njit, int8, int64

import numba
import numpy as np

# Hardcode for simplicity
# observation   = {'mark': 1, 'board': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
# configuration = {'columns': 7, 'rows': 6, 'inarow': 4, 'steps': 1000, 'timeout': 8}

bitboard_type = numba.typeof(np.ndarray((2,), dtype=np.int64))
Configuration = namedtuple('configuration', ['rows', 'columns', 'inarow'])
configuration = Configuration(
    rows=6,
    columns=7,
    inarow=4
)



### Conversions

def cast_configuration(configuration):
    return Configuration(
        rows    = configuration.rows,
        columns = configuration.columns,
        inarow  = configuration.inarow
    )


def is_bitboard(bitboard) -> bool:
    if isinstance(bitboard, np.ndarray) and bitboard.dtype == np.int64 and bitboard.shape == (2,):
        return True
    else:
        return False

#@njit
def list_to_bitboard(listboard: Union[np.ndarray,List[int]]) -> np.ndarray:
    # bitboard[0] = played, is a square filled             | 0 = empty, 1 = filled
    # bitboard[1] = player, who's token is this, if filled | 0 = empty, 1 = filled
    bitboard_played = 0  # 42 bit number for if board square has been played
    bitboard_player = 0  # 42 bit number for player 0=p1 1=p2
    if isinstance(listboard, np.ndarray): listboard = listboard.flatten()
    for n in range(len(listboard)):  # prange
        if listboard[n] != 0:
            bitboard_played |= (1 << n)        # is a square filled (0 = empty | 1 = filled)
            if listboard[n] == 2:
                bitboard_player |= (1 << n)    # mark as player 2 square, else assume p1=0 as default
    bitboard = np.array([bitboard_played, bitboard_player], dtype=np.int64)
    return bitboard


@njit(int8[:,:](int64[:]))
def bitboard_to_numpy2d(bitboard: np.ndarray) -> np.ndarray:
    global configuration
    rows    = configuration.rows
    columns = configuration.columns
    size    = rows * columns
    output  = np.zeros((size,), dtype=np.int8)
    for i in range(size):  # prange
        is_played = (bitboard[0] >> i) & 1
        if is_played:
            player = (bitboard[1] >> i) & 1
            output[i] = 1 if player == 0 else 2
    return output.reshape((rows, columns))


### Bitboard Operations

@njit
def empty_bitboard() -> np.ndarray:
    return np.array([0, 0], dtype=np.int64)


def bitboard_from_actions(actions: List[Union[int, Tuple[int]]]) -> np.ndarray:
    bitboard  = empty_bitboard()
    player_id = 1
    for action in actions:
        if isinstance(action, tuple): action, player_id = action
        bitboard  = result_action(bitboard, action, player_id=player_id % 2)
        player_id = next_player_id(player_id)
    return bitboard


@njit
def hash_bitboard( bitboard: np.ndarray ) -> Tuple[int,int]:
    """ Create a tupleised mirror hash, the minimum value of the bitboard and its mirrored reverse """
    if bitboard[0] == 0:
        return ( bitboard[0], bitboard[1] )

    global configuration
    mirror_0 = mirror_bitstring(bitboard[0])
    if bitboard[0] < mirror_0:
        return ( bitboard[0], bitboard[1] )
    else:
        mirror_1 = mirror_bitstring(bitboard[1])
        if bitboard[0] == mirror_0 and bitboard[1] <= mirror_1:
            return ( bitboard[0], bitboard[1] )
        else:
            return ( mirror_0, mirror_1 )


# Use string reverse to create mirror bit lookup table: mirror_bits[ 0100000 ] == 0000010
mirror_bits = np.array([
    int( "".join(reversed(f'{n:07b}')), 2 )
    for n in range(2**configuration.columns)
], dtype=np.int64)

@njit
def mirror_bitstring( bitstring: int ) -> int:
    """ Return the mirror view of the board for hashing:  0100000 -> 0000010 """
    global configuration

    if bitstring == 0:
        return 0  # short-circuit for empty board

    bitsize     = configuration.columns * configuration.rows        # total number of bits to process
    unit_size   = configuration.columns                             # size of each row in bits
    unit_mask   = (1 << unit_size) - 1                              # == 0b1111111 | 0x7f
    offsets     = np.arange(0, bitsize, unit_size, dtype=np.int64)  # == [ 0, 7, 14, 21, 28, 35 ]

    # row_masks   = unit_mask               << offsets  # create bitmasks for each row
    # bits        = (bitstring & row_masks) >> offsets  # extract out the bits for each row
    # stib        = mirror_bits[ bits ]     << offsets  # lookup mirror bits for each row and shift back into position
    # output      = np.sum(stib)                        # np.sum() will bitwise AND the array assuming no overlapping bits

    # This can technically be done as a one liner:
    output = np.sum( mirror_bits[ (bitstring & (unit_mask << offsets)) >> offsets ] << offsets )

    ### Old Loop Implementation
    # output = 0
    # for row in range(configuration.rows):
    #     offset = row * configuration.columns
    #     mask   = unit_mask          << offset
    #     bits   = (bitstring & mask) >> offset
    #     if bits == 0: continue
    #     stib   = mirror_bits[ bits ]
    #     output = output | (stib << offset)

    return int(output)


@njit
def mirror_bitboard( bitboard: np.ndarray ) -> np.ndarray:
    return np.array([
        mirror_bitstring(bitboard[0]),
        mirror_bitstring(bitboard[1]),
    ], dtype=bitboard.dtype)



### Player Id

@njit
def current_player_id( bitboard: np.ndarray ) -> int:
    """ Returns next player to move: 1 = p1, 2 = p2 """
    move_number = get_move_number(bitboard)
    next_player = 1 if move_number % 2 == 0 else 2  # player 1 has the first move on an empty board
    return next_player

def current_player_index( bitboard: np.ndarray ) -> int:
    """ Returns next player to move: 0 = p1, 1 = p2 """
    move_number = get_move_number(bitboard)
    next_player = 0 if move_number % 2 == 0 else 1  # player 1 has the first move on an empty board
    return next_player


@njit(int8(int8))
def next_player_id(player_id: int) -> int:
    # assert player_id in [1,2]
    return 1 if player_id == 2 else 2



### Coordinates

@njit
def index_to_coords(index: int) -> Tuple[int,int]:
    global configuration
    row    = index // configuration.columns
    column = index - row * configuration.columns
    return (row, column)


@njit
def coords_to_index(row: int, column: int) -> int:
    global configuration
    return column + row * configuration.columns



### Moves

@njit(int64[:](int8))
def get_bitcount_mask(size: int = configuration.columns * configuration.rows) -> np.ndarray:
    # return np.array([1 << index for index in range(0, size)], dtype=np.int64)
    return 1 << np.arange(0, size, dtype=np.int64)

# bitcount_mask = get_bitcount_mask()


@njit(int8(int64[:]))
def get_move_number(bitboard: np.ndarray) -> int:
    global configuration
    if bitboard[0] == 0: return 0
    size          = configuration.columns * configuration.rows
    mask_bitcount = get_bitcount_mask(size)
    move_number   = np.count_nonzero(bitboard[0] & mask_bitcount)
    return move_number


mask_board       = (1 << configuration.columns * configuration.rows) - 1
mask_legal_moves = (1 << configuration.columns) - 1

@njit
def has_no_illegal_moves( bitboard: np.ndarray ) -> int:
    """If any the squares on the top row have been played, then there are illegal moves"""
    are_all_moves_legal = ((bitboard[0] & mask_legal_moves) == 0)
    return 1 if are_all_moves_legal else 0


@njit
def has_no_more_moves(bitboard: np.ndarray) -> bool:
    """If all the squares on the top row have been played, then there are no more moves"""
    return bitboard[0] & mask_legal_moves == mask_legal_moves


_is_legal_move_mask  = ((1 << configuration.columns) - 1)
_is_legal_move_cache = np.array([
    [
        int( (bits >> action) & 1 == 0 )
        for action in range(configuration.columns)
    ]
    for bits in range(2**configuration.columns)
], dtype=np.int8)

@njit
def is_legal_move(bitboard: np.ndarray, action: int) -> int:
    bits = bitboard[0] & _is_legal_move_mask   # faster than: int( (bitboard[0] >> action) & 1 == 0 )
    return _is_legal_move_cache[bits, action]  # NOTE: [bits,action] is faster than [bits][action]

#@njit
def get_legal_moves(bitboard: np.ndarray) -> np.ndarray:
    # First 7 bytes represent the top row. Moves are legal if the sky is unplayed
    global configuration
    bits = bitboard[0] & _is_legal_move_mask  # faster than: int( (bitboard[0] >> action) & 1 == 0 )
    if bits == 0:
        return actions  # get_all_moves()
    else:
        return np.array([
            action
            for action in range(configuration.columns)
            if _is_legal_move_cache[bits, action]
        ], dtype=np.int8)


actions = np.array([ action for action in range(configuration.columns) ], dtype=np.int64)
@njit
def get_all_moves() -> np.ndarray:
    # First 7 bytes represent the top row. Moves are legal if the sky is unplayed
    return actions
    # global configuration
    # return np.array([ action for action in range(configuration.columns) ])


@njit
def get_random_move(bitboard: np.ndarray) -> int:
    """ This is slightly quicker than random.choice(get_all_moves())"""
    # assert not has_no_more_moves(bitboard)

    global configuration
    while True:
        action = np.random.randint(0, configuration.columns)
        if is_legal_move(bitboard, action):
            return action



# Actions + Results

@njit
def get_next_index(bitboard: np.ndarray, action: int) -> int:
    global configuration
    # assert is_legal_move(bitboard, action)

    # Start at the ground, and return first row that contains a 0
    for row in range(configuration.rows-1, -1, -1):
        index = action + (row * configuration.columns)
        value = (bitboard[0] >> index) & 1
        if value == 0:
            return index
    return action  # this should never happen - implies not is_legal_move(action)

@njit
def get_next_row(bitboard: np.ndarray, action: int) -> int:
    global configuration
    index = get_next_index(bitboard, action)
    row   = index // configuration.columns
    return row


@njit
def result_action(bitboard: np.ndarray, action: int, player_id: int) -> np.ndarray:
    # assert is_legal_move(bitboard, action)
    index    = get_next_index(bitboard, action)
    mark     = 0 if player_id == 1 else 1
    output = np.array([
        bitboard[0] | 1    << index,
        bitboard[1] | mark << index
    ], dtype=bitboard.dtype)
    return output


### Simulations

#@njit
def run_random_simulation( bitboard: np.ndarray, player_id: int ) -> float:
    """ Returns +1 = victory | 0.5 = draw | 0 = loss """
    move_number = get_move_number(bitboard)
    next_player = 1 if move_number % 2 == 0 else 2  # player 1 has the first move on an empty board
    while not is_gameover(bitboard):
        actions     = get_legal_moves(bitboard)
        action      = np.random.choice(actions)
        bitboard    = result_action(bitboard, action, next_player)
        next_player = next_player_id(next_player)
        # print( bitboard_to_numpy2d(bitboard) )  # DEBUG
    score = get_utility_zero_one(bitboard, player_id)
    return score


### Endgame

@njit(int64[:]())
def get_gameovers() -> np.ndarray:
    """Creates a list of all winning board positions, over 4 directions: horizontal, vertical and 2 diagonals"""
    global configuration

    rows    = configuration.rows
    columns = configuration.columns
    inarow  = configuration.inarow

    gameovers = []

    mask_horizontal  = 0
    mask_vertical    = 0
    mask_diagonal_dl = 0
    mask_diagonal_ul = 0
    for n in range(inarow):  # prange
        mask_horizontal  |= 1 << n
        mask_vertical    |= 1 << n * columns
        mask_diagonal_dl |= 1 << n * columns + n
        mask_diagonal_ul |= 1 << n * columns + (inarow - 1 - n)

    row_inner = rows    - inarow
    col_inner = columns - inarow
    for row in range(rows):  # prange
        for col in range(columns):  # prange
            offset = col + row * columns
            if col <= col_inner:
                gameovers.append( mask_horizontal << offset )
            if row <= row_inner:
                gameovers.append( mask_vertical << offset )
            if col <= col_inner and row <= row_inner:
                gameovers.append( mask_diagonal_dl << offset )
                gameovers.append( mask_diagonal_ul << offset )

    _get_gameovers_cache = np.array(gameovers, dtype=np.int64)
    return _get_gameovers_cache

gameovers = get_gameovers()


@njit
def is_gameover(bitboard: np.ndarray) -> bool:
    if has_no_more_moves(bitboard):  return True
    if get_winner(bitboard) != 0:    return True
    return False


@njit
def get_winner(bitboard: np.ndarray) -> int:
    """ Endgame get_winner: 0 for no get_winner, 1 = player 1, 2 = player 2"""
    global gameovers
    # gameovers = get_gameovers()
    p2_wins = (bitboard[0] &  bitboard[1]) & gameovers == gameovers
    if np.any(p2_wins): return 2
    p1_wins = (bitboard[0] & ~bitboard[1]) & gameovers == gameovers
    if np.any(p1_wins): return 1
    return 0

    # NOTE: above implementation is 2x faster than this original attempt
    # gameovers_played = gameovers[ gameovers & bitboard[0] == gameovers ]  # exclude any unplayed squares
    # if np.any(gameovers_played):                                          # have 4 tokens been played in a row yet
    #     p1_wins = gameovers_played & ~bitboard[1] == gameovers_played
    #     if np.any(p1_wins): return 1
    #     p2_wins = gameovers_played &  bitboard[1] == gameovers_played
    #     if np.any(p2_wins): return 2
    # return 0


### Utility Scores

@njit
def get_utility_one(bitboard: np.ndarray, player_id: int) -> int:
    """ Like get_utility_inf but returns: 1 for victory, -1 for loss, 0 for draw """
    winning_player = get_winner(bitboard)
    if winning_player == 0: return 0
    return 1 if winning_player == player_id else -1


@njit
def get_utility_zero_one(bitboard: np.ndarray, player_id: int) -> float:
    """ Like get_utility_one but returns: 1 for victory, 0 for loss, 0.5 for draw """
    winning_player = get_winner(bitboard)
    if winning_player == 0: return 0.5
    return 1.0 if winning_player == player_id else 0.0


@njit
def get_utility_inf(bitboard: np.ndarray, player_id: int) -> float:
    """ Like get_utility_one but returns: +inf for victory, -inf for loss, 0 for draw """
    winning_player = get_winner(bitboard)
    if winning_player == 0: return 0
    return +np.inf if winning_player == player_id else -np.inf


#####
##### END   core/ConnectXBBNN.py
#####

#####
##### START util/base64_file.py
#####

import base64
import gzip
import os
import pickle
import re
import time
from typing import Any, Union
import humanize

# _base64_file__test_base64_static_import = """
# H4sIAPx9LF8C/2tgri1k0IjgYGBgKCxNLS7JzM8rZIwtZNLwZvBm8mYEkjAI4jFB2KkRbED1iXnF
# 5alFhczeWqV6AEGfwmBHAAAA
# """


def base64_file_varname(filename: str) -> str:
    # ../data/AntColonyTreeSearchNode.pickle.zip.base64 -> _base64_file__AntColonyTreeSearchNode__pickle__zip__base64
    varname = re.sub(r'^.*/',   '',   filename)  # remove directories
    varname = re.sub(r'[.\W]+', '__', varname)   # convert dots and non-ascii to __
    varname = f"_base64_file__{varname}"
    return varname


def base64_file_var_wrap(base64_data: Union[str,bytes], varname: str) -> str:
    return f'{varname} = """\n{base64_data.strip()}\n"""'                    # add varname = """\n\n""" wrapper


def base64_file_var_unwrap(base64_data: str) -> str:
    output = base64_data.strip()
    output = re.sub(r'^\w+ = """|"""$', '', output)  # remove varname = """ """ wrapper
    output = output.strip()
    return output


def base64_file_encode(data: Any) -> str:
    encoded = pickle.dumps(data)
    encoded = gzip.compress(encoded)
    encoded = base64.encodebytes(encoded).decode('utf8').strip()
    return encoded


def base64_file_decode(encoded: str) -> Any:
    data = base64.b64decode(encoded)
    data = gzip.decompress(data)
    data = pickle.loads(data)
    return data


def base64_file_save(data: Any, filename: str, vebose=True) -> float:
    """
        Saves a base64 encoded version of data into filename, with a varname wrapper for importing via kaggle_compile.py
        # Doesn't create/update global variable.
        Returns filesize in bytes
    """
    varname    = base64_file_varname(filename)
    start_time = time.perf_counter()
    try:
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        with open(filename, 'wb') as file:
            encoded = base64_file_encode(data)
            output  = base64_file_var_wrap(encoded, varname)
            output  = output.encode('utf8')
            file.write(output)
            file.close()
        if varname in globals(): globals()[varname] = encoded  # globals not shared between modules, but update for saftey

        filesize = os.path.getsize(filename)
        if vebose:
            time_taken = time.perf_counter() - start_time
            print(f"base64_file_save(): {filename:40s} | {humanize.naturalsize(filesize)} in {time_taken:4.1f}s")
        return filesize
    except Exception as exception:
        pass
    return 0.0


def base64_file_load(filename: str, vebose=True) -> Union[Any,None]:
    """
        Performs a lookup to see if the global variable for this file alread exists
        If not, reads the base64 encoded file from filename, with an optional varname wrapper
        # Doesn't create/update global variable.
        Returns decoded data
    """
    varname    = base64_file_varname(filename)
    start_time = time.perf_counter()
    try:
        # Hard-coding PyTorch weights into a script - https://www.kaggle.com/c/connectx/discussion/126678
        encoded = None

        if varname in globals():
            encoded = globals()[varname]

        if encoded is None and os.path.exists(filename):
            with open(filename, 'rb') as file:
                encoded = file.read().decode('utf8')
                encoded = base64_file_var_unwrap(encoded)
                # globals()[varname] = encoded  # globals are not shared between modules

        if encoded is not None:
            data = base64_file_decode(encoded)

            if vebose:
                filesize = os.path.getsize(filename)
                time_taken = time.perf_counter() - start_time
                print(f"base64_file_load(): {filename:40s} | {humanize.naturalsize(filesize)} in {time_taken:4.1f}s")
            return data
    except Exception as exception:
        print(f'base64_file_load({filename}): Exception:', exception)
    return None


#####
##### END   util/base64_file.py
#####

#####
##### START agents/MontyCarlo/MontyCarloPure.py
#####

# This is a LinkedList implementation of MontyCarlo Tree Search
# Inspired by https://www.kaggle.com/matant/monte-carlo-tree-search-connectx
import atexit
import time
from struct import Struct
from typing import Callable

# from core.ConnectXBBNN import *
# from util.base64_file import base64_file_load
# from util.base64_file import base64_file_save

Hyperparameters = namedtuple('hyperparameters', [])

class MontyCarloNode:
    persist   = True
    save_node = {}                                                        # save_node[cls.__name__] = cls(empty_bitboard(), 1)
    root_nodes: List[Union['MontyCarloNode', None]] = [None, None, None]  # root_nodes[observation.mark]

    def __init__(
            self,
            bitboard:      np.ndarray,
            player_id:     int,
            parent:        Union['MontyCarloNode', None] = None,
            parent_action: Union[int,None]       = None,
            exploration:   float = 1.0,
            **kwargs
    ):
        self.bitboard      = bitboard
        self.player_id     = player_id
        self.next_player   = 3 - player_id

        self.exploration   = exploration
        self.kwargs        = kwargs

        # self.mirror_hash   = hash_bitboard(bitboard)  # BUG: using mirror hashes causes get_best_action() to return invalid moves
        self.legal_moves   = get_legal_moves(bitboard)
        self.is_gameover   = is_gameover(bitboard)
        self.winner        = get_winner(bitboard) if self.is_gameover else 0
        self.utility       = 1 if self.winner == self.player_id else 0  # Scores in range 0-1

        self.parent        = parent
        self.parent_action = parent_action
        self.is_expanded   = False
        self.children: List[Union[MontyCarloNode, None]] = [None for action in get_all_moves()]  # include illegal moves to preserve indexing
        self.total_score   = 0.0
        self.total_visits  = 0
        self.ucb1_score    = self.get_ucb1_score(self)



    def __hash__(self):
        return tuple(self.bitboard)
        # return self.mirror_hash  # BUG: using mirror hashes causes get_best_action() to return invalid moves



    ### Loading and Saving

    @classmethod
    def prune(cls, node: 'MontyCarloNode', min_visits=7, pruned_count=0, total_count=0):
        for n, child in enumerate(node.children):
            if child is None: continue
            if child.total_visits < min_visits:
                pruned_count    += child.total_visits  # excepting terminal states, this equals the number of grandchildren
                total_count     += child.total_visits  # excepting terminal states, this equals the number of grandchildren
                node.children[n] = None
                node.is_expanded = False  # Use def expand(self) to reinitalize state
            else:
                total_count += 1
                pruned_count, total_count = cls.prune(child, min_visits, pruned_count, total_count)
        return pruned_count, total_count


    @classmethod
    def filename(cls):
        return f"data/{cls.__name__}_base64.py"

    @classmethod
    def load(cls):
        filename = cls.filename()
        loaded   = base64_file_load(filename)
        if loaded is not None:
            cls.save_node[cls.__name__] = loaded
            return loaded
        else:
            return None


    @classmethod
    def save(cls) -> Union[str,None]:
        if cls.persist == True and cls.save_node.get(cls.__name__, None) is not None:
            save_node    = cls.save_node[cls.__name__]

            start_time   = time.perf_counter()
            pruned_count, total_count = cls.prune(save_node)  # This reduces a 47MB base64 file down to 5Mb
            print(f'{cls.__name__}.save() - pruned {pruned_count:.0f}/{total_count:.0f} nodes leaving {total_count-pruned_count:.0f} in {time.perf_counter() - start_time:.2f}s')

            filename = cls.filename()
            filesize = base64_file_save(save_node, filename)
            return filename
        return None

    ### Constructors and Lookups

    def create_child( self, action: int ) -> 'MontyCarloNode':
        result = result_action(self.bitboard, action, self.player_id)
        child  = None  # self.find_mirror_child(result, depth=1)  # BUG: using mirror hashes causes get_best_action() to return invalid moves
        if child is None:
            child = self.__class__(
                bitboard      = result,
                player_id     = next_player_id(self.player_id),
                parent        = self,
                parent_action = action,
                exploration   = self.exploration,
                **self.kwargs
            )
        self.children[action] = child
        self.is_expanded      = self._is_expanded()
        if self.is_expanded:
            self.on_expanded()
        return child


    def find_child( self, bitboard: np.array, depth=2 ) -> Union['MontyCarloNode', None]:
        # assert 0 <= depth <= 2

        if depth >= 0:
            if np.all( self.bitboard == bitboard ):
                return self
        if depth >= 1:
            for child in self.children:
                if child is None: continue
                if np.all( child.bitboard == bitboard ):
                    return child
        if depth >= 2:
            # Avoid recursion to prevent duplicate calls to hash_bitboard()
            for child in self.children:
                if child is None: continue
                for grandchild in child.children:
                    if grandchild is None: continue
                    if np.all( grandchild.bitboard == bitboard ):
                        return grandchild
        return None

    # # BUG: using mirror hashes causes get_best_action() to return invalid moves
    # def find_mirror_child( self, bitboard: np.array, depth=2 ) -> Union['MontyCarloNode', None]:
    #     # assert 0 <= depth <= 2
    #     mirror_hash = hash_bitboard(bitboard)
    #
    #     if depth >= 0:
    #         if self.mirror_hash == mirror_hash:
    #             return self
    #     if depth >= 1:
    #         for child in self.children:
    #             if child is None: continue
    #             if child.mirror_hash == mirror_hash:
    #                 return child
    #     if depth >= 2:
    #         # Avoid recursion to prevent duplicate calls to hash_bitboard()
    #         for child in self.children:
    #             if child is None: continue
    #             for grandchild in child.children:
    #                 if grandchild is None: continue
    #                 if grandchild.mirror_hash == mirror_hash:
    #                     return grandchild
    #     return None



    ### Properties

    def _is_expanded(self) -> bool:
        is_expanded = True
        for action in self.legal_moves:
            if self.children[action] is None:
                is_expanded = False
                break
        return is_expanded


    def get_unexpanded(self) -> List[int]:
        return [
            action
            for action in self.legal_moves
            if self.children[action] is None
        ]


    ### Action Selection

    def get_best_action(self) -> int:
        scores = [
            self.children[action].total_score
            if self.children[action] is not None else 0
            for action in self.legal_moves
        ]
        index  = np.argmax(scores)
        action = self.legal_moves[index]
        return action


    def get_exploration_action(self) -> int:
        scores = [
            self.children[action].ucb1_score
            if self.children[action] is not None else 0
            for action in self.legal_moves
        ]
        index  = np.argmax(scores)
        action = self.legal_moves[index]
        return action



    ### Scores

    def get_ucb1_score(self, node: 'MontyCarloNode') -> float:
        if node is None or node.total_visits == 0:
            return np.inf
        else:
            score = node.total_score / node.total_visits
            if node.parent is not None and node.parent.total_visits > 0:
                score += (
                    node.exploration * np.sqrt(2)
                    * np.log(node.parent.total_visits) / node.total_visits
                )
            return score


    @staticmethod
    def opponents_score(score: float):
        # assert 0 <= score <= 1
        return 1 - score



    ### Training and Backpropagation

    def single_run(self):
        if self.is_gameover:
            self.backpropagate(self.utility)
        elif not self.is_expanded:
            child = self.expand()
            score = child.simulate()    # score from the perspective of the other player
            child.backpropagate(score)
        else:
            # Recurse down tree, until a gameover or not expanded node is found
            action = self.get_exploration_action()
            child  = self.children[action]
            child.single_run()


    def expand(self) -> 'MontyCarloNode':
        # assert not self.is_gameover
        # assert not self.is_expanded

        unexpanded = self.get_unexpanded()
        # assert len(unexpanded)

        action     = np.random.choice(unexpanded)
        child      = self.create_child(action)
        return child

    def on_expanded(self) -> None:
        # Callback placeholder for any subclass hooks
        pass

    def simulate(self) -> float:
        return run_random_simulation(self.bitboard, self.player_id)


    def backpropagate(self, score: float):
        # child.simulate()  returns score for the player 2
        # child.total_score is accessed via the parent node, so score on this node is from the perspective of player 1
        node = self
        while node is not None:
            score = self.opponents_score(score)
            node.total_score  += score
            node.total_visits += 1
            node = node.parent      # when we reach the root: node.parent == None which terminates

        # get_ucb1_score() gets called 4x less often if we cache the value after backpropagation
        # get_ucb1_score() depends on parent.total_visits, so needs to be called after updating scores
        node = self
        while node is not None:
            node.ucb1_score = node.get_ucb1_score(node)
            node = node.parent      # when we reach the root: node.parent == None which terminates



    ### Agent
    @classmethod
    def agent(cls, **kwargs) -> Callable[[Struct, Struct],int]:
        def kaggle_agent( observation: Struct, _configuration_: Struct ):
            first_move_time = 0
            safety_time     = kwargs.get('safety_time', 0.25)
            start_time      = time.perf_counter()
            # configuration   = cast_configuration(_configuration_)

            player_id     = int(observation.mark)
            listboard     = np.array(observation.board, dtype=np.int8)
            bitboard      = list_to_bitboard(listboard)
            move_number   = get_move_number(bitboard)
            is_first_move = int(move_number < 2)
            endtime       = start_time + _configuration_.timeout - safety_time - (first_move_time * is_first_move)

            if cls.persist == True and cls.save_node.get(cls.__name__, None) is None:
                atexit.register(cls.save)
                cls.save_node[cls.__name__] = cls.load() or cls(empty_bitboard(), player_id=1)
                cls.root_nodes[1] = cls.root_nodes[2] = cls.save_node[cls.__name__]  # implement shared state

            root_node = cls.root_nodes[player_id]
            if root_node is None or root_node.find_child(bitboard, depth=2) is None:
                root_node = cls.root_nodes[player_id] = cls(
                    bitboard      = bitboard,
                    player_id     = player_id,
                    parent        = None,
                    # configuration = configuration,
                    **kwargs
                )
            else:
                root_node = cls.root_nodes[player_id] = cls.root_nodes[player_id].find_child(bitboard)
            # assert root_node is not None

            count = 0
            while time.perf_counter() < endtime:
                count += 1
                root_node.single_run()

            action     = root_node.get_best_action()
            time_taken = time.perf_counter() - start_time
            print(f'{cls.__name__}: p{player_id} action = {action} after {count} simulations in {time_taken:.3f}s')
            return int(action)

        kaggle_agent.__name__ = cls.__name__
        return kaggle_agent

def MontyCarloPure(**kwargs):
    # observation   = {'mark': 1, 'board': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}
    # configuration = {'columns': 7, 'rows': 6, 'inarow': 4, 'steps': 1000, 'timeout': 8}
    def MontyCarloPure(observation: Struct, configuration: Struct) -> int:
        return MontyCarloNode.agent(**kwargs)(observation, configuration)
    return MontyCarloPure

def MontyCarloPureKaggle(observation, configuration):
    return MontyCarloPure()(observation, configuration)


#####
##### END   agents/MontyCarlo/MontyCarloPure.py
#####

##### 
##### ../../kaggle_compile.py agents/MontyCarlo/MontyCarloPure.py
##### 
##### 2020-08-26 16:45:21+01:00
##### 
##### origin	git@github.com:JamesMcGuigan/ai-games.git (fetch)
##### origin	git@github.com:JamesMcGuigan/ai-games.git (push)
##### 
##### * master 247327a [ahead 6] ConnectX | reduce safety_time to 0.25s
##### 
##### 247327afa97dfaa0c87ea36321e7be3deaa9d8d4
##### 


In [None]:
%run submission.py

# Test your Agent

In [None]:
from kaggle_environments import evaluate, make, utils

%load_ext autoreload
%autoreload 2

## Versus Self

In [None]:
### Play against yourself without an ERROR or INVALID.
### Note: The first episode in the competition will run this to weed out erroneous agents.

env = make("connectx", debug=True)
env.run(["/kaggle/working/submission.py", "/kaggle/working/submission.py"])
print("\nEXCELLENT SUBMISSION!" if env.toJSON()["statuses"] == ["DONE", "DONE"] else "MAYBE BAD SUBMISSION?")
env.render(mode="ipython", width=500, height=450)

## Versus Negamax

In [None]:
env = make("connectx", debug=True)
env.run(["/kaggle/working/submission.py", "negamax"])
print("\nEXCELLENT SUBMISSION!" if env.toJSON()["statuses"] == ["DONE", "DONE"] else "MAYBE BAD SUBMISSION?")
env.render(mode="ipython", width=500, height=450)

# Versus Human

In [None]:
env = make("connectx", debug=True)
env.play([None, "/kaggle/working/submission.py"], width=500, height=450)