In [92]:
import torch
from alpha_zero.utils import eval_winrate
from alpha_zero.models import RolloutPolicy, Model, Backbone, ValueFunction, NNTreePolicy, TreePolicy
from alpha_zero.mcts import mcts, Node, get_ns, get_qs, get_best_action
from alpha_zero.augmentations_tictactoe import symetric_add2rbuff
from myrl.buffers import ReplayBuffer
from myrl.utils import ExperimentWriter
from gym_tictactoe.env import TicTacToeEnv, agent_by_mark, next_mark
import copy
from alpha_zero.utils import play_mcts_against_human, play_mcts_against_itself, save, load


env = TicTacToeEnv()
obs = env.reset()
%load_ext autoreload
%autoreload 2

import random
import math
random.random()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


0.5236111213050683

In [99]:
rollout_policy = RolloutPolicy(env)
model = Model(TicTacToeEnv())
backbone = Backbone([10, 32])
value_function = ValueFunction([32, 16, 1], backbone=backbone)
tree_policy = NNTreePolicy([32, 16, 9], backbone=backbone)

In [100]:
eval_winrate(rollout_policy, tree_policy, TicTacToeEnv(), model, n_games=100)

(0.49, 0.06, 0.45)

In [101]:
# best_policy = copy.deepcopy(tree_policy)
wll = ExperimentWriter('tb/alpha_tictacte_zero_with_vfunc_')

In [116]:
test_env = TicTacToeEnv()
obs = test_env.reset()

test_env.render()
# rbuff = ReplayBuffer(nitems=3, max_len=50*9)
bsize = 128
# wll.new()
writer = wll.writer
opt = torch.optim.Adam(list(tree_policy.parameters())+list(value_function.parameters()), lr=1e-2)
# best_tree_policy = copy.deepcopy(tree_policy)
# best_opt = copy.deepcopy(opt)
# best_vfunc = copy.deepcopy(value_function)
# game = 0
n_simulations = 200
temperature = 10
rsize = 0

