In [1]:
import numpy as np
import scipy
import matplotlib.pyplot as plt
import pandas as pd
from pettingzoo.classic import go_v5
from SMPyBandits import Policies
from torch.utils.tensorboard import SummaryWriter
import warnings
import torch
from torch import nn
import torch.nn.functional as F
import random
import itertools

If you want the speed up brought by numba.jit, try to manually install numba and check that it works (installing llvmlite can be tricky, cf. https://github.com/numba/numba#custom-python-environments
Info: Using the Jupyter notebook version of the tqdm() decorator, tqdm_notebook() ...
If you want the speed up brought by numba.jit, try to manually install numba and check that it works (installing llvmlite can be tricky, cf. https://github.com/numba/numba#custom-python-environments


In [2]:
def fixedMOSSEcomputeAllIndex(self):
    """ Compute the current indexes for all arms, in a vectorized manner."""
    pulls_of_suboptimal_arms = np.sum(self.pulls[self.pulls < np.sqrt(self.t)])
    if pulls_of_suboptimal_arms > 0:
        indexes = (self.rewards / self.pulls) + np.sqrt(0.5 * np.maximum(0, np.log(self.t / pulls_of_suboptimal_arms)) / self.pulls)
    else:
        indexes = (self.rewards / self.pulls) + np.sqrt(0.5 * np.maximum(0, np.log(self.t / (self.nbArms * self.pulls))) / self.pulls)
    # indexes[self.pulls < 1] = float('+inf')
    self.index[:] = indexes
Policies.MOSSExperimental.computeAllIndex = fixedMOSSEcomputeAllIndex

In [3]:
class GoNode():
    def __init__(self, num_arms, legal_actions, policy_algorithm, player='black_0', state=None, fictitious_alphas = None, fictitious_betas = None):
        self.num_arms = num_arms
        self.legal_actions = legal_actions
        self.player = player
        self.state = state
        self.policy_algorithm = policy_algorithm
        self.policy = self.policy_algorithm(num_arms)
        if fictitious_alphas == None and fictitious_betas == None:
            self.fictitious_pulls = np.ones(num_arms)
            self.fictitious_rewards = np.ones(num_arms) / 2
        else:
            self.fictitious_pulls = fictitious_alphas + fictitious_betas
            self.fictitious_rewards = fictitious_alphas
        self.policy.pulls = self.fictitious_pulls.copy()
        self.policy.rewards = self.fictitious_rewards.copy()
        self.policy.t = self.policy.pulls.sum()
        self.next_nodes = {}

In [4]:
def pull_data(node):
    states = [node.state]
    num_arms = len(root_node.state[0].flatten())+1
    a = np.zeros(num_arms)
    a[node.legal_actions] = node.policy.rewards
    alphas = [a]
    p = np.zeros(num_arms)
    p[node.legal_actions] = node.policy.pulls
    betas = [p-a]
    for next_node in node.next_nodes.values():
        next_states, next_alphas, next_betas = pull_data(next_node)
        states += next_states
        alphas += next_alphas
        betas += next_betas
    return states, alphas, betas

In [15]:
class Net(nn.Module):
    def __init__(self, size):
        super(Net, self).__init__()
        if size == 1:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 2, 3, 1, padding='same')])
        elif size == 2:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 32, 3, 1, padding='same') for _ in range(2)] + \
                [nn.Conv2d(32, 2, 3, 1, padding='same')])
        elif size == 3:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 32, 3, 1, padding='same') for _ in range(4)] + \
                [nn.Conv2d(32, 2, 3, 1, padding='same')])
        elif size == 4:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 64, 3, 1, padding='same')] + \
                [nn.Conv2d(64, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 16, 3, 1, padding='same')] + \
                [nn.Conv2d(16, 2, 3, 1, padding='same')])
        elif size == 5:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 128, 3, 1, padding='same')] + \
                [nn.Conv2d(128, 64, 3, 1, padding='same')] + \
                [nn.Conv2d(64, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 16, 3, 1, padding='same')] + \
                [nn.Conv2d(16, 2, 3, 1, padding='same')])

    def forward(self, x):
        for i in range(len(self.layers)-1):
            x = self.layers[i](x)
            x = F.relu(x)
        out = torch.flatten(self.layers[-1](x), -2, -1)
        passes = x[:,:2,:,:].mean(axis=(-1,-2))
        out = torch.cat((out, passes.unsqueeze(-1)), axis=-1)
        alphas = torch.exp(out[:,0,:])
        betas=torch.exp(out[:,1,:])
        return alphas, betas
    
