In [1]:
from gym_tictactoe.env import TicTacToeEnv, agent_by_mark, next_mark
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from myrl.buffers import ReplayBuffer
from myrl.utils import ExperimentWriter
# from myrl.value_functions import 

env = TicTacToeEnv()
obs = env.reset()

In [2]:
env.action_space

Discrete(9)

In [201]:
class RolloutPolicy:
    def __init__(self, env):
        self.env = env
    def act(self, obs, render=False):
        return env.action_space.sample()
    def rollout(self, obs, model, render=False):
        d = False
        rsum = 0
        while not d:
            obs, r, d, _ = model.step(obs, self.act(obs))
            rsum += r
        return rsum

class Model:
    def __init__(self, env):
        self.env = env
    def step(self, obs, action):
        self._set_env(obs)
        return self.env.step(action)
    def available_actions(self, obs):
        self._set_env(obs)
        return self.env.available_actions()
    def not_available_actions(self, obs):
        return self._list_cut(list(range(self.get_num_actions())), self.available_actions(obs))
    def _list_cut(self, l1, l2):
        toret = []
        for a1 in l1:
            if a1 not in l2:
                toret.append(a1)
        return toret
    def _set_env(self, obs):
        self.env.board = list(obs[0])
        self.env.mark  = obs[1] 
        self.done = False
    def get_num_actions(self):
        return self.env.action_space.n

class TreePolicy:
    def __init__(self):
        pass
    def act(self, obs, available_actions):
        import random
        return random.choice(available_actions)
    def get_action_probs(self, obs, available_actions):
        return [1/len(available_actions) for i in range(len(available_actions))]

class Backbone(nn.Module):
    def __init__(self, net_arch, middle_activation=F.relu, last_activation=F.relu):
        super().__init__()
        self.middle_activation = middle_activation
        self.last_activation = last_activation
        self.layers = nn.ModuleList([nn.Linear(a, b) for a, b in zip(net_arch[:-1], net_arch[1:])])
    def forward(self, h):
        h = h.view(h.shape[0], -1)
        for lay in self.layers[:-1]:
            h = self.middle_activation(lay(h))
        h = self.layers[-1](h)
        h = self.last_activation(h)
        return h

class ValueFunction(nn.Module):
    def __init__(self, net_arch, backbone):
        super().__init__()
        self.backbone = backbone
        self.layers = nn.ModuleList([nn.Linear(a, b) for a, b in zip(net_arch[:-1], net_arch[1:])])
    def forward(self, h):
        h = torch.tensor(h, dtype=torch.float)
        h = h.view(h.shape[0], -1)
        h = self.backbone(h)
        for lay in self.layers[:-1]:
            h = F.relu(lay(h))
        h = self.layers[-1](h)
        return h

class NNTreePolicy(nn.Module):
    def __init__(self, net_arch, backbone, temperature=1):
        super().__init__()
        self.backbone = backbone
        self.temperature = temperature
        self.layers = nn.ModuleList([nn.Linear(a, b) for a, b in zip(net_arch[:-1], net_arch[1:])])
    def forward(self, x, not_available_actions=None):
        h = torch.tensor(x, dtype=torch.float)
        h = h.view(h.shape[0], -1)
        h = self.backbone(h)
        for lay in self.layers[:-1]:
            h = F.relu(lay(h))
        h = self.layers[-1](h)/self.temperature
        if not_available_actions is not None and len(not_available_actions)>0:
            not_available_actions = torch.tensor(not_available_actions)
            try:
                h[0, not_available_actions] = float('-inf')
            except:
                print(not_available_actions)
                print(h)
                raise
        h = torch.softmax(h, dim=1)
        return h
    def act(self, obs, not_available_actions):
        obs = self.obs2testorobs(obs)
        h = self.forward(obs, not_available_actions=not_available_actions)
        action = np.random.choice(range(len(h[0])), p=h.detach().squeeze(0).numpy())
        return action      
    def get_action_probs(self, obs, available_actions):
        obs = self.obs2testorobs(obs)
        h = self.forward(obs)
        return h.tolist()[0]
    def obs2testorobs(self, obs):
        l2 = [1] if obs[1]=='O' else [-1]
        obs = torch.tensor([list(obs[0])+l2])
        obs[obs==2] = -1
        return obs

