In [22]:
from gym_tictactoe.env import TicTacToeEnv, agent_by_mark, next_mark
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from myrl.buffers import ReplayBuffer
from myrl.utils import ExperimentWriter
from myrl.value_functions import 

env = TicTacToeEnv()
obs = env.reset()

In [23]:
env.action_space

Discrete(9)

In [120]:
class RolloutPolicy:
    def __init__(self, env):
        self.env = env
    def act(self, obs, render=False):
        return env.action_space.sample()
    def rollout(self, obs, model, render=False):
        d = False
        rsum = 0
        while not d:
            obs, r, d, _ = model.step(obs, self.act(obs))
            rsum += r
        return rsum

class Model:
    def __init__(self, env):
        self.env = env
    def step(self, obs, action):
        self._set_env(obs)
        return self.env.step(action)
    def available_actions(self, obs):
        self._set_env(obs)
        return self.env.available_actions()
    def _set_env(self, obs):
        self.env.board = list(obs[0])
        self.env.mark  = obs[1] 
        self.done = False
    def get_num_actions(self):
        return self.env.action_space.n

class TreePolicy:
    def __init__(self):
        pass
    def act(self, obs, available_actions):
        import random
        return random.choice(available_actions)
    def get_action_probs(self, obs, available_actions):
        return [1/len(available_actions) for i in range(len(available_actions))]

class Backbone(nn.Module):
    def __init__(self, net_arch, middle_activation=F.relu, last_activation=F.relu):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(a, b) for a, b in zip(net_arch[:-1], net_arch[1:])])
    def forward(self, h):
        h = h.view(h.shape[0], -1)
        for lay in self.layers[:-1]:
            h = middle_activation(lay(h))
        h = self.layers[-1](h)
        h = last_activation(h)
        return h

class ValueFunction(nn.Module):
    def __init__(self, net_arch, backbone):
        super().__init__()
        self.backbone = backbone
        self.layers = nn.ModuleList([nn.Linear(a, b) for a, b in zip(net_arch[:-1], net_arch[1:])])
    def forward(self, h):
        h = torch.tensor(x, dtype=torch.float)
        h = h.view(h.shape[0], -1)
        h = self.backbone(h)
        for lay in self.layers[:-1]:
            h = middle_activation(lay(h))
        h = self.layers[-1](h)
        return h

class NNTreePolicy(nn.Module):
    def __init__(self, net_arch, backbone):
        super().__init__()
        self.backbone = backbone
        self.layers = nn.ModuleList([nn.Linear(a, b) for a, b in zip(net_arch[:-1], net_arch[1:])])
    def forward(self, x):
        h = torch.tensor(x, dtype=torch.float)
        h = h.view(h.shape[0], -1)
        h = self.backbone(h)
        for lay in self.layers[:-1]:
            # if print:
            #     print(h.shape, x.shape)
            h = F.relu(lay(h))
        h = self.layers[-1](h)
        h = torch.softmax(h, dim=1)
        return h
    def act(self, obs, available_actions):
        obs = self.obs2testorobs(obs)
        h = self.forward(obs)
        indices = h.sort(descending=True).indices[0]
        # print(indices)
        for i in indices:
            if int(i.item()) in available_actions:
                return int(i.item())        
    def get_action_probs(self, obs, available_actions):
        obs = self.obs2testorobs(obs)
        h = self.forward(obs)
        return h.tolist()[0]
    def obs2testorobs(self, obs):
        l2 = [1] if obs[1]=='O' else [-1]
        obs = torch.tensor([list(obs[0])+l2])
        obs[obs==2] = -1
        return obs




rollout_policy = RolloutPolicy(env)
model = Model(TicTacToeEnv())
backbone = Backbone([10, 16])
value_function = ValueFunction([16, 4, 1], backbone=backbone)
tree_policy = NNTreePolicy([16, 9, 9], backbone=backbone)

# obs = env.reset()
# rollout_policy.rollout(obs, model)

In [121]:
test_env.mark

'X'