def augment(inpt, target_alphas, target_betas):
    rotation = random.randrange(4)
    flip = random.randrange(2)
    alpha_board_targets , alpha_pass_targets = torch.split(target_alphas, [target_alphas.shape[-1]-1,1], dim=-1)
    beta_board_targets , beta_pass_targets = torch.split(target_betas, [target_betas.shape[-1]-1,1], dim=-1)
    alpha_board_targets_reshaped = alpha_board_targets.view(-1, int(np.sqrt(alpha_board_targets.shape[-1])),int(np.sqrt(alpha_board_targets.shape[-1])))
    beta_board_targets_reshaped = beta_board_targets.view(-1, int(np.sqrt(beta_board_targets.shape[-1])),int(np.sqrt(beta_board_targets.shape[-1])))
    inpt = torch.rot90(inpt,rotation,[-2,-1])
    alpha_board_targets_reshaped = torch.rot90(alpha_board_targets_reshaped,rotation,[-2,-1])
    beta_board_targets_reshaped = torch.rot90(beta_board_targets_reshaped,rotation,[-2,-1])
    if flip:
        inpt = torch.flip(inpt, [-1])
        alpha_board_targets_reshaped = torch.flip(alpha_board_targets_reshaped, [-1])
        beta_board_targets_reshaped = torch.flip(beta_board_targets_reshaped, [-1])
    alpha_board_targets = alpha_board_targets_reshaped.reshape(-1, alpha_board_targets.shape[-1])
    beta_board_targets = beta_board_targets_reshaped.reshape(-1, beta_board_targets.shape[-1])
    alpha_targets = torch.cat((alpha_board_targets, alpha_pass_targets),-1).contiguous()
    beta_targets = torch.cat((beta_board_targets, beta_pass_targets),-1).contiguous()
    return inpt, alpha_targets, beta_targets

def loss1(target_alpha_batches, p_hat, s):
    # return (target_beta_batches*p_hat**2 + target_alpha_batches*(1-p_hat)**2 + p_hat*(1-p_hat)/(s+1)*pulls).mean()
    return (target_alpha_batches - target_alpha_batches*p_hat + p_hat*(1-p_hat)/(s+1)).mean()

def loss2(p_hat, target_probs):
    return ((p_hat-target_probs)**2).mean()

def loss3(target_beta_batches, target_alpha_batches, p_hat):
    return (target_beta_batches*p_hat**2 + target_alpha_batches*(1-p_hat)**2).mean()

def loss4(alphas, betas, target_alphas, target_betas):
    # return (target_alpha_batches*(torch.digamma(s)-torch.digamma(alphas))).mean()
    return ((target_alphas + target_betas)*torch.digamma(alphas+betas)-target_alphas*torch.digamma(alphas)-target_betas*torch.digamma(betas)).mean()