rollout_policy = RolloutPolicy(env)
model = Model(TicTacToeEnv())
backbone = Backbone([10, 16])
value_function = ValueFunction([16, 4, 1], backbone=backbone)
tree_policy = NNTreePolicy([16, 9, 9], backbone=backbone)

In [202]:
class Node:
    def __init__(self, obs, reward, change_child_rew_sign=True, reward_sign=1, done=False, parent=None):
        self.n = 0
        self.cumulative_reward = 0#reward
        self.parent = parent
        self.action2child = {}
        self.nchildren = 0
        self.taken_actions = []
        self.obs = obs
        self.done = done
        self.reward = reward
        self.reward_sign = reward_sign
        self.change_child_rew_sign = change_child_rew_sign
        self.nzeros = 0
        self.nones = 0
        self.nmones = 0
    def get_q(self):
        return self.cumulative_reward/(self.n)

    def backpropagate(self, r, gamma=1):
        self.n += 1
        self.cumulative_reward += r*gamma*self.reward_sign
        self.nzeros += 1 if r==0 else 0
        self.nones += 1 if r*self.reward_sign==1  else 0
        self.nmones+= 1 if r*self.reward_sign==-1 else 0
        if not( -1 <= self.get_q() <= 1 ) and 0:
            print(self.__dict__)
        if self.parent is None:
            return 
        self.parent.backpropagate(r, gamma=gamma)
    def print_parents(self):
        if self.parent is None:
            return
        self.parent.print_parents()

    def create_child(self, ChildType, policy, model):
        not_available_actions = self._list_union(self.taken_actions, model.not_available_actions(self.obs))
        # print(not_available_actions, self.parent)
        action = policy.act(self.obs, not_available_actions)
        self.taken_actions.append(action)
        obs, reward, done, info = model.step(self.obs, action)
        reward_sign= -self.reward_sign if self.change_child_rew_sign else self.reward_sign
        child = ChildType(obs, reward, done=done, reward_sign=reward_sign, parent=self, change_child_rew_sign=self.change_child_rew_sign)
        self.action2child[action] = child
        self.nchildren += 1
        return child, action

    def _list_cut(self, l1, l2):
        toret = []
        for a1 in l1:
            if a1 not in l2:
                toret.append(a1)
        return toret
    def _list_union(self, l1, l2):
        toret = copy.deepcopy(l2)
        for a1 in l1:
            if a1 not in l2:
                toret.append(a1)
        return toret       

    def rollout(self, rollout_policy, model, render=False):
        if self.done:
            return self.reward
        return rollout_policy.rollout(self.obs, model, render=render)

In [203]:
def UCB(root_node, policy, model, alpha=1):
    scores = [0 for i in range(model.get_num_actions())]
    all_actions = list(range(model.get_num_actions()))
    probs = policy.get_action_probs(root_node.obs, all_actions)
    minscore, a_minscore = float('inf'), -1
    for a in root_node.action2child.keys():
        child = root_node.action2child[a]
        u = -probs[a]/(1+child.n)
        q = child.get_q()
        score = q + alpha*u
        scores.append(score)
        if score < minscore:
            minscore = score
            a_minscore = a
    return scores, (a_minscore, minscore)

def MCTS(root_node, max_depth, n_times, policy, model, alpha_UCB=1):
    current_node  = root_node
    current_depth = 0
    n_times_done  = 0

    while n_times_done != n_times:
        if current_depth == max_depth or current_node.done:
            reward = current_node.rollout(rollout_policy, model)
            current_node.backpropagate(reward, gamma=0.99)
            current_node = root_node
            n_times_done += 1
            current_depth = 0
            model.env.done = False
        elif current_node.nchildren < len(model.available_actions(current_node.obs)):
            child, action = current_node.create_child(Node, policy, model) 
            current_node = child
            current_depth += 1
        else:
            scores, (a_minscore, minscore) = UCB(current_node, policy, model, alpha_UCB)
            current_node = current_node.action2child[a_minscore]
            current_depth += 1

    visits = []
    not_available_actions = root_node._list_union(root_node.taken_actions, model.not_available_actions(root_node.obs))
    for a in range(model.get_num_actions()):
        if a in root_node.action2child:
            visits.append(root_node.action2child[a].get_q())
        else:
            visits.append(float('inf'))
    return visits

