- guideline
    - nodes correspond to states $s$
    - edges refer to actions $a$
        - each edge transfers the environment from its parent state to its child state
            - state transition
    - UCT => pUCT: Q + U 
        - early on the simulation, U dominates (more exploration)
        - but later, Q is more important (less exploration, more exploitation)
    - training & inference
        - training: uct = Q + U(select node)
        - inference: visits



In [1]:
import collections
import numpy as np
import math
from IPython.display import Image
import tensorflow as tf
from tensorflow.keras import layers, models
from copy import deepcopy
import time
from collections import deque
import random


In [2]:
c_puct = 2
BOARD_SIZE = 5
ALL_ACTIONS = [(r, c) for r in range(BOARD_SIZE) for c in range(BOARD_SIZE)]
X, O, EMPTY = 'X', 'O', 0
iterations = 600
temp = 1e-3
MEMORY_SIZE = 10000
dirichlet = 0.3
MINIBATCH_SIZE = 256
input_shape = (5,5,1)


## Node & search 

### Node

In [3]:
class UCTNode():
    def __init__(self, state, action, parent=None):

        self.state = state  # Board() class
        self.action = action  # int
        
        self.is_expanded = False
        if state.left == 0:
            self.is_terminal = True
        else:
            self.is_terminal = False

        self.parent = parent  # UCTNode
        
        self.children = {}  # Dict[action, UCTNode]
        # p
        self.child_priors = np.zeros([25], dtype=np.float32)
        # ti
        self.child_total_value = np.zeros([25], dtype=np.float32)
        # ni
        self.child_number_visits = np.zeros([25], dtype=np.float32)
    
    
    # Ni
    @property
    def number_visits(self):
        return self.parent.child_number_visits[self.action]

    @number_visits.setter
    def number_visits(self, value):
        self.parent.child_number_visits[self.action] = value
        
    # ti
    @property
    def total_value(self):
        return self.parent.child_total_value[self.action]

    @total_value.setter
    def total_value(self, value):
        self.parent.child_total_value[self.action] = value

    # pUCT
    def child_Q(self) -> np.ndarray:
        return self.child_total_value / (1 + self.child_number_visits)


    def child_U(self) -> np.ndarray:
        return c_puct * math.sqrt(self.number_visits) * (
            self.child_priors / (1 + self.child_number_visits))
    
    
    def best_child(self) -> int:
        return np.argmax(self.child_Q() + self.child_U())
    
    # traversal
    # using pUCT to select a leaf that is terminal or not expanded
    def select_leaf(self):  
        current = self
        while current.is_expanded and not current.is_terminal:
            # pUCT
            best_action = current.best_child()
            current = current.maybe_add_child(best_action)
        return current
    
    # if leaf not terminal, mark it expanded and give it priors to exclude invalid moves
    def expand(self, child_priors):
        self.is_expanded = True
        if not self.is_terminal:
            self.child_priors = child_priors
            for action in range(25):
                if ALL_ACTIONS[action] not in self.state.generate_actions():
                    self.child_priors[action] = 0
                    self.child_total_value[action] = -100

    # expanded leaf should have all valid children, if not, add one
    def maybe_add_child(self, action):
        if action not in self.children:
            # add next state UCTNode into children
            self.children[action] = UCTNode(
                self.state.move(ALL_ACTIONS[action]), action, parent=self)
        return self.children[action]
    
    # update visits and value to all node along the search tree
    def backup(self, value_estimate: float):
        current = self
        while current.parent is not None:
            current.number_visits += 1
            current.total_value += value_estimate
            if value_estimate != 0.5:  # not tie
                value_estimate = -value_estimate
            current = current.parent

### Q + U

- Ranking = Quality + Uncertainty (Q + U)
    - Quality: exploitation
    - Uncertainty: exploration
        - FOMO（fear of missing out）
        - P from policy network

$$
\begin{split}
&Q=\frac{t_i}{1+n_i}\\
&U=\sqrt{\ln N_i}\times \frac{P}{1+n_i}
\end{split}
$$

## Game state