def tuning(network, loss_func):
    optimizer = torch.optim.Adam(network.parameters(), lr=1e-5)

    batch_size = 64
    
    train_target_alphas, val_target_alphas, test_target_alphas = torch.split(target_alphas, [int(len(target_alphas)*0.7), int(len(target_alphas)*0.15), int(len(target_alphas)*0.15) + 1])
    train_target_betas, val_target_betas, test_target_betas = torch.split(target_betas, [int(len(target_betas)*0.7), int(len(target_betas)*0.15), int(len(target_betas)*0.15) + 1])
    train_states, val_states, test_states = torch.split(states, [int(len(states)*0.7), int(len(states)*0.15), int(len(states)*0.15) + 1])
    
    train_target_alpha_batches = torch.split(train_target_alphas, batch_size)
    train_target_beta_batches = torch.split(train_target_betas, batch_size)
    train_batches = torch.split(train_states, batch_size)
    ma = 0
        
    for epoch in range(10):
        print(f'epoch:{epoch}')
        for i in range(len(train_batches)):
            optimizer.zero_grad()
            aug_train_inpt, aug_train_target_alpha, aug_train_target_beta = augment(train_batches[i], train_target_alpha_batches[i], train_target_beta_batches[i])
            alphas, betas = network(aug_train_inpt)

            s = alphas+betas
            p_hat = alphas/s
            
            pulls = aug_train_target_alpha + aug_train_target_beta
            target_probs = aug_train_target_alpha/pulls
            target_probs[target_probs != target_probs] = 0

            if loss_func == loss1:
                train_loss = loss1(aug_train_target_alpha, p_hat, s)
            elif loss_func == loss2:
                train_loss = loss2(p_hat, target_probs)
            elif loss_func == loss3:
                train_loss = loss3(aug_train_target_beta, aug_train_target_alpha, p_hat)
            elif loss_func == loss4:
                train_loss = loss4(aug_train_target_alpha, s, alphas)
            train_loss.backward()
            optimizer.step()
            if i % 500 == 0:
                print(train_loss.item())
    
    val_target_alpha_batches, val_target_beta_batches = torch.split(val_target_alphas, batch_size), torch.split(val_target_betas, batch_size)
    val_batches = torch.split(val_states, batch_size)
    val_batch_loss = []
    for i in range(len(val_batches)):
        optimizer.zero_grad()
        alphas, betas = network(val_batches[i])

        s = alphas+betas
        p_hat = alphas/s
        
        pulls = val_target_alpha_batches[i] + val_target_beta_batches[i]
        target_probs = val_target_alpha_batches[i]/pulls
        target_probs[target_probs != target_probs] = 0

        val_batch_loss += [loss1(val_target_alpha_batches[i], p_hat, s).item()]

    return val_batch_loss

In [6]:
def train(network, states, target_alphas, target_betas):
    optimizer = torch.optim.Adam(network.parameters(), lr=1e-4)
    batch_size = 1024

    train_target_alphas, test_target_alphas = torch.split(target_alphas, [int(len(target_alphas)*0.85), len(target_alphas) - int(len(target_alphas)*0.85)])
    train_target_betas,  test_target_betas = torch.split(target_betas, [int(len(target_betas)*0.85), len(target_betas) - int(len(target_betas)*0.85)])
    train_states, test_states = torch.split(states, [int(len(states)*0.85), len(states) - int(len(states)*0.85)])
    
    train_target_alpha_batches = torch.split(train_target_alphas, batch_size)
    train_target_beta_batches = torch.split(train_target_betas, batch_size)
    train_batches = torch.split(train_states, batch_size)
        
    for epoch in range(10):
        print(f'epoch:{epoch}')
        for i in range(len(train_batches)):
            optimizer.zero_grad()
            aug_train_inpt, aug_train_target_alpha, aug_train_target_beta = augment(train_batches[i], train_target_alpha_batches[i], train_target_beta_batches[i])
            alphas, betas = network(aug_train_inpt)
            
            # train_loss = loss_mse(alphas, betas, aug_train_target_alpha, aug_train_target_beta)
            train_loss = loss4(alphas, betas, aug_train_target_alpha, aug_train_target_beta)
            train_loss.backward()
            optimizer.step()
            if i % 40 == 0:
                print(train_loss.item())
        test_alphas, test_betas = network(test_states)
        # val_loss = loss_mse(test_alphas, test_betas, test_target_alphas, test_target_betas)
        val_loss = loss4(test_alphas, test_betas, test_target_alphas, test_target_betas)
        print(f'val loss:{val_loss}')

