In [1]:
import pickle
import os
import glob
from Yinsh.yinsh_model import YinshGameRule 
from Yinsh.yinsh_utils import *
import numpy as np
import re
import copy
import random
from keras.callbacks import TensorBoard
import time
from tqdm import tqdm

In [None]:
def calculate_reward(old_state,new_state, agent_index):
    # Check if any ring is removed or this is the end of game
    if new_state.rings_won[agent_index] == 3:
        return 10
    # If the opponent won, return -10
    elif new_state.rings_won[int(not agent_index)] == 3:
        return -10
    # Check if any ring is removed
    if new_state.rings_won[0] != old_state.rings_won[0] or new_state.rings_won[1] != old_state.rings_won[1]:
        own_change = new_state.rings_won[agent_index] - old_state.rings_won[agent_index]
        opp_change = new_state.rings_won[int(not agent_index)] - old_state.rings_won[int(not agent_index)]
        return 3*own_change - 3*opp_change
    
    return 0
    # Finally, return the potential function
    # return potential_func(new_state, agent_index) - potential_func(old_state, agent_index)
        

"""
game_rule_static = YinshGameRule(2)
def potential_func(state, agent_index):
    # Calculate potential function for a given state using heuristic by Khoi
    if state.rings_to_place > 0: return 0
    return 0.5*(ChainHeuristic(state,agent_index,0) - 0.1*ChainHeuristic(state,agent_index,1))

def ChainHeuristic( s, agent_index, opponent_perspective: int = 0, pts: dict = {0:0, 1:0, 2:1, 3:3, 4:5, 5:6}):
    '''
    Heuristic that returns score based on the feasible, unblocked sequences of markers.

    params
    ---
    - s: Game state
    - opponent_perspective: Whose view are we looking at? Our view (0) or opponent's view (1)?
    - pts: A dictionary mapping chains to corresponding point
    '''
    # What color are we looking at?
    view = agent_index ^ opponent_perspective
    tot_mark = 0
    ring = str(RING_1 if view else RING_0)
    counter = str(CNTR_1 if view else CNTR_0)
    opponent_counter = str(CNTR_0 if view else CNTR_1)
    # Get all markers first
    # Return the list of all locations of a specific type of grid on the board
    lookup_pos = lambda n,b : list(zip(*np.where(b==n)))
    markers = lookup_pos(CNTR_1 if view else CNTR_0,s.board)
    lines_of_interest = set()
    # Get all lines with markers
    for m in markers:
        for line in ['v','h','d']:
            lines_of_interest.add(tuple(game_rule_static.positionsOnLine(m,line)))
    # For each line that has markers, see if a feasible chain exists
    # R: ring       M: my marker     M_opponent: opponent marker
    for p in lines_of_interest:
        p_str  = ('').join([str(s.board[i]) for i in p])
        # Chains of 5 mixed R/M
        for st in re.findall(f'[{ring}{counter}]*',p_str):
            # How many rings needed to move to get 5-marker streak?
            if len(st)>4: tot_mark+=pts[5-st.count(ring)]
        # Incomplete but feasible EMPTY-R-nM (not summing up to 5 yet)
        for st in re.findall(f'{str(EMPTY)*3}{ring}{counter*1}'
                            + f'|{str(EMPTY)*2}{ring}{counter*2}'
                            + f'|{str(EMPTY)*1}{ring}{counter*3}',p_str):
            tot_mark+=pts[st.count(counter)]
        # Incomplete but feasible nM-R-EMPTY (not summing up to 5 yet)
        for st in re.findall(f'{counter*1}{ring}{str(EMPTY)*3}'
                            + f'|{counter*2}{ring}{str(EMPTY)*2}'
                            + f'|{counter*3}{ring}{str(EMPTY)*1}',p_str):
            tot_mark+=pts[st.count(counter)]
        # R-nM_opponent-EMPTY
        for st in re.findall(f'{ring}{opponent_counter*1}{str(EMPTY)*3}'
                            + f'|{ring}{opponent_counter*2}{str(EMPTY)*2}'
                            + f'|{ring}{opponent_counter*3}{str(EMPTY)*1}'
                            # Needs at least a space to land after jumping over
                            + f'|{ring}{opponent_counter*4}{str(EMPTY)*1}',p_str):
            tot_mark+=pts[st.count(counter)]
        # EMPTY-nM_opponent-R
        for st in re.findall(f'{str(EMPTY)*3}{opponent_counter*1}{ring}'
                            + f'|{str(EMPTY)*2}{opponent_counter*2}{ring}'
                            + f'|{str(EMPTY)*1}{opponent_counter*3}{ring}'
                            # Needs at least a space to land after jumping over
                            + f'|{str(EMPTY)*1}{opponent_counter*4}{ring}',p_str):
            tot_mark+=pts[st.count(counter)]
    return tot_mark

"""

