- 일단 중복 고려 안함. 이동 시 발생하는 reward 0.
- 오직 게임이 끝나고 win, lose, draw에 따라서만 reward 발생
- MTCS 구현
- [MTCS 코드 레퍼런스](https://cafe.daum.net/oracleoracle/SSSv/9?q=몬테카를로%20트리%20탐색)

In [1]:
import time
import os
import pickle
import numpy as np
import pandas as pd
from typing import Tuple
from collections import deque
import copy
from scipy.special import softmax
import random
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 틱택토 환경

- player = True : 컴퓨터 ('X')
- player = False : 상대방 ('O')

In [47]:
class Environment:
    def __init__(self):
        self.n = 3
        self.num_actions = self.n**2
        self.present_state = np.zeros((self.n, self.n))
        self.action_space = np.arange(self.num_actions)
        self.available_actions = np.ones(self.num_actions)
        self.reward_dict = {'win':1, 'lose':-1, 'draw': -0.1, 'good_action':0, 'overlapped':0}
        self.done = False


    def step(self, action_idx:int, player:bool):
        '''
        에이전트가 선택한 action에 따라 주어지는 next_state, reward, done
        '''
        x, y = np.divmod(action_idx, self.n)

        self.present_state[x,y] = player*2 -1
        next_state = self.present_state
        done, is_win = self.is_done(next_state)
        reward = self.reward_dict['good_action']
        self.available_actions = self.check_available_action(self.present_state)

        if done:
            if is_win == "win":
                reward = self.reward_dict['win']
            elif is_win == "lose":
                reward = self.reward_dict['lose']
            else:
                reward = self.reward_dict['draw']

        self.done = done

        return next_state, reward, done, is_win


    def reset(self):
        '''
        게임판 초기화
        '''
        self.present_state = np.zeros((self.n, self.n))
        self.available_actions = np.ones(self.num_actions)
        self.done = False


    def render(self):
        '''
        print the current state
        '''
        render_state = np.array([['.','.','.'],
                                ['.','.','.'],
                                ['.','.','.']])
        render_str = ""
        for i in range(self.num_actions):
            x, y = np.divmod(i, 3)
            if self.present_state[x,y] == 1:
                render_state[x,y] = 'X'
            elif self.present_state[x,y] == -1:
                render_state[x,y] = 'O'

            render_str += f" {render_state[x,y]}"
            if (i+1) % 3 == 0:
                render_str += "\n" + "-"*11 + "\n"
            else:
                render_str += " |"

        print(render_str)


    def check_available_action(self, state):
        '''
        현재 state에서 가능한 actions array 반환
        원핫 인코딩 방식으로 구현
        '''
        impossible_actions = np.argwhere(state.reshape(-1) != 0)
        available_actions = np.ones(self.num_actions)
        available_actions[impossible_actions] = 0

        return available_actions


    def is_done(self, state):
        '''
        틱택토 게임 종료 조건 및 승리 여부 확인하는 함수
        '''
        is_done, is_win = False, "null"

        # 무승부 여부 확인
        if (state==0).sum()==0:
            is_done, is_win = True, "draw"

        else:
            axis_sum = np.concatenate((state.sum(axis=0), state.sum(axis=1)))
            diag_sum = np.array([state.trace(), np.fliplr(state).trace()])

            sum_array = np.concatenate((axis_sum, diag_sum))
            max_sum = np.max(sum_array)
            min_sum = np.min(sum_array)

            if max_sum == 3:
                is_done, is_win = True, "win"
            elif min_sum == -3:
                is_done, is_win = True, "lose"
            else:
                is_done, is_win = False, "null"

        return is_done, is_win

# MTCS 에이전트

- 헷갈려... 일단 밖에 빼서 따로 시험해봄

In [48]:
class Agent:
    def __init__(self, env, player:bool):
        self.env = env
        self.player = player

        self.n = self.env.n
        self.num_actions = self.env.num_actions
        self.actions = self.env.action_space

        self.stepsize = STEPSIZE
        self.gamma = GAMMA
        self.epsilon = EPSILON
        self.epsilon_decay = EPSILON_DECAY
        self.epsilon_min = EPSILON_MIN


    def UCT(self, environment, itermax, player):
        root_state = env.present_state
        root_node = Node(env=env, state=root_state, player = player)

        for _ in range(itermax):
            node = root_node
            state = copy.deepcopy(root_state)
            env = copy.deepcopy(environment)

            # selection
            while (node.available_actions.sum() > 0) and node.childnode_list != []:
                node = node.select_child()
                state, _, _, _ = env.step(node.action, node.player)

            # expansion
            if node.available_actions.sum() > 0:
                available_action = np.where(node.available_actions != 0)[0]

                action = np.random.choice(available_action)
                state, _, _, _ = env.step(action, node.player)
                node = node.add_child(action, state)

            # simulation
            while not env.done:
                available_action = np.where(env.available_actions != 0)[0]

                action = np.random.choice(available_action)
                state, _, _, _ = env.step(action, node.player)

            # backpropagation
            while node.parent_node != None:
                result = env.is_done(state)
                node.update(result)
                node = node.parent_node

        # 최적의 행동 선택
        s = sorted(root_node.childnode_list, key=lambda c: c.wins / c.visits)

        return s[-1].action


    def update_value_table(self, history):
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

        else:
            self.epsilon = self.epsilon_min

        G = 0
        for hist in reversed(history):
            action_idx, reward = hist

            G = self.gamma * G + reward
            self.returns[action_idx].append(G)
            self.value_table[action_idx] = np.mean(self.returns[action_idx])


    def get_action(self, state, available_actions):
        available_action = np.where(available_actions != 0)[0]
        if (np.random.rand() <= self.epsilon) or (not self.player):
            action = np.random.choice(available_action)

        else:
            available_value = self.value_table * available_actions
            action = np.argmax(available_value)

        return action

# Node 클래스, UCT 함수

In [49]:
class Node():
    def __init__(self, env, action = None, state = None, parent = None, player = True):
        self.env = copy.deepcopy(env)
        self.action = action
        self.player = player
        self.available_actions = self.env.check_available_action(state)
        self.parent_node = parent
        self.childnode_list = []

        self.visits = 0
        self.wins = 0


    def select_child(self):
        '''
        UCT를 이용한 child node 선택 (selection에 해당)
        '''
        # UTC 공식을 적용하여 값이 낮은 것부터 child node들을 정렬 (마지막이 최댓값을 갖는 노드)
        s = sorted(self.childnode_list, key=lambda c: c.wins / c.visits + np.sqrt(2 * np.log(self.visits) / c.visits))
        return s[-1]


    def add_child(self, action, state):
        '''
        action에 해당하는 child node 추가 (expansion에 해당)
        '''
        n = Node(self.env, action, copy.deepcopy(state), parent=self, player = not self.player)
        self.available_actions[action] = 0
        self.childnode_list.append(n)
        return n

    def update(self, result):
        '''
        이 노드의 값 업데이트 (backpropagation에 해당)
        '''
        self.visits += 1
        self.wins += result

In [55]:
def UCT(environment, itermax, player):
    root_state = environment.present_state
    root_node = Node(env=environment, state=root_state, player = player)

    for _ in range(itermax):
        node = root_node
        state = copy.deepcopy(root_state)
        env = copy.deepcopy(environment)

        # selection
        while (node.available_actions.sum() > 0) and node.childnode_list != []:
            node = node.select_child()
            state, _, _, _ = env.step(node.action, node.player)

        # expansion
        if node.available_actions.sum() > 0:
            available_action = np.where(node.available_actions != 0)[0]

            action = np.random.choice(available_action)
            state, _, _, _ = env.step(action, node.player)
            node = node.add_child(action, state)

        # simulation
        while not env.done:
            available_action = np.where(env.available_actions != 0)[0]

            action = np.random.choice(available_action)
            state, _, _, _ = env.step(action, node.player)

        # backpropagation
        while node.parent_node != None:
            is_done, is_win = env.is_done(state)
            result = env.reward_dict[is_win]
            node.update(result)
            node = node.parent_node

    # 최적의 행동 선택
    s = sorted(root_node.childnode_list, key=lambda c: c.wins / c.visits)

    return s[-1].action

## UCT main

In [56]:
MAXITER = 10000
env = Environment()

In [57]:
def UCTPlayGame(first_player:bool):
    env = Environment()
    state = env.present_state
    done = env.done
    player = first_player
    is_win = "null"

    env.render()

    while not done:
        print(f"player:{player}")
        if player:
            action = UCT(env, MAXITER, player)

        else:
            available_action = np.where(env.available_actions != 0)[0]
            action = np.random.choice(available_action)

        state, _, done, is_win = env.step(action, player)
        player = not player

        env.render()
        print(f"win:{is_win}")

In [58]:
UCTPlayGame(True)

 . | . | .
-----------
 . | . | .
-----------
 . | . | .
-----------

player:True


  s = sorted(self.childnode_list, key=lambda c: c.wins / c.visits + np.sqrt(2 * np.log(self.visits) / c.visits))
  s = sorted(self.childnode_list, key=lambda c: c.wins / c.visits + np.sqrt(2 * np.log(self.visits) / c.visits))


 . | . | .
-----------
 . | . | .
-----------
 . | X | .
-----------

win:null
player:False
 O | . | .
-----------
 . | . | .
-----------
 . | X | .
-----------

win:null
player:True
 O | . | .
-----------
 . | . | .
-----------
 X | X | .
-----------

win:null
player:False
 O | . | .
-----------
 O | . | .
-----------
 X | X | .
-----------

win:null
player:True
 O | X | .
-----------
 O | . | .
-----------
 X | X | .
-----------

win:null
player:False
 O | X | .
-----------
 O | . | O
-----------
 X | X | .
-----------

win:null
player:True
 O | X | .
-----------
 O | . | O
-----------
 X | X | X
-----------

win:win