In [16]:
writer = SummaryWriter()
env = go_v5.env(board_size = 5, komi = 3.5)
policy_algorithm=Policies.MOSSExperimental
env.reset(seed=42)
last = env.last()
root_node = GoNode(num_arms = last[0]['action_mask'].sum(), legal_actions = last[0]['action_mask'].nonzero()[0], policy_algorithm=policy_algorithm, state = last[0]['observation'].transpose((2,1,0)))
winner_list = []
big_number = 1e10
win_condition = 100
i = 0
using_network=False
states_per_training_batch = 200000
num_unique_states_visited = 1
# data collection and training loop
while True:
    # data collection and tree update loop
    while num_unique_states_visited < states_per_training_batch:
        print(num_unique_states_visited)
        node = root_node
        node_action_list = []
        observation, reward, termination, truncation, info = env.last()
        state = observation['observation'].transpose((2,1,0))
        mask = observation["action_mask"]
        # single game playout loop
        for agent in env.agent_iter():
            if termination or truncation:
                action = None
                winner = reward if agent=='black_0' else -reward
                winner = (winner + 1)/2
                winner_list.append(winner)
                writer.add_scalar('Winner', winner, str(i))
                break
            policy_choice = node.policy.choice()
            action = node.legal_actions[policy_choice]
            node_action_list.append((node,policy_choice))
            env.step(action)
            observation, reward, termination, truncation, info = env.last()
            state = observation['observation'].transpose((2,1,0))
            mask = observation["action_mask"]
            num_arms = np.count_nonzero(mask)
            if policy_choice in node.next_nodes:
                node = node.next_nodes[policy_choice]
            else:
                fictitious_alphas, fictitious_betas = network(state) if using_network else (None, None)
                node.next_nodes[policy_choice] = GoNode(num_arms=num_arms, 
                                                        legal_actions = mask.nonzero()[0], 
                                                        policy_algorithm=policy_algorithm if num_arms > 3 else Policies.klUCBPlus, 
                                                        player = 'white_0' if node.player=='black_0' else 'black_0', 
                                                        state=state,
                                                        fictitious_alphas = fictitious_alphas,
                                                        fictitious_betas = fictitious_betas)
                num_unique_states_visited += 1
                node = node.next_nodes[policy_choice]
        env.close()

        if {'black_0':1,'white_0':0}[node_action_list[-1][0].player] == winner:
            node_solved=-1
        else:
            node_solved=1
        # tree update loop
        for node, action in reversed(node_action_list):
            if node_solved !=0:
                node.policy.rewards[action] = big_number if node_solved==-1 else .5
                node.policy.pulls[action]=big_number
                # node.policy.pulls[action] +=1
            else:
                node.policy.getReward(action, (winner if node.player=='black_0' else 1-winner))
            node_solved = 1 if any(node.policy.rewards==big_number) else (-1 if (all(node.policy.rewards==.5) and all(node.policy.pulls==big_number)) else 0)
        
        # if len(winner_list)>win_condition and ((np.array(winner_list[-win_condition:])==1).all() or (np.array(winner_list[-win_condition:])==0).all()):
        #     break
        if any(root_node.policy.rewards==big_number):
            break
        env.reset(seed=42)

    states, target_alphas, target_betas = pull_data(root_node)
    if not using_network:
        network = Net(5)
        using_network = True
    train(network, torch.FloatTensor(states), torch.FloatTensor(target_alphas), torch.FloatTensor(target_betas))
    num_unique_states_visited = 1


1
19
47
72
102
129
148
182
230
258
273
314
365
412
439
459
506
537
566
592
612
625
684
726
760
793
835
863
892
947
973
1003
1029
1063
1089
1142
1162
1193
1222
1255
1291
1345
1364
1401
1432
1466
1524
1567
1596
1647
1694
1717
1737
1788
1824
1875
1952
1999
2027
2055
2085
2108
2134
2203
2217
2253
2273
2303
2344
2370
2400
2448
2484
2530
2633
2658
2766
2799
2829
2871
2876
2914
2971
2998
3048
3078
3099
3118
3142
3178
3208
3249
3278
3321
3369
3402
3432
3464
3496
3556
3592
3623
3675
3708
3754
3785
3821
3874
3899
3965
3989
4023
4069
4084
4124
4170
4188
4292
4322
4361
4408
4448
4467
4497
4535
4599
4630
4640
4663
4689
4738
4781
4818
4871
4909
5026
5048
5084
5129
5155
5194
5237
5257
5291
5338
5386
5427
5469
5487
5521
5538
5569
5646
5661
5690
5713
5740
5799
5831
5856
5891
5928
5969
6009
6019
6045
6086
6122
6135
6162
6213
6258
6288
6325
6352
6388
6419
6442
6474
6522
6554
6591
6638
6699
6793
6833
6884
6899
6923
6960
7044
7088
7111
7150
7186
7233
7282
7294
7344
7394
7456
7477
7511
7563
7609
7632
7674
7