In [204]:
def eval_winrate(totest, bench, env, n_games=100):
    wins = 0
    draws = 0
    rsum = 0
    for igame in range(n_games):
        done, reward = False, 0
        obs = env.reset()
        curr_policy = totest if igame<=n_games//2 else bench
        rew2count = 1 if igame<=n_games//2 else -1
        while not done:
            action = curr_policy.act(obs, model.not_available_actions(obs))
            obs, r, done, _ = env.step(action)
            rsum += r
            curr_policy = totest if curr_policy==bench else bench
            wins += 1 if r==rew2count else 0
        draws += 1 if r==0 else 0
    winrate = wins/n_games
    drawrate = draws/n_games
    return winrate, drawrate, (1-winrate-drawrate)

In [197]:
import numpy as np
test_env = TicTacToeEnv()
obs = test_env.reset()
done = False
test_env.render()
# rbuff = ReplayBuffer(200)
random_policy = TreePolicy()

while not done:
    rew_sign = 1 if test_env.mark==test_env.start_mark else -1
    root = Node(obs, 0, reward_sign=rew_sign)
    dic = MCTS(root, 10, 10, tree_policy, model, 100)
    dic = np.array(dic)
    tdic = torch.tensor([-dic])
    print("move probs= ", torch.softmax(tdic, dim=1).view(3, 3))
    move = np.argmin(dic)
    print(move, test_env.mark)
    print(dic.reshape(3, 3))
    obs, r, done, _ = test_env.step(move)
    test_env.render()
    print(" ")
    print(" ")

   | | 
  -----
   | | 
  -----
   | | 

move probs=  tensor([[0.0865, 0.0321, 0.2328],
        [0.2328, 0.0865, 0.0321],
        [0.2328, 0.0321, 0.0321]], dtype=torch.float64)
2 O
[[ 0.    0.99 -0.99]
 [-0.99  0.    0.99]
 [-0.99  0.99  0.99]]
   | |O
  -----
   | | 
  -----
   | | 

 
 
move probs=  tensor([[0.0567, 0.1527, 0.0000],
        [0.0567, 0.1527, 0.0567],
        [0.4109, 0.0567, 0.0567]], dtype=torch.float64)
6 X
[[ 0.99  0.     inf]
 [ 0.99  0.    0.99]
 [-0.99  0.99  0.99]]
   | |O
  -----
   | | 
  -----
  X| | 

 
 
move probs=  tensor([[0.0874, 0.2351, 0.0000],
        [0.0874, 0.0874, 0.2351],
        [0.0000, 0.0325, 0.2351]], dtype=torch.float64)
1 O
[[ 0.   -0.99   inf]
 [ 0.    0.   -0.99]
 [  inf  0.99 -0.99]]
   |O|O
  -----
   | | 
  -----
  X| | 

 
 
move probs=  tensor([[0.5496, 0.0000, 0.0000],
        [0.0759, 0.1468, 0.0759],
        [0.0000, 0.0759, 0.0759]], dtype=torch.float64)
0 X
[[-0.99   inf   inf]
 [ 0.99  0.33  0.99]
 [  inf  0.99  0.99]]
  X|

In [214]:
rollout_policy = RolloutPolicy(env)
model = Model(TicTacToeEnv())
backbone = Backbone([10, 16])
value_function = ValueFunction([16, 4, 1], backbone=backbone)
tree_policy = NNTreePolicy([16, 9, 9], backbone=backbone)

In [215]:
best_policy = copy.deepcopy(tree_policy)
# wll = ExperimentWriter('tb/alpha_tictacte_zero_valuef_accurateMCTS_')

In [217]:
test_env = TicTacToeEnv()
obs = test_env.reset()
done = False
test_env.render()
rbuff = ReplayBuffer(nitems=3, max_len=150)
bsize = 128
wll.new()
writer = wll.writer
opt = torch.optim.Adam(list(tree_policy.parameters())+list(value_function.parameters()), lr=5e-3)
import copy
best_tree_policy = copy.deepcopy(tree_policy)
best_opt = copy.deepcopy(opt)
best_vfunc = copy.deepcopy(value_function)

