In [14]:
import torch
from alpha_zero.utils import eval_winrate
from alpha_zero.models import RolloutPolicy, Model, Backbone, ValueFunction, NNTreePolicy, TreePolicy
from alpha_zero.mcts import mcts, Node, get_ns, get_qs, get_best_action
from alpha_zero.augmentations_tictactoe import symetric_add2rbuff
from myrl.buffers import ReplayBuffer
from myrl.utils import ExperimentWriter
from gym_tictactoe.env import TicTacToeEnv, agent_by_mark, next_mark
import copy
from alpha_zero.utils import play_mcts_against_human, play_mcts_against_itself, save, load, play_model_against_human


env = TicTacToeEnv()
obs = env.reset()
%load_ext autoreload
%autoreload 2

import random
import math
random.random()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


0.9713480454050764

In [47]:
rollout_policy = RolloutPolicy(env)
model = Model(TicTacToeEnv())
backbone = Backbone([10, 32])
value_function = ValueFunction([32, 16, 1], backbone=backbone)
tree_policy = NNTreePolicy([32, 16, 9], backbone=backbone)

In [48]:
eval_winrate(rollout_policy, tree_policy, TicTacToeEnv(), model, n_games=100)

(0.46, 0.1, 0.44000000000000006)

In [49]:
# best_policy = copy.deepcopy(tree_policy)
wll = ExperimentWriter('tb/alpha_tictacte_egreedy_explore_arena___')

In [50]:
game_step = 0

In [51]:
test_env = TicTacToeEnv()
obs = test_env.reset()

rbuff = ReplayBuffer(nitems=3, max_len=50*9)
bsize = 64
wll.new()
writer = wll.writer
opt = torch.optim.Adam(list(tree_policy.parameters())+list(value_function.parameters()), lr=1e-3)
best_tree_policy = copy.deepcopy(tree_policy)
best_opt = copy.deepcopy(opt)
best_vfunc = copy.deepcopy(value_function)
game = game_step
n_simulations = 100
temperature = 10
rsize = 0