In [4]:
class Board():  # must contain (scores,player,board,valid actions,move)
    # create constructor (init board class instance)
    def __init__(self, board=[[EMPTY for _ in range(BOARD_SIZE)] for _ in range(BOARD_SIZE)]):

        self.player = 1
        self.board = deepcopy(board)
        self.scores = [0, 0]
        self.left = 25

    def copy(self):
        new = Board(self.board)
        new.player = self.player
        new.scores = self.scores.copy()
        new.left = self.left
        return new

    def update_score(self, action):
        related = {
            (0, 0): [1, 6, 11, 21, 30, 31, ],
            (0, 1): [1, 7, 13, 22, 31, 32, ],
            (0, 2): [1, 8, 17, 18, 21, 23, 32, 33, ],
            (0, 3): [1, 9, 14, 22, 33, 34, ],
            (0, 4): [1, 10, 12, 23, 30, 34, ],
            (1, 0): [2, 6, 15, 24, 31, 35, ],
            (1, 1): [2, 7, 11, 17, 21, 25, 31, 32, 35, 36, ],
            (1, 2): [2, 8, 13, 14, 22, 24, 26, 32, 33, 36, 37, ],
            (1, 3): [2, 9, 12, 18, 23, 25, 33, 34, 37, 38, ],
            (1, 4): [2, 10, 16, 26, 34, 38, ],
            (2, 0): [3, 6, 17, 20, 21, 27, 35, 39, ],
            (2, 1): [3, 7, 14, 15, 22, 24, 28, 35, 36, 39, 40, ],
            (2, 2): [3, 8, 11, 12, 21, 23, 25, 27, 29, 30, 36, 37, 40, 41, ],
            (2, 3): [3, 9, 13, 16, 22, 26, 28, 37, 38, 41, 42, ],
            (2, 4): [3, 10, 18, 19, 23, 29, 38, 42, ],
            (3, 0): [4, 6, 14, 24, 39, 43, ],
            (3, 1): [4, 7, 12, 20, 25, 27, 39, 40, 43, 44, ],
            (3, 2): [4, 8, 15, 16, 24, 26, 28, 40, 41, 44, 45, ],
            (3, 3): [4, 9, 11, 19, 25, 29, 41, 42, 45, 46, ],
            (3, 4): [4, 10, 13, 26, 42, 46, ],
            (4, 0): [5, 6, 12, 27, 30, 43, ],
            (4, 1): [5, 7, 16, 28, 43, 44, ],
            (4, 2): [5, 8, 19, 20, 27, 29, 44, 45, ],
            (4, 3): [5, 9, 15, 28, 45, 46, ],
            (4, 4): [5, 10, 11, 29, 30, 46, ]
        }
        def add_score(base_sign, score):
            if base_sign == 1:
                    self.scores[0] += score
            else:
                    self.scores[1] += score
        def check_raw(base_sign, i):
            if self.board[i][0] == base_sign and self.board[i][1] == base_sign and self.board[i][2] == base_sign and self.board[i][3] == base_sign and self.board[i][4] == base_sign:
                add_score(base_sign,5)
        def check_column(base_sign, i):
            if self.board[0][i] == base_sign and self.board[1][i] == base_sign and self.board[2][i] == base_sign and self.board[3][i] == base_sign and self.board[4][i] == base_sign:
                add_score(base_sign,5)
        def check_5x(base_sign, i):
            if i == 1 and self.board[2][2] == base_sign and self.board[0][0] == base_sign and self.board[1][1] == base_sign and self.board[3][3] == base_sign and self.board[4][4] == base_sign:
                add_score(base_sign,5)
            if i == 2 and self.board[2][2] == base_sign and self.board[0][4] == base_sign and self.board[1][3] == base_sign and self.board[3][1] == base_sign and self.board[4][0] == base_sign:
                add_score(base_sign,5)
        def check_4x(base_sign, i):
            if i == 2 and self.board[0][3] == base_sign and self.board[1][2] == base_sign and self.board[2][1] == base_sign and self.board[3][0] == base_sign:
                add_score(base_sign,4)
            if i == 1 and self.board[0][1] == base_sign and self.board[1][2] == base_sign and self.board[2][3] == base_sign and self.board[3][4] == base_sign:
                add_score(base_sign,4)
            if i == 3 and self.board[4][3] == base_sign and self.board[3][2] == base_sign and self.board[2][1] == base_sign and self.board[1][0] == base_sign:
                add_score(base_sign,4)
            if i == 4 and self.board[4][1] == base_sign and self.board[3][2] == base_sign and self.board[2][3] == base_sign and self.board[1][4] == base_sign:
                add_score(base_sign,4)
        def check_3x(base_sign, i):
            if i == 1 and self.board[0][2] == base_sign and self.board[1][1] == base_sign and self.board[2][0] == base_sign:
                add_score(base_sign,3)
            if i == 2 and self.board[0][2] == base_sign and self.board[1][3] == base_sign and self.board[2][4] == base_sign:
                add_score(base_sign,3)
            if i == 4 and self.board[4][2] == base_sign and self.board[3][1] == base_sign and self.board[2][0] == base_sign:
                add_score(base_sign,3)
            if i == 3 and self.board[4][2] == base_sign and self.board[3][3] == base_sign and self.board[2][4] == base_sign:
                add_score(base_sign,3)
        def check_big5(base_sign):
            if self.board[2][2] == base_sign and self.board[0][0] == base_sign and self.board[4][0] == base_sign and self.board[0][4] == base_sign and self.board[4][4] == base_sign:
                add_score(base_sign,10)
        def check_small5(base_sign, index):
            i, j = index // 3 + 1, index % 3 + 1
            if self.board[i][j] == base_sign and self.board[i-1][j-1] == base_sign and self.board[i-1][j+1] == base_sign and self.board[i+1][j-1] == base_sign and self.board[i+1][j+1] == base_sign:
                add_score(base_sign,5)
        def check_well(base_sign, index):
            i, j = index // 4, index % 4
            if self.board[i][j] == base_sign and self.board[i][j+1] == base_sign and self.board[i+1][j] == base_sign and self.board[i+1][j+1] == base_sign:
                add_score(base_sign,1)
        
        i, j = action
        base_sign = self.board[i][j]
        for index in related[action]:
            if 1 <= index <= 5:
                check_raw(base_sign, index-1)
            elif 6 <= index <= 10:
                check_column(base_sign, index-6)
            elif 11 <= index <= 12:
                check_5x(base_sign, index-10)
            elif 13 <= index <= 16:
                check_4x(base_sign, index-12)
            elif 17 <= index <= 20:
                check_3x(base_sign, index-16)
            elif 21 <= index <= 29:
                check_small5(base_sign, index-21)
            elif index == 30:
                check_big5(base_sign)
            elif 31 <= index <= 46:
                check_well(base_sign, index-31)
    
    def check_winner(self):
        if self.scores[0] > self.scores[1]:
            return 1
        elif self.scores[0] < self.scores[1]:
            return -1
        else:
            return 0
            
    # make move
    def move(self, action):
        row, col = action
        # create new board instance that inherits from the current state
        next_state = Board(self.board)
        
        # make move
        next_state.board[row][col] = self.player
        next_state.left = self.left - 1

        # update scores
        next_state.scores = self.scores.copy()
        next_state.update_score(action)

        # swap players
        next_state.player = -self.player
    
        # return new board state
        return next_state
        
    # generate legal moves to play in the current position
    def generate_actions(self):
        # define states list (move list - list of available actions to consider)
        actions = []
        
        # loop over board rows
        for row in range(BOARD_SIZE):
            # loop over board columns
            for col in range(BOARD_SIZE):
                # make sure that current square is empty
                if self.board[row][col] == EMPTY:
                    # append available (row, col) action
                    actions.append((row, col))
        
        # return the list of available actions (tuple)
        return actions