for game in range(10000):
    game_step = 0
    done = False
    tmp_buff = []
    while not done:
        game_step += 1
        rew_sign = 1 if test_env.mark==test_env.start_mark else -1
        root = Node(obs, 0, reward_sign=rew_sign)
        dic = MCTS(root, 10, 100, tree_policy, model, 100)
        dic = np.array(dic)

        tdic = torch.tensor([[-dic]])
        monte_probs = torch.softmax(tdic, dim=-1).detach()
        tensor_obs = tree_policy.obs2testorobs(obs).unsqueeze(0)
        tmp_buff.append([tensor_obs, monte_probs, rew_sign])
        
        move = np.argmin(dic)
        obs, r, done, _ = test_env.step(move)

    for elements in tmp_buff:
        rbuff.add(elements[0], elements[1], elements[2]*torch.tensor([[[r]]]).float())

    if len(rbuff) > bsize:
        for opt_step in range(4):
            tensor_obs, monte_probs, game_finish = rbuff.get(bsize)
            policy_probs = tree_policy(tensor_obs)
            loss_policy = -(monte_probs*torch.log(policy_probs+1e-8)).mean()
            loss_value  = ((value_function(tensor_obs)-game_finish)**2).mean()
            loss = loss_policy + loss_value
            opt.zero_grad()
            loss.backward()
            opt.step()
        writer.add_scalar('loss/loss', loss.item(), game)
        writer.add_scalar('loss/policy', loss_policy.item(), game)
        writer.add_scalar('loss/vfunc', loss_value.item(), game)
    else:
        print(len(rbuff))
    if game % 5 == 0:
        winrate, drawrate, loserate = eval_winrate(tree_policy, best_tree_policy, test_env, n_games=100)
        print("ARENA!!! ", winrate, drawrate, loserate)
        if winrate > loserate:
            best_tree_policy = copy.deepcopy(tree_policy)
            best_opt = copy.deepcopy(opt)
            best_vfunc = copy.deepcopy(value_function)
            print("upgrade", winrate, drawrate, loserate)
        else:
            tree_policy = copy.deepcopy(best_tree_policy)
            opt = copy.deepcopy(best_opt)
            value_function = copy.deepcopy(best_vfunc)

        
    winrate2, drawrate2, _ = eval_winrate(tree_policy, rollout_policy, test_env, n_games=100)
    writer.add_scalar('winrate/winrate', winrate2, game)
    writer.add_scalar('winrate/drawrate', drawrate2, game)
    print(game, "winrate=", winrate, drawrate, winrate2, drawrate2)

    obs = test_env.reset()


   | | 
  -----
   | | 
  -----
   | | 

9
ARENA!!!  0.46 0.12 0.42000000000000004
upgrade 0.46 0.12 0.42000000000000004
0 winrate= 0.46 0.12 0.52 0.11
16
1 winrate= 0.46 0.12 0.52 0.08
25
2 winrate= 0.46 0.12 0.51 0.08
34
3 winrate= 0.46 0.12 0.53 0.11
43
4 winrate= 0.46 0.12 0.58 0.08
52
ARENA!!!  0.42 0.22 0.3600000000000001
upgrade 0.42 0.22 0.3600000000000001
5 winrate= 0.42 0.22 0.51 0.08
61
6 winrate= 0.42 0.22 0.48 0.06
66
7 winrate= 0.42 0.22 0.51 0.06
71
8 winrate= 0.42 0.22 0.53 0.08
76
9 winrate= 0.42 0.22 0.6 0.09
83
ARENA!!!  0.48 0.11 0.41000000000000003
upgrade 0.48 0.11 0.41000000000000003
10 winrate= 0.48 0.11 0.67 0.09
92
11 winrate= 0.48 0.11 0.53 0.11
99
12 winrate= 0.48 0.11 0.54 0.07
108
13 winrate= 0.48 0.11 0.5 0.1
117
14 winrate= 0.48 0.11 0.48 0.09
122
ARENA!!!  0.5 0.07 0.43
upgrade 0.5 0.07 0.43
15 winrate= 0.5 0.07 0.56 0.09
16 winrate= 0.5 0.07 0.57 0.05
17 winrate= 0.5 0.07 0.49 0.08
18 winrate= 0.5 0.07 0.55 0.03
19 winrate= 0.5 0.07 0.53 0.06
ARENA!!! 

KeyboardInterrupt: 