In [122]:
class Node:
    def __init__(self, obs, reward, change_child_rew_sign=True, reward_sign=1, done=False, parent=None):
        self.n = 0
        self.cumulative_reward = 0#reward
        self.parent = parent
        self.action2child = {}
        self.nchildren = 0
        self.taken_actions = []
        self.obs = obs
        self.done = done
        self.reward = reward
        self.reward_sign = reward_sign
        self.change_child_rew_sign = change_child_rew_sign
    def get_q(self):
        return self.cumulative_reward/(self.n)

    def backpropagate(self, r, gamma=1):
        self.n += 1
        self.cumulative_reward += r*gamma*self.reward_sign
        if not( -1 <= self.get_q() <= 1 ) and 0:
            print(self.__dict__)
        if self.parent is None:
            return 
        self.parent.backpropagate(r, gamma=gamma)
    def print_parents(self):
        if self.parent is None:
            return
        self.parent.print_parents()

    def create_child(self, ChildType, policy, model):
        non_taken_actions = self._list_cut(model.available_actions(self.obs), self.taken_actions)
        action = policy.act(self.obs, non_taken_actions)
        self.taken_actions.append(action)
        obs, reward, done, info = model.step(self.obs, action)
        reward_sign= -self.reward_sign if self.change_child_rew_sign else self.reward_sign
        child = ChildType(obs, reward, done=done, reward_sign=reward_sign, parent=self, change_child_rew_sign=self.change_child_rew_sign)
        self.action2child[action] = child
        self.nchildren += 1
        return child, action

    def _list_cut(self, l1, l2):
        toret = []
        for a1 in l1:
            if a1 not in l2:
                toret.append(a1)
        return toret

    def rollout(self, rollout_policy, model, render=False):
        if self.done:
            return self.reward
        return rollout_policy.rollout(self.obs, model, render=render)

In [123]:
def UCB(root_node, policy, model, alpha=1):
    scores = [0 for i in range(model.get_num_actions())]
    all_actions = list(range(model.get_num_actions()))
    probs = policy.get_action_probs(root_node.obs, all_actions)
    maxscore, a_maxscore = float('-inf'), -1
    for a in root_node.action2child.keys():
        child = root_node.action2child[a]
        u = probs[a]/(1+child.n)
        q = child.get_q()
        score = q + alpha*u
        scores.append(score)
        if score > maxscore:
            maxscore = score
            a_maxscore = a
    return scores, (a_maxscore, maxscore)

def MCTS(root_node, max_depth, n_times, policy, model, alpha_UCB=1):
    current_node  = root_node
    current_depth = 0
    n_times_done  = 0

    while n_times_done != n_times:
        if current_depth == max_depth or current_node.done:
            # current_node.print_parents()
            reward = current_node.rollout(rollout_policy, model)
            current_node.backpropagate(reward, gamma=0.99)
            current_node = root_node
            n_times_done += 1
            current_depth = 0
            model.env.done = False
        elif current_node.nchildren < len(model.available_actions(current_node.obs)):
            child, action = current_node.create_child(Node, policy, model) 
            current_node = child
            current_depth += 1
        else:
            scores, (a_maxscore, maxscore) = UCB(current_node, policy, model, alpha_UCB)
            current_node = current_node.action2child[a_maxscore]
            current_depth += 1

    visits = []
    for a in range(model.get_num_actions()):
        if a in root_node.action2child:
            visits.append(root_node.action2child[a].get_q())
        else:
            visits.append(float('inf'))
    return visits

In [133]:
def eval_winrate(totest, bench, env, n_games=100):
    wins = 0
    draws = 0
    for igame in range(n_games):
        done, reward = False, 0
        obs = env.reset()
        curr_policy = totest if igame<=n_games//2 else bench
        rew2count = 1 if igame<=n_games//2 else -1
        while not done:
            action = curr_policy.act(obs, env.available_actions())
            obs, r, done, _ = test_env.step(action)
            curr_policy = totest if curr_policy==bench else bench
            wins += 1 if r==rew2count else 0
            # print(r)
        draws += 1 if r==0 else 0
        # print(" ")
    winrate = wins/n_games
    drawrate = draws/n_games
    return winrate, drawrate, (1-winrate-drawrate)

In [134]:
eval_winrate(tree_policy, rollout_policy, test_env, n_games=100)

(0.63, 0.04, 0.33)

In [126]:
import numpy as np
test_env = TicTacToeEnv()
obs = test_env.reset()
obs, r, done, _ = test_env.step(4)
obs, r, done, _ = test_env.step(0)
# obs, r, done, _ = test_env.step(1)
# obs, r, done, _ = test_env.step(7)
# obs, r, done, _ = test_env.step(6)
# obs, r, done, _ = test_env.step(2)
# obs, r, done, _ = test_env.step(3)
# obs, r, done, _ = test_env.step(5)
done = False
test_env.render()
# rbuff = ReplayBuffer(200)