for game_step in range(395, 10000):
    obs = test_env.reset()
    game_buff = []
    done = False
    while not done:
        nodes = {obs:Node()}
        for i in range(n_simulations):
            mcts(nodes, obs, tree_policy, model, value_function=value_function, cpucb=50, discount_factor=0.99)
        qs = get_qs(nodes, obs, model)
        ns = get_ns(nodes, obs, model)

        monte_ns = torch.tensor(ns, dtype=torch.float).unsqueeze(0)/n_simulations
        monte_probs = torch.softmax(monte_ns*temperature, dim=-1).detach()
        tensor_obs = tree_policy.obs2testorobs(obs).unsqueeze(0)
        game_buff.append((tensor_obs, monte_probs))

        if random.random() > max(0, min(1, 10/(game+1e-8))):
            move = random.choice(model.available_actions(obs))
        else:
            move = get_best_action(nodes, obs, model)
        obs, r, done, _ = test_env.step(move)

    tensor_rew_sign = torch.ones(1, 1)*r
    for tensor_obs, monte_probs in game_buff:
        rbuff.add(tensor_obs, monte_probs, tensor_rew_sign)
        tensor_rew_sign *= -1
    # symetric_add2rbuff(rbuff, game_buff, rew_sign, r)

    game += 1
    if len(rbuff) <= bsize:
        print("len rbuff=", len(rbuff), len(game_buff))
        continue
    rsize = 0

    for step in range(2*(len(rbuff)//bsize+1)):
        tensor_obs, monte_probs, game_finish = rbuff.get(bsize)
        # print(tensor_obs.shape, monte_probs.shape, game_finish.shape)
        # print(tensor_obs, monte_probs, game_finish)
        for opt_step in range(4):
            policy_probs = tree_policy(tensor_obs)
            loss_policy = -(monte_probs*torch.log(policy_probs+1e-8)).mean()
            loss_value  = ((value_function(tensor_obs)-game_finish)**2).mean()*10
            loss = loss_policy + loss_value
            opt.zero_grad()
            loss.backward()
            opt.step()

    print("loss=", loss.item(), loss_policy.item(), loss_value.item())
    writer.add_scalar('loss/loss', loss.item(), game)
    writer.add_scalar('loss/policy', loss_policy.item(), game)
    writer.add_scalar('loss/vfunc', loss_value.item(), game)
    # writer.add_scalar('loss/game', game, game_step)

    if game % 7 == 0:
        tree_policy.temperature, best_tree_policy.temperature = 0.1, 0.1
        winrate, drawrate, loserate = eval_winrate(tree_policy, best_tree_policy, test_env, model, n_games=100)
        tree_policy.temperature, best_tree_policy.temperature = 1, 1
        print("ARENA!!! ", winrate, drawrate, loserate)
        if winrate > loserate:
            best_tree_policy = copy.deepcopy(tree_policy)
            best_opt = copy.deepcopy(opt)
            best_vfunc = copy.deepcopy(value_function)
            print("upgrade", winrate, drawrate, loserate)
            winrate, drawrate, _ = eval_winrate(tree_policy, rollout_policy, test_env, model, n_games=300)
            print("winrate against random ", winrate, drawrate)
        else:
            tree_policy = copy.deepcopy(best_tree_policy)
            opt = copy.deepcopy(best_opt)
            value_function = copy.deepcopy(best_vfunc)


    winrate2, drawrate2, _ = eval_winrate(tree_policy, rollout_policy, test_env, model, n_games=100)
    writer.add_scalar('winrate/winrate', winrate2, game)
    writer.add_scalar('winrate/drawrate', drawrate2, game)
    print(game, "winrate=", winrate2, drawrate2)



018
379 winrate= 0.68 0.02
loss= 3.6435694694519043 0.09263595193624496 3.550933599472046
380 winrate= 0.73 0.03
loss= 2.7076313495635986 0.09353616833686829 2.6140952110290527
381 winrate= 0.76 0.03
loss= 2.9721243381500244 0.08649469912052155 2.885629653930664
382 winrate= 0.62 0.09
loss= 2.733546018600464 0.08592704683542252 2.6476190090179443
383 winrate= 0.79 0.02
loss= 1.4606837034225464 0.08749023079872131 1.3731935024261475
384 winrate= 0.71 0.05
loss= 1.930840253829956 0.10789040476083755 1.8229498863220215
ARENA!!!  0.51 0.0 0.49
upgrade 0.51 0.0 0.49
winrate against random  0.7433333333333333 0.03333333333333333
385 winrate= 0.69 0.05
loss= 3.005021572113037 0.1100187599658966 2.895002841949463
386 winrate= 0.7 0.06
loss= 2.9385101795196533 0.09222414344549179 2.8462860584259033
387 winrate= 0.68 0.04
loss= 2.241718053817749 0.07567187398672104 2.166046142578125
388 winrate= 0.65 0.03
loss= 3.12062668800354 0.09821098297834396 3.0224156379699707
389 winrate= 0.64 0.08
loss= 

KeyboardInterrupt: 

In [53]:
play_model_against_human(env, model, tree_policy, value_function, human='firt')

   | | 
  -----
   | | 
  -----
   | | 

--------------------
   | | 
  -----
   | | 
  -----
   | | 

turn number 0
value of position =  -0.874956488609314
--------------------
   | | 
  -----
   | | 
  -----
  O| | 

turn number 1
value of position =  -0.827090859413147
your action was 0
--------------------
  X| | 
  -----
   | | 
  -----
  O| | 

turn number 2
value of position =  -0.9447544813156128
--------------------
  X| | 
  -----
   |O| 
  -----
  O| | 

turn number 3
value of position =  -1.0376200675964355
your action was 2
--------------------
  X| |X
  -----
   |O| 
  -----
  O| | 

turn number 4
value of position =  -0.8899586200714111
--------------------
  X|O|X
  -----
   |O| 
  -----
  O| | 

turn number 5
value of position =  -0.7020097970962524
your action was 3
--------------------
  X|O|X
  -----
  X|O| 
  -----
  O| | 

turn number 6
value of position =  -1.099194049835205
--------------------
  X|O|X
  -----
  X|O| 
  -----
  O|O| 

final reward= 1


In [9]:
tree_policy.temperature = 1e-8