In [3]:

replay_memory = [] # Observed state, action, reward, next state
folder_path = 'replays'

for filename in glob.glob(os.path.join(folder_path, '*.replay')):
    replay = pickle.load(open(filename,'rb'),encoding="bytes")
    # Initial game state
    num_of_agent = replay["num_of_agent"]
    game_rule = YinshGameRule(num_of_agent)
    # Run the replay
    # Old state and action for agents
    old_state = [None,None]
    old_action = [None,None]
    end_unexpected = False
    for item in replay["actions"]:
        (index, info), = item.items()
        selected = info["action"]
        agent_index = info["agent_id"]
        if old_state[agent_index] is not None:
            current_state = copy.deepcopy(game_rule.current_game_state)
            reward = calculate_reward(old_state[agent_index],current_state, agent_index)
            replay_memory.append((old_state[agent_index], old_action[agent_index], reward,current_state,agent_index,False))
        old_state[agent_index] = copy.deepcopy(game_rule.current_game_state)
        old_action[agent_index] = selected
        # print(info)
        game_rule.current_agent_index = agent_index          
        action_candidates = game_rule.getLegalActions(game_rule.current_game_state, agent_index)

        if selected not in action_candidates:
            print("Error: illegal action")
            old_state[agent_index] = None
            old_action[agent_index] = None
            end_unexpected = True
            break
        game_rule.update(selected)
    # Game ends, do a final update
    if not game_rule.gameEnds():
        print(filename)
    assert game_rule.gameEnds() or end_unexpected
    if not end_unexpected:
        for agent_index in range(num_of_agent):
            current_state = copy.deepcopy(game_rule.current_game_state)
            reward = calculate_reward(old_state[agent_index],current_state, agent_index)
            replay_memory.append((old_state[agent_index], old_action[agent_index], reward,current_state,agent_index,True))  
        

In [4]:
len(replay_memory) # 57068

57068

In [5]:
import tensorflow as tf
from tensorflow import keras

In [6]:
model = keras.Sequential()
model.add(keras.Input(shape=(11, 11, 4)))
model.add(keras.layers.Conv2D(filters=64,kernel_size = 3,activation = 'relu'))
model.add(keras.layers.Conv2D(filters=128,kernel_size = 3,activation = 'relu'))
model.add(keras.layers.Conv2D(filters=256,kernel_size = 3,activation = 'relu'))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(units=1933,activation = 'softmax'))
model.compile(loss=tf.keras.losses.Huber(), optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.0001), metrics=['accuracy'])

model.load_weights('main_model_dqn_replay.h5')

2022-05-24 13:15:23.479421: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [7]:
def reshape_state_for_net(state,agent_index):
    # Reshape the state to be [11, 11, 4] => 4 dimensions are our ring, our marker, opponent's ring, opponent's marker
    reshaped_state = np.zeros((11, 11, 4),dtype=np.int8)
    assert state.board.shape[0] == state.board.shape[1] == 11
    for i in range (0, 11):
        for j in range (0, 11):
            # If the pos is illegal, update all layers to be 0
            if state.board[i][j] == 5:
                for z in range (0, 4):
                    reshaped_state[i][j][z] = 0
            elif state.board[i][j] == 1: # RING_0
                # If the pos is our ring, update layer 0 to be 1
                if agent_index == 0:
                    reshaped_state[i][j][0] = 1
                # If the pos is opponent's ring, update layer 2 to be 1
                else:
                    reshaped_state[i][j][2] = 1
            elif state.board[i][j] == 2: # CNTR_0
                # If the pos is our marker, update layer 1 to be 1
                if agent_index == 0:
                    reshaped_state[i][j][1] = 1
                # If the pos is opponent's marker, update layer 3 to be 1
                else:
                    reshaped_state[i][j][3] = 1
            elif state.board[i][j] == 3: # RING_1
                # If the pos is our ring, update layer 0 to be 1
                if agent_index == 1:
                    reshaped_state[i][j][0] = 1
                # If the pos is opponent's ring, update layer 2 to be 1
                else:
                    reshaped_state[i][j][2] = 1
            elif state.board[i][j] == 4: # CNTR_1
                # If the pos is our marker, update layer 1 to be 1
                if agent_index == 1:
                    reshaped_state[i][j][1] = 1
                # If the pos is opponent's marker, update layer 3 to be 1
                else:
                    reshaped_state[i][j][3] = 1
    return reshaped_state