while not done:
    rew_sign = 1 if test_env.mark==test_env.start_mark else -1
    root = Node(obs, 0, reward_sign=rew_sign)
    dic = MCTS(root, 10, 1, tree_policy, model, 100)
    dic = np.array(dic)
    tdic = torch.tensor([-dic])
    print("move probs= ", torch.softmax(tdic, dim=1).view(3, 3))
    move = np.argmin(dic)
    # print(move, test_env.mark)
    # print(dic.reshape(3, 3))
    obs, r, done, _ = test_env.step(move)
    test_env.render()
    # print(" ")
    # print(" ")

  X| | 
  -----
   |O| 
  -----
   | | 

move probs=  tensor([[0., 0., 0.],
        [0., 0., 0.],
        [1., 0., 0.]], dtype=torch.float64)
  X| | 
  -----
   |O| 
  -----
  O| | 

move probs=  tensor([[0., 0., 0.],
        [1., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)
  X| | 
  -----
  X|O| 
  -----
  O| | 

move probs=  tensor([[0., 0., 1.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)
  X| |O
  -----
  X|O| 
  -----
  O| | 



In [136]:
test_env = TicTacToeEnv()
obs = test_env.reset()
done = False
test_env.render()
# rbuff = ReplayBuffer(nitems=2, max_len=400)
bsize = 32
wll = ExperimentWriter('tb/alpha_tictacte_zero_3_')
wll.new()
writer = wll.writer
# opt = torch.optim.Adam(tree_policy.parameters(), lr=1e-3)
import copy
# best_tree_policy = copy.deepcopy(tree_policy)
# best_opt = copy.deepcopy(opt)

for game in range(10000):
    game_step = 0
    done = False
    while not done:
        game_step += 1
        rew_sign = 1 if test_env.mark==test_env.start_mark else -1
        root = Node(obs, 0, reward_sign=rew_sign)
        dic = MCTS(root, 10, 100, tree_policy, model, 100)
        dic = np.array(dic)

        tdic = torch.tensor([[-dic]])
        monte_probs = torch.softmax(tdic, dim=-1).detach()
        tensor_obs = tree_policy.obs2testorobs(obs).unsqueeze(0)
        rbuff.add(tensor_obs, monte_probs)
        
        move = np.argmin(dic)
        obs, r, done, _ = test_env.step(move)

        if len(rbuff) > bsize:
            for opt_step in range(2):
                tensor_obs, monte_probs = rbuff.get(bsize)
                policy_probs = tree_policy(tensor_obs)
                loss = -(monte_probs*torch.log(policy_probs+1e-8)).mean()
                opt.zero_grad()
                loss.backward()
                opt.step()
            writer.add_scalar('loss', loss.item(), game)
    winrate, drawrate, loserate = eval_winrate(tree_policy, best_tree_policy, test_env, n_games=1000)
    if winrate > loserate:
        best_tree_policy = copy.deepcopy(tree_policy)
        best_opt = copy.deepcopy(opt)
        print("upgrade", winrate, drawrate, loserate)
    else:
        tree_policy = copy.deepcopy(best_tree_policy)
        opt = copy.deepcopy(best_opt)
        
    winrate, drawrate, _ = eval_winrate(tree_policy, rollout_policy, test_env, n_games=100)
    writer.add_scalar('winrate/winrate', winrate, game)
    writer.add_scalar('winrate/drawrate', drawrate, game)
    print(game, "winrate=", winrate, drawrate)

    obs = test_env.reset()


   | | 
  -----
   | | 
  -----
   | | 

upgrade 0.501 0.0 0.499
0 winrate= 0.58 0.11
upgrade 0.501 0.0 0.499
1 winrate= 0.6 0.06
upgrade 0.501 0.0 0.499
2 winrate= 0.59 0.06
upgrade 0.501 0.0 0.499
3 winrate= 0.56 0.05
upgrade 0.501 0.0 0.499
4 winrate= 0.59 0.06
upgrade 0.501 0.0 0.499
5 winrate= 0.54 0.09


KeyboardInterrupt: 

In [None]:
rbuff.deqs[1]

In [None]:
print(inspect.getsource(rbuff.add))

In [26]:
tree_policy

NNTreePolicy(
  (layers): ModuleList(
    (0): Linear(in_features=9, out_features=9, bias=True)
    (1): Linear(in_features=9, out_features=9, bias=True)
  )
)