In [15]:
class Fiver_tigers():
    def __init__(self, board = Board()):
        self.state = board
        self.board = self.state.board
        self.winner = None
        self.player = self.state.player
        self.left = self.state.left
        self.scores = self.state.scores

    def available_actions(self):
        return self.state.generate_actions()
    
    def render(self):
        print()
        print("board:")
        print("   0 1 2 3 4")
        for i in range(BOARD_SIZE):
            print(i,end="  ")
            for j in self.board[i]:
                print(j if j != 0 else '-',end=" ")
            print()
        print()

    def move(self, action):
        self.state = self.state.move(action)
        self.board = self.state.board
        self.player = -self.player
        self.left = self.state.left
        self.scores = self.state.scores
        if self.left == 0:
            self.winner = self.state.check_winner() 

## Policy network & Value network

- 结合使用策略网络（Policy network）来指导搜索方向, 并使用价值网络来评估棋局的潜在价值, 可以显著减少搜索树的大小，提高搜索的效率。
    - 策略网络（Policy network）能够从先前的对局中学习到有效的走棋模式和策略，这相当于在搜索过程中加入了大量的“先验知识”（child_priors）。
- 价值网络（value network）可以给出对当前棋局胜负的直接评估，而不需要到达游戏的终局。这种评估能力对于减少搜索深度、加速决策过程至关重要。