In [8]:
board = np.zeros((11,11), dtype=np.int8)
for pos in ILLEGAL_POS:
    board[pos] = ILLEGAL
# A dictionary that map an action into an index in NN output
action_to_index_dict = {}
count = 0
legal_pos = []
for i in range(11):
    for j in range(11):
        if board[i,j] != ILLEGAL:
            legal_pos.append((i,j))
            action_to_index_dict[(i,j)] = count
            count += 1
assert len(legal_pos) == 85
for pos in legal_pos:
    for line in ['v', 'h', 'd']:
        if line == 'h':
            for i in range(11):
                if board[pos[0], i] != ILLEGAL and i != pos[1]:
                    # From pos[0],pos[1] to pos[0],i
                    action_to_index_dict[(pos[0],pos[1],pos[0],i)] = count
                    count += 1
        elif line == 'v':
            for i in range(11):
                if board[i, pos[1]] != ILLEGAL and i != pos[0]:
                    # From pos[0],pos[1] to pos[0],i
                    action_to_index_dict[(pos[0],pos[1],i,pos[1])] = count
                    count += 1
        elif line == 'd':
            for i in range(-10, 11):
                if i == 0: continue
                if (0 <= pos[0]+i <= 10 and 0 <= pos[1]-i <= 10 and board[pos[0]+i, pos[1]-i] != ILLEGAL):
                    action_to_index_dict[(pos[0],pos[1],pos[0]+i,pos[1]-i)] = count
                    count += 1

In [9]:
# Visualize the learning process
class ModifiedTensorBoard(TensorBoard):

    # Overriding init to set initial step and writer (we want one log file for all .fit() calls)
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.step = 1
        self.writer = tf.summary.create_file_writer(self.log_dir)
        self._log_write_dir = self.log_dir
    # Overriding this method to stop creating default log writer
    def set_model(self, model):
        self.model = model
        self._train_dir = os.path.join(self._log_write_dir, 'train')
        self._train_step = self.model._train_counter
        self._val_dir = os.path.join(self._log_write_dir, 'validation')
        self._val_step = self.model._test_counter
        self._should_write_train_graph = False

    # Overrided, saves logs with our step number
    # (otherwise every .fit() will start writing from 0th step)
    def on_epoch_end(self, epoch, logs=None):
        self.update_stats(**logs)

    # Overrided
    # We train for one batch only, no need to save anything at epoch end
    def on_batch_end(self, batch, logs=None):
        pass

    # Overrided, so won't close writer
    def on_train_end(self, _):
        pass

    # Creates writer, writes custom metrics and closes writer
    def update_stats(self, **stats):
        self._write_logs(stats, self.step)

    def _write_logs(self, logs, index):
        with self.writer.as_default():
            for name, value in logs.items():
                tf.summary.scalar(name, value, step=index)
                self.step += 1
                self.writer.flush()

tensorboard = ModifiedTensorBoard(log_dir="replay_train_logs/{}".format(int(time.time())))

In [10]:
# Get a minibatch of random samples from memory replay table
MINIBATCH_SIZE = 50
for i in tqdm(range(15000)):
    minibatch = random.sample(replay_memory, MINIBATCH_SIZE)
   #  old_state(not flatted) , old_action, reward,new state(not flatted), agent index,game end
    X = []
    y = []
    for experience in minibatch:
        # Get stored values
        agent_index = experience[4]
        old_state = reshape_state_for_net(experience[0],agent_index)
        old_state_prediction = model.predict(np.expand_dims(old_state,axis=0))[0]
        action_index = None
        action = experience[1]
        if action["type"] == "place ring":
            action_index = action_to_index_dict[action["place pos"]]
        elif action["type"] == "place and move" or  action["type"] == "place, move, remove":
            action_index = action_to_index_dict[(action["place pos"][0],action["place pos"][1],action["move pos"][0],action["move pos"][1])]

        reward = experience[2]
        new_state = np.expand_dims(reshape_state_for_net(experience[3],agent_index), axis=0) 
        # Get prediction of the new state
        new_state_prediction = model.predict(new_state)[0]
        # Get the target value
        target_value = None
        if experience[5]:
            target_value = reward
        else:
            target_value = reward + 0.9 * np.max(new_state_prediction)
        # Get the target value for the old state
        target_f = old_state_prediction
        target_f[action_index] = target_value
        # Train the network
        X.append(old_state)
        y.append(target_f)
    X = np.array(X)
    y = np.array(y)
    model.fit(X, y, epochs=1, verbose=0, callbacks=[tensorboard] )

100%|██████████| 15000/15000 [17:57:24<00:00,  4.31s/it]  


In [10]:
model.save_weights('main_model_dqn_replay.h5')