for game_step in range(0, 10000):
    obs = test_env.reset()
    game_buff = [] 
    done = False
    reward_sign = 1
    while not done:
        nodes = {obs:Node()}
        for i in range(n_simulations):
            mcts(nodes, obs, tree_policy, model, value_function=value_function, cpucb=50, discount_factor=0.99)
        qs = get_qs(nodes, obs, model)
        ns = get_ns(nodes, obs, model)

        # print(ns)
        monte_ns = torch.tensor(ns, dtype=torch.float).unsqueeze(0)/n_simulations
        monte_probs = torch.softmax(monte_ns*temperature, dim=-1).detach()
        # print(monte_probs)
        tensor_obs = tree_policy.obs2testorobs(obs).unsqueeze(0)
        tensor_rew_sign = torch.tensor([[reward_sign]])
        game_buff.append((tensor_obs, monte_probs, tensor_rew_sign))
        # print(monte_probs)
        
        move = get_best_action(nodes, obs, model)
        obs, r, done, _ = test_env.step(move)
        reward_sign *= -1
    for tensor_obs, monte_probs, tensor_rew_sign in game_buff:
        rbuff.add(tensor_obs, monte_probs, tensor_rew_sign*r)
    # symetric_add2rbuff(rbuff, game_buff, rew_sign, r)

    game += 1
    if len(rbuff) <= bsize:
        print("len rbuff=", len(rbuff), len(game_buff))
        continue   
    rsize = 0  
        
    for step in range(2*(len(rbuff)//bsize+1)):
        tensor_obs, monte_probs, game_finish = rbuff.get(bsize)
        # print(tensor_obs.shape, monte_probs.shape, game_finish.shape)
        # print(tensor_obs, monte_probs, game_finish)
        for opt_step in range(4):
            policy_probs = tree_policy(tensor_obs)
            loss_policy = -(monte_probs*torch.log(policy_probs+1e-8)).mean()*10
            loss_value  = ((value_function(tensor_obs)-game_finish)**2).mean()
            loss = loss_policy + loss_value
            opt.zero_grad()
            loss.backward()
            opt.step()

    print("loss=", loss.item(), loss_policy.item(), loss_value.item())
    writer.add_scalar('loss/loss', loss.item(), game_step)
    writer.add_scalar('loss/policy', loss_policy.item(), game_step)
    writer.add_scalar('loss/vfunc', loss_value.item(), game_step)
    # writer.add_scalar('loss/game', game, game_step)

    # if game % 7 == 0:
    #     tree_policy.temperature, best_tree_policy.temperature = 1, 1
    #     winrate, drawrate, loserate = eval_winrate(tree_policy, best_tree_policy, test_env, n_games=100)
    #     tree_policy.temperature, best_tree_policy.temperature = 0.1, 0.1
    #     print("ARENA!!! ", winrate, drawrate, loserate)
    #     if winrate > loserate:
    #         best_tree_policy = copy.deepcopy(tree_policy)
    #         best_opt = copy.deepcopy(opt)
    #         best_vfunc = copy.deepcopy(value_function)
    #         print("upgrade", winrate, drawrate, loserate)
    #         winrate, drawrate, _ = eval_winrate(tree_policy, rollout_policy, test_env, n_games=300)
    #         print("winrate against random ", winrate, drawrate)
    #     else:
    #         tree_policy = copy.deepcopy(best_tree_policy)
    #         opt = copy.deepcopy(best_opt)
    #         value_function = copy.deepcopy(best_vfunc)

        
    winrate2, drawrate2, _ = eval_winrate(tree_policy, rollout_policy, test_env, model, n_games=100)
    writer.add_scalar('winrate/winrate', winrate2, game)
    writer.add_scalar('winrate/drawrate', drawrate2, game)
    print(game, "winrate=", winrate2, drawrate2)



7193
278 winrate= 0.56 0.12
loss= 0.0050238193944096565 0.005008228588849306 1.5590823750244454e-05
279 winrate= 0.61 0.08
loss= 0.005001746583729982 0.004994604270905256 7.1421254688175395e-06
280 winrate= 0.57 0.05
loss= 0.004980285651981831 0.004937073215842247 4.321230881032534e-05
281 winrate= 0.63 0.07
loss= 0.004934994969516993 0.00490149250254035 3.350228143972345e-05
282 winrate= 0.58 0.08
loss= 0.004818437620997429 0.0048089646734297276 9.472805686527863e-06
283 winrate= 0.61 0.05
loss= 0.00483403354883194 0.004808552097529173 2.5481491320533678e-05
284 winrate= 0.59 0.05
loss= 0.005052801687270403 0.005049074999988079 3.726638851730968e-06
285 winrate= 0.62 0.08
loss= 0.0049452693201601505 0.004930158611387014 1.5110685126273893e-05
286 winrate= 0.54 0.06
loss= 0.004843760281801224 0.004831216298043728 1.2543925549834967e-05
287 winrate= 0.65 0.03
loss= 0.004918344784528017 0.004868274088948965 5.007081927033141e-05
288 winrate= 0.54 0.06
loss= 0.0048578944988548756 0.004777

KeyboardInterrupt: 

In [104]:
env = TicTacToeEnv()
obs = env.reset()
env.step(0)
env.step(1)
# env.step(4)
env.step(2)
obs, _, _, _ = env.step(6)
obs, _, _, _ = env.step(2)
env.render()


tree_policy.act(obs, model.available_actions(obs)), value_function.get(obs)

  O|X|O
  -----
   | | 
  -----
  X| | 



(5, -0.6840388178825378)

In [67]:
tree_policy(tree_policy.obs2testorobs(obs))

tensor([[2.1998e-05, 2.1533e-03, 4.5727e-06, 1.1061e-02, 1.2088e-01, 2.0388e-04,
         7.7113e-04, 5.1501e-01, 3.4989e-01]], grad_fn=<SoftmaxBackward>)

In [73]:
model.available_actions(obs)

[3, 4, 5, 7, 8]

In [118]:
save(tree_policy, 'TicTacToe_policy', 1, 0)
save(value_function, 'TicTacToe_value', 1, 0)
save(opt, 'TicTacToe_opt', 1, 0)

In [113]:
def play_model_against_human(env, model, policy, value_function, human='first'):
    obs = env.reset()
    env.render()
    done = False
    turn = 0
    human = 0 if human=='first' else 1 
    while not done:
        print("-"*20)
        env.render()
        print("turn number", turn)
        print("value of position = ", value_function.get(obs))
        if human == turn%2:
            besta = int(input("your play? "))
            print("your action was", besta)
            obs, r, done, _ = env.step(besta)
        else:
            besta = policy.act(obs, model.available_actions(obs))
            obs, r, done, _ = env.step(besta)
        turn += 1
    print("-"*20)
    env.render()    
    print("final reward=", r)

In [117]:
play_model_against_human(env, model, tree_policy, value_function, human='nas')

   | | 
  -----
   | | 
  -----
   | | 

--------------------
   | | 
  -----
   | | 
  -----
   | | 

turn number 0
value of position =  1.005598783493042
--------------------
   | |O
  -----
   | | 
  -----
   | | 

turn number 1
value of position =  -1.0068035125732422
your action was 4
--------------------
   | |O
  -----
   |X| 
  -----
   | | 

turn number 2
value of position =  0.8598470687866211
--------------------
   | |O
  -----
   |X|O
  -----
   | | 

turn number 3
value of position =  -0.7394993305206299
your action was 8
--------------------
   | |O
  -----
   |X|O
  -----
   | |X

turn number 4
value of position =  1.0248593091964722
--------------------
   | |O
  -----
  O|X|O
  -----
   | |X

turn number 5
value of position =  0.06275974959135056
your action was 0
--------------------
  X| |O
  -----
  O|X|O
  -----
   | |X

final reward= -1