In [28]:
model = models.Sequential()

# 公共层
model.add(layers.Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=input_shape))
model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(layers.Flatten())
# 策略头（Policy Head）
policy_head = layers.Dense(128, activation='relu')(model.output)
policy_head = layers.Dense(25, activation='linear')(policy_head)  # 输出为动作概率(before softmax)

# 价值头（Value Head）
value_head = layers.Dense(128, activation='relu')(model.output)
value_head = layers.Dense(1, activation='tanh')(value_head)  # 输出为状态值（理论范围在 -1 到 1）

policy_value_network = tf.keras.Model(inputs=model.input, outputs=[policy_head, value_head])
#policy_value_network.summary()
#np.random.random([25]), np.random.random()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

In [None]:
# save
#policy_value_network.save('policy_value_model', save_format='tf')

# load
policy_value_network = tf.keras.models.load_model('policy_value_model_cnn')
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

In [6]:
def train_network(network, experiences):
    states, mcts_probs, mcts_values = experiences
    
    with tf.GradientTape() as tape:
        policy_predictions, value_predictions = network(states)
        
        policy_loss = tf.keras.losses.categorical_crossentropy(mcts_probs, policy_predictions, from_logits=True)
        policy_loss = tf.reduce_mean(policy_loss)
        
        value_loss = tf.keras.losses.mean_squared_error(mcts_values, value_predictions)
        value_loss = tf.reduce_mean(value_loss)

        loss = policy_loss + value_loss
    
    gradients = tape.gradient(loss, network.trainable_variables)
    optimizer.apply_gradients(zip(gradients, network.trainable_variables))
    
    return loss


In [7]:
def state2input(state):
    state_array = np.array(state.board) #(5,5)
    board_cnn_input = np.expand_dims(state_array, axis=-1) # (5,5,1)
    board_cnn_input = np.expand_dims(board_cnn_input, axis=0) # (1,5,5,1)
    #state_flat = np.array(state.board).flatten()  # (25,)
    #return tf.convert_to_tensor(state_flat, dtype=tf.float32)[tf.newaxis, :]  # (1,25)
    return board_cnn_input


In [8]:
def get_experiences(memory_buffer):
    experiences = random.sample(memory_buffer, k=MINIBATCH_SIZE)  # list
    states = tf.convert_to_tensor(np.array([np.expand_dims(np.array(e[0].board), axis=-1) for e in experiences if e is not None]),dtype=tf.float32)  # (n,5,5,1) tensor
    #states = tf.convert_to_tensor(np.array([np.array(e[0].board).flatten() for e in experiences if e is not None]),dtype=tf.float32) # (n,25) tensor
    mcts_probs = tf.convert_to_tensor(np.array([e[1] for e in experiences if e is not None]),dtype=tf.float32) # (n,25) tensor
    mcts_values = tf.convert_to_tensor(np.array([[e[2]] for e in experiences if e is not None]), dtype=tf.float32) # (n,1) tensor

    return (states, mcts_probs, mcts_values)

## UCT_search

In [9]:
def softmax(x):  
    probs = np.exp(x - np.max(x))
    probs /= np.sum(probs)
    return probs

In [10]:
class DummyNode(object):  # for root node
    def __init__(self):
        self.parent = None
        self.child_total_value = collections.defaultdict(float)
        self.child_number_visits = collections.defaultdict(float)

In [11]:
def UCT_search(state, num_reads):
    # root is current state need to select a best move
    root = UCTNode(state, action=None, parent=DummyNode())
    # repeated simuations for 'num_reads' times
    for i in range(num_reads):
        # start from root
        leaf = root.select_leaf()
        # child_priors: [0, 1]
        child_priors, value_estimate = policy_value_network(state2input(leaf.state))
        #child_priors, value_estimate =  np.random.random([25]), np.random.random()
        child_priors = softmax(tf.squeeze(child_priors).numpy())
        if leaf.state.left == 0:
            winner = leaf.state.check_winner()
            if winner == 0:
                value_estimate = 0.5
            elif winner == 1:
                value_estimate = 1
            else:
                value_estimate = -1
        else:
            value_estimate = tf.squeeze(value_estimate).numpy().item()

        leaf.expand(child_priors)
        leaf.backup(value_estimate)
    act_visits= [(action, root.child_number_visits[action]) for action in root.children]
    acts, visits = zip(*act_visits)
    act_probs = softmax(1.0 / temp * np.log(np.array(visits) + 1e-10))
    return acts, act_probs


In [None]:
start = time.time()

num_rounds = 500

total_loss_history = []

# Create a memory buffer D with capacity N
memory_buffer = deque(maxlen=MEMORY_SIZE)

for i in range(num_rounds):
    
    # Reset the environment to the initial state and get the initial state
    game = Fiver_tigers()
    states, mcts_probs = [], []
    
    for t in range(25):
        state = game.state # Board class
        print(f'\rEpisode {i+1}: now making move {t+1}, now score: {game.state.scores}', end='')
        acts, act_probs = UCT_search(state, iterations)
        move_probs = np.zeros(25)
        move_probs[list(acts)] = act_probs
        p=0.75*act_probs + 0.25*np.random.dirichlet(dirichlet * np.ones(len(act_probs)))
        move = np.random.choice(
            acts,
            p=p/np.sum(p)
        )

        states.append(state)  # Board class
        mcts_probs.append(move_probs)  # np array (25,)

        # make actions
        game.move(ALL_ACTIONS[move])
        

    winners = np.zeros(25)
    if game.winner == -1:
        winners[0::2] = 1
        winners[1::2] = -1
    elif game.winner == 1:
        winners[0::2] = -1
        winners[1::2] = 1
    else:
        winners += 0.5

    memory_buffer.extend(zip(states, mcts_probs, winners))

    if len(memory_buffer) > MINIBATCH_SIZE:
        experiences = get_experiences(memory_buffer)
        loss = train_network(policy_value_network, experiences)
        total_loss_history.append(loss)

    if len(total_loss_history) >= 1:
        print(f"\rEpisode {i+1} | current loss: {np.mean(total_loss_history):.2f}")
    else:
        print(f"\rEpisode {i+1} | haven't update yet")

    if (i+1) % 50 == 0:
        policy_value_network.save('policy_value_model', save_format='tf')

        
tot_time = time.time() - start

print(f"\nTotal Runtime: {tot_time:.2f} s ({(tot_time/60):.2f} min)")

In [None]:
import pygame
import sys,os
import time

def play(human_experiences):
    # 字体路径
    FONT_PATH = './resources/OpenSans-Regular.ttf'

    pygame.init()
    size = width, height = 900, 600

    # Colors
    orange = (174, 130, 44)
    black = (0, 0, 0)
    white = (255, 255, 255)

    screen = pygame.display.set_mode(size)

    mediumFont = pygame.font.Font(FONT_PATH, 32)
    largeFont = pygame.font.Font(FONT_PATH, 50)
    scoreFont = pygame.font.Font(FONT_PATH, 40)

    user = None
    ai_turn = False
    ai_move = 0
    game = Fiver_tigers() 
    board = game.board
    states, mcts_probs = [], []


    while True:

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                sys.exit()

        screen.fill(orange)

        # Let user choose a player.
        if user is None:

            # Draw title
            title = largeFont.render("Play Five-Tigers", True, black)
            titleRect = title.get_rect()
            titleRect.center = ((width / 2), 60)
            screen.blit(title, titleRect)

            # Draw buttons
            playXButton = pygame.Rect((width / 8), (height / 2), width / 4, 60)
            playX = mediumFont.render("Play first", True, orange)
            playXRect = playX.get_rect()
            playXRect.center = playXButton.center
            pygame.draw.rect(screen, black, playXButton)
            screen.blit(playX, playXRect)

            playOButton = pygame.Rect(5 * (width / 8), (height / 2), width / 4, 60)
            playO = mediumFont.render("Play second", True, orange)
            playORect = playO.get_rect()
            playORect.center = playOButton.center
            pygame.draw.rect(screen, black, playOButton)
            screen.blit(playO, playORect)
            
            # Check if button is clicked
            click, _, _ = pygame.mouse.get_pressed()
            if click == 1:
                mouse = pygame.mouse.get_pos()
                if playXButton.collidepoint(mouse):
                    time.sleep(0.2)
                    user = 1
                elif playOButton.collidepoint(mouse):
                    time.sleep(0.2)
                    user = -1

        else:

            # Draw game board
            tile_size = 80
            tile_origin = (width / 2 - (2.5 * tile_size),
                        height / 2 - (2.5 * tile_size))
            line_origin = (width / 2 - (2 * tile_size),
                        height / 2 - (2 * tile_size))
            tiles = []
            for i in range(5):
                row = []
                for j in range(5):
                    rect = pygame.Rect(
                        tile_origin[0] + j * tile_size,
                        tile_origin[1] + i * tile_size,
                        tile_size, tile_size
                    )
                    if i != 4 and j != 4:
                        lines = pygame.Rect(
                            line_origin[0] + j * tile_size,
                            line_origin[1] + i * tile_size,
                            tile_size, tile_size
                        )
                        pygame.draw.rect(screen, black, lines, 3)

                    if board[i][j] != 0:
                        if board[i][j] == 1:
                            pygame.draw.circle(screen, black, rect.center, 30, 0)
                        elif board[i][j] == -1:
                            pygame.draw.circle(screen, white, rect.center, 30, 0)

                    row.append(rect)
                tiles.append(row)

            game_over = game.left == 0
            player = game.player

            # Show title
            if game_over:
                winner = game.winner
                if winner == 0:
                    title = f"Game Over: Tie."
                else:
                    if game.winner == user:
                        title = f"Game Over: Human wins."
                    else:
                        title = f"Game Over: AI wins."
            elif user == player:
                if user == 1:
                    title = f"Play as black"
                elif user == -1:
                    title = f"Play as white"
            else:
                title = f"Computer thinking..."
            title = largeFont.render(title, True, black)
            titleRect = title.get_rect()
            titleRect.center = ((width / 2), 40)
            screen.blit(title, titleRect)

            # Check for AI move
            if user != player and not game_over:
                if ai_turn:
                    state = game.state
                    acts, act_probs = UCT_search(state, iterations)
                    move_probs = np.zeros(25)
                    move_probs[list(acts)] = act_probs
                    p=act_probs
                    move = np.random.choice(
                        acts,
                        p=p/np.sum(p)
                    )

                    states.append(state)  # Board class
                    mcts_probs.append(move_probs)
                    ai_move += 1
                    # 更新根节点并重用搜索树
                    game.move(ALL_ACTIONS[move])
                    board = game.board
                    ai_turn = False
                else:
                    ai_turn = True

            # Check for a user move
            click, _, _ = pygame.mouse.get_pressed()
            if click == 1 and user == player and not game_over:
                mouse = pygame.mouse.get_pos()
                for i in range(5):
                    for j in range(5):
                        if (board[i][j] == 0 and tiles[i][j].collidepoint(mouse)):
                            game.move((i,j))
                            board = game.board

            if game_over:
                againButton = pygame.Rect(width / 3, height - 70, width / 3, 60)
                again = mediumFont.render("Play Again", True, orange)
                againRect = again.get_rect()
                againRect.center = againButton.center
                pygame.draw.rect(screen, black, againButton)
                screen.blit(again, againRect)

                scores = scoreFont.render("scores:", True, black)
                scoresRect = scores.get_rect()
                scoresRect.center = (100, 230)
                scoresRect.x = 25
                screen.blit(scores, scoresRect)

                aiscores = scoreFont.render(f"AI: {game.scores[1-int((1-user)/2)]}", True, black)
                aiscoresRect = aiscores.get_rect()
                aiscoresRect.center = (100, 300)
                aiscoresRect.x = 25
                screen.blit(aiscores, aiscoresRect)

                huscores = scoreFont.render(f'Human: {game.scores[int((1-user)/2)]}', True, black)
                huscoresRect = huscores.get_rect()
                huscoresRect.center = (100, 370)
                huscoresRect.x = 25
                screen.blit(huscores, huscoresRect)

                if game.winner == user:
                    ai_win = np.ones(ai_move)
                elif game.winner == -user:
                    ai_win = -np.ones(ai_move)
                else:
                    ai_win = np.zeros(ai_move)
                human_experiences.extend(zip(states, mcts_probs, ai_win))

                click, _, _ = pygame.mouse.get_pressed()
                if click == 1:
                    mouse = pygame.mouse.get_pos()
                    if againButton.collidepoint(mouse):
                        time.sleep(0.2)
                        user = None
                        game = Fiver_tigers()
                        states, mcts_probs = [], []
                        ai_move = 0
                        board = game.board
                        ai_turn = False

        pygame.display.flip()


In [None]:
human_experiences = deque(maxlen=MEMORY_SIZE)
play(human_experiences)

In [15]:
def expend_experiences(memory_buffer):
    # copy Board class to create new (state, probs, values)
    def copy_Board(BOARD):
        new = Board(BOARD.board)
        new.player = BOARD.player
        new.scores = BOARD.scores.copy()
        new.left = BOARD.left
        return new
        
    expended_memory_buffer = deque(maxlen=MEMORY_SIZE*8)

    for experiences in memory_buffer:
        # original
        expended_memory_buffer.append(experiences)  
        state = np.array(experiences[0].board)
        action_probs = experiences[1].reshape((5,5))
        value = experiences[2]

        # flip
        flip_state = np.fliplr(state)
        flip_action_probs = np.fliplr(action_probs)
        flip_board = copy_Board(experiences[0])
        flip_board.board = flip_state.tolist()
        expended_memory_buffer.append((flip_board, flip_action_probs.reshape(25), value))

        # rotate
        for k in range(1,4):
            rotate_board, rotate_flip_board = copy_Board(experiences[0]), copy_Board(experiences[0])
            rotate_board.board, rotate_flip_board.board = np.rot90(state,k).tolist(), np.rot90(flip_state,k).tolist()
            expended_memory_buffer.append((rotate_board, np.rot90(action_probs,k).reshape(25), value))
            expended_memory_buffer.append((rotate_flip_board, np.rot90(flip_action_probs,k).reshape(25), value))
    return expended_memory_buffer

In [None]:
extended_memory_buffer = expend_experiences(memory_buffer)
print(len(extended_memory_buffer))

In [None]:
total_loss_history = []
for i in range(50000):
    experiences = get_experiences(extended_memory_buffer)
    loss = train_network(policy_value_network, experiences)
    #policy_value_network.save('policy_value_model', save_format='tf')
    total_loss_history.append(loss)
    print(f"\rIteration {i+1} | current loss: {total_loss_history[-1]:.2f}",end='')
    if (i+1) % 1000 == 0:
        av_latest_loss = np.mean(total_loss_history[-1000:])
        print(f"\rIteration {i+1} | current loss: {av_latest_loss:.2f}")
    if (i+1) % 5000 == 0:
        policy_value_network.save('policy_value_model_cnn', save_format='tf')