In [2]:
import numpy as np
import scipy
import matplotlib.pyplot as plt
import pandas as pd
from pettingzoo.classic import go_v5
from SMPyBandits import Policies
from torch.utils.tensorboard import SummaryWriter
import warnings
import torch
from torch import nn
import torch.nn.functional as F
import random
import itertools
import sys

If you want the speed up brought by numba.jit, try to manually install numba and check that it works (installing llvmlite can be tricky, cf. https://github.com/numba/numba#custom-python-environments
Info: Using the Jupyter notebook version of the tqdm() decorator, tqdm_notebook() ...
If you want the speed up brought by numba.jit, try to manually install numba and check that it works (installing llvmlite can be tricky, cf. https://github.com/numba/numba#custom-python-environments


In [3]:
sys.setrecursionlimit(50000)
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")

In [4]:
def fixedMOSSEcomputeAllIndex(self):
    """ Compute the current indexes for all arms, in a vectorized manner."""
    pulls_of_suboptimal_arms = np.sum(self.pulls[self.pulls < np.sqrt(self.t)])
    if pulls_of_suboptimal_arms > 0:
        indexes = (self.rewards / self.pulls) + np.sqrt(0.5 * np.maximum(0, np.log(self.t / pulls_of_suboptimal_arms)) / self.pulls)
    else:
        indexes = (self.rewards / self.pulls) + np.sqrt(0.5 * np.maximum(0, np.log(self.t / (self.nbArms * self.pulls))) / self.pulls)
    # indexes[self.pulls < 1] = float('+inf')
    self.index[:] = indexes
Policies.MOSSExperimental.computeAllIndex = fixedMOSSEcomputeAllIndex

In [5]:
class GoNode():
    def __init__(self, num_arms, legal_actions, policy_algorithm, player='black_0', state=None, fictitious_alphas = None, fictitious_betas = None):
        self.num_arms = num_arms
        self.legal_actions = legal_actions
        self.player = player
        self.state = state
        self.policy_algorithm = policy_algorithm
        self.policy = self.policy_algorithm(num_arms)
        if fictitious_alphas == None and fictitious_betas == None:
            self.fictitious_pulls = np.ones(num_arms)
            self.fictitious_rewards = np.ones(num_arms) / 2
        else:
            self.fictitious_pulls = (fictitious_alphas + fictitious_betas).view(-1).numpy()[legal_actions]
            self.fictitious_rewards = fictitious_alphas.view(-1).numpy()[legal_actions]
        self.policy.pulls = self.fictitious_pulls.copy()
        self.policy.rewards = self.fictitious_rewards.copy()
        self.policy.t = self.policy.pulls.sum()
        self.next_nodes = {}

In [6]:
def pull_data(node):
    states = [node.state]
    num_arms = len(node.state[0].flatten())+1
    a = np.zeros(num_arms)
    a[node.legal_actions] = node.policy.rewards
    alphas = [a]
    p = np.zeros(num_arms)
    p[node.legal_actions] = node.policy.pulls
    betas = [p-a]
    for next_node in node.next_nodes.values():
        next_states, next_alphas, next_betas = pull_data(next_node)
        states += next_states
        alphas += next_alphas
        betas += next_betas
    return states, alphas, betas

In [7]:
class Net(nn.Module):
    def __init__(self, size):
        super(Net, self).__init__()
        if size == 1:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 2, 3, 1, padding='same')])
        elif size == 2:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 32, 3, 1, padding='same') for _ in range(2)] + \
                [nn.Conv2d(32, 2, 3, 1, padding='same')])
        elif size == 3:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 32, 3, 1, padding='same') for _ in range(4)] + \
                [nn.Conv2d(32, 2, 3, 1, padding='same')])
        elif size == 4:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 64, 3, 1, padding='same')] + \
                [nn.Conv2d(64, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 16, 3, 1, padding='same')] + \
                [nn.Conv2d(16, 2, 3, 1, padding='same')])
        elif size == 5:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 128, 3, 1, padding='same')] + \
                [nn.Conv2d(128, 64, 3, 1, padding='same')] + \
                [nn.Conv2d(64, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 16, 3, 1, padding='same')] + \
                [nn.Conv2d(16, 2, 3, 1, padding='same')])

    def forward(self, x):
        for i in range(len(self.layers)-1):
            x = self.layers[i](x)
            x = F.relu(x)
        out = torch.flatten(self.layers[-1](x), -2, -1)
        passes = x[:,:2,:,:].mean(axis=(-1,-2))
        out = torch.cat((out, passes.unsqueeze(-1)), axis=-1)
        alphas = torch.exp(out[:,0,:])
        betas=torch.exp(out[:,1,:])
        return alphas, betas

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(425,500), nn.Linear(500,400), nn.Linear(400,300), nn.Linear(300,200), nn.Linear(200,100), nn.Linear(100, 52)])
    def forward(self, x):
        x = x.reshape(-1,425)
        for i in range(len(self.layers)-1):
            x = self.layers[i](x)
            x = F.hardtanh(x)
        out = torch.exp(self.layers[-1](x))
        alphas = out[:,:26]
        betas = out[:,26:]
        
        return alphas, betas
    
def augment(inpt, target_alphas, target_betas):
    rotation = random.randrange(4)
    flip = random.randrange(2)
    alpha_board_targets , alpha_pass_targets = torch.split(target_alphas, [target_alphas.shape[-1]-1,1], dim=-1)
    beta_board_targets , beta_pass_targets = torch.split(target_betas, [target_betas.shape[-1]-1,1], dim=-1)
    alpha_board_targets_reshaped = alpha_board_targets.view(-1, int(np.sqrt(alpha_board_targets.shape[-1])),int(np.sqrt(alpha_board_targets.shape[-1])))
    beta_board_targets_reshaped = beta_board_targets.view(-1, int(np.sqrt(beta_board_targets.shape[-1])),int(np.sqrt(beta_board_targets.shape[-1])))
    inpt = torch.rot90(inpt,rotation,[-2,-1])
    alpha_board_targets_reshaped = torch.rot90(alpha_board_targets_reshaped,rotation,[-2,-1])
    beta_board_targets_reshaped = torch.rot90(beta_board_targets_reshaped,rotation,[-2,-1])
    if flip:
        inpt = torch.flip(inpt, [-1])
        alpha_board_targets_reshaped = torch.flip(alpha_board_targets_reshaped, [-1])
        beta_board_targets_reshaped = torch.flip(beta_board_targets_reshaped, [-1])
    alpha_board_targets = alpha_board_targets_reshaped.reshape(-1, alpha_board_targets.shape[-1])
    beta_board_targets = beta_board_targets_reshaped.reshape(-1, beta_board_targets.shape[-1])
    alpha_targets = torch.cat((alpha_board_targets, alpha_pass_targets),-1).contiguous()
    beta_targets = torch.cat((beta_board_targets, beta_pass_targets),-1).contiguous()
    return inpt, alpha_targets, beta_targets

def loss1(target_alpha_batches, p_hat, s):
    # return (target_beta_batches*p_hat**2 + target_alpha_batches*(1-p_hat)**2 + p_hat*(1-p_hat)/(s+1)*pulls).mean()
    return (target_alpha_batches - target_alpha_batches*p_hat + p_hat*(1-p_hat)/(s+1)).mean()

def loss2(p_hat, target_probs):
    return ((p_hat-target_probs)**2).mean()

def loss3(target_beta_batches, target_alpha_batches, p_hat):
    return (target_beta_batches*p_hat**2 + target_alpha_batches*(1-p_hat)**2).mean()

def loss4(alphas, betas, target_alphas, target_betas):
    # return (target_alpha_batches*(torch.digamma(s)-torch.digamma(alphas))).mean()
    return ((target_alphas + target_betas)*torch.digamma(alphas+betas)-target_alphas*torch.digamma(alphas)-target_betas*torch.digamma(betas)).mean()

def loss_mse(alphas, betas, target_alphas, target_betas):
    s = alphas+betas + 1
    return ((alphas/s-target_alphas/(target_alphas+target_betas+1))**2).mean()

def loss_dkl(alphas, betas, target_alphas, target_betas):
    return (torch.lgamma(alphas)+torch.lgamma(betas)-torch.lgamma(alphas+betas)
            -torch.lgamma(target_alphas)-torch.lgamma(target_betas)+torch.lgamma(target_alphas+target_betas)
            +(target_alphas-alphas)*torch.digamma(target_alphas)+(target_betas-betas)*torch.digamma(target_betas)
            +(alphas-target_alphas+betas-target_betas)*torch.digamma(target_alphas+target_betas)).mean()

def loss_direct_mse(alphas, betas, target_alphas, target_betas):
    return ((alphas-target_alphas)**2 +(betas-target_betas)**2).mean()

def reverse_log_likelihood_loss(alpha, beta, target_alpha, target_beta):
    return -(
        (target_alpha-1)*(torch.digamma(alpha)-torch.digamma(alpha+beta)) + 
        (target_beta-1) *(torch.digamma(beta) -torch.digamma(alpha+beta))
        # -torch.lgamma(target_alpha)-torch.lgamma(target_beta)+torch.lgamma(target_alpha+target_beta)).mean()
        ).mean()

def regularized_mse_loss(alpha, beta, target_alpha, target_beta):
    s_hat = alpha+beta
    p_hat = alpha/s_hat
    loss = (target_alpha+target_beta)*2*p_hat*(1-p_hat)/(s_hat+1) + 2*target_alpha*(1-p_hat)**2 + 2*target_beta*p_hat**2
    euler_mascheroni = 0.57721066
    regularizer = target_alpha*(torch.log(1/beta)+(1-beta)*(-euler_mascheroni)+(beta-1)*(1-euler_mascheroni)) + \
                  target_beta*(torch.log(1/alpha)+(1-alpha)*(-euler_mascheroni)+(alpha-1)*(1-euler_mascheroni))
    annealing_coefficient = 1
    # print(loss/target_alpha+target_beta)
    # print(regularizer)
    return ((loss + annealing_coefficient * regularizer)*((target_alpha+target_beta)>0)).sum()

def log_likelihood_loss(alpha, beta, target_alpha, target_beta):
    # print(alpha)
    # print(beta)
    # print(target_alpha)
    # print(target_beta)
    loss = -(
        (alpha-1)*(torch.digamma(target_alpha)-torch.digamma(target_alpha+target_beta)) + 
        (beta-1) *(torch.digamma(target_beta) -torch.digamma(target_alpha+target_beta)) -
        torch.lgamma(alpha)-torch.lgamma(beta)+torch.lgamma(alpha+beta)).mean()
    euler_mascheroni = 0.57721066
    regularizer = target_alpha*(torch.log(1/beta)+(1-beta)*(-euler_mascheroni)+(beta-1)*(1-euler_mascheroni)) + \
                  target_beta*(torch.log(1/alpha)+(1-alpha)*(-euler_mascheroni)+(alpha-1)*(1-euler_mascheroni))
    annealing_coefficient = .1
    return (loss + regularizer*annealing_coefficient).mean()

def tuning(network, loss_func):
    optimizer = torch.optim.Adam(network.parameters(), lr=1e-5)

    batch_size = 64
    
    train_target_alphas, val_target_alphas, test_target_alphas = torch.split(target_alphas, [int(len(target_alphas)*0.7), int(len(target_alphas)*0.15), int(len(target_alphas)*0.15) + 1])
    train_target_betas, val_target_betas, test_target_betas = torch.split(target_betas, [int(len(target_betas)*0.7), int(len(target_betas)*0.15), int(len(target_betas)*0.15) + 1])
    train_states, val_states, test_states = torch.split(states, [int(len(states)*0.7), int(len(states)*0.15), int(len(states)*0.15) + 1])
    
    train_target_alpha_batches = torch.split(train_target_alphas, batch_size)
    train_target_beta_batches = torch.split(train_target_betas, batch_size)
    train_batches = torch.split(train_states, batch_size)
    ma = 0
        
    for epoch in range(10):
        print(f'epoch:{epoch}')
        for i in range(len(train_batches)):
            optimizer.zero_grad()
            aug_train_inpt, aug_train_target_alpha, aug_train_target_beta = augment(train_batches[i], train_target_alpha_batches[i], train_target_beta_batches[i])
            alphas, betas = network(aug_train_inpt)

            s = alphas+betas
            p_hat = alphas/s
            
            pulls = aug_train_target_alpha + aug_train_target_beta
            target_probs = aug_train_target_alpha/pulls
            target_probs[target_probs != target_probs] = 0

            if loss_func == loss1:
                train_loss = loss1(aug_train_target_alpha, p_hat, s)
            elif loss_func == loss2:
                train_loss = loss2(p_hat, target_probs)
            elif loss_func == loss3:
                train_loss = loss3(aug_train_target_beta, aug_train_target_alpha, p_hat)
            elif loss_func == loss4:
                train_loss = loss4(aug_train_target_alpha, s, alphas)
            train_loss.backward()
            optimizer.step()
            if i % 500 == 0:
                print(train_loss.item())
    
    val_target_alpha_batches, val_target_beta_batches = torch.split(val_target_alphas, batch_size), torch.split(val_target_betas, batch_size)
    val_batches = torch.split(val_states, batch_size)
    val_batch_loss = []
    for i in range(len(val_batches)):
        optimizer.zero_grad()
        alphas, betas = network(val_batches[i])

        s = alphas+betas
        p_hat = alphas/s
        
        pulls = val_target_alpha_batches[i] + val_target_beta_batches[i]
        target_probs = val_target_alpha_batches[i]/pulls
        target_probs[target_probs != target_probs] = 0

        val_batch_loss += [loss1(val_target_alpha_batches[i], p_hat, s).item()]

    return val_batch_loss

In [8]:
def train(network, states, target_alphas, target_betas):
    optimizer = torch.optim.SGD(network.parameters(), lr=1e-5)
    batch_size = 20

    train_target_alphas, test_target_alphas = torch.split(target_alphas, [int(len(target_alphas)*0.85), len(target_alphas) - int(len(target_alphas)*0.85)])
    train_target_betas,  test_target_betas = torch.split(target_betas, [int(len(target_betas)*0.85), len(target_betas) - int(len(target_betas)*0.85)])
    train_states, test_states = torch.split(states, [int(len(states)*0.85), len(states) - int(len(states)*0.85)])
    
    train_target_alpha_batches = torch.split(train_target_alphas, batch_size)
    train_target_beta_batches = torch.split(train_target_betas, batch_size)
    train_batches = torch.split(train_states, batch_size)
    running_loss = 0
    running_p_hat_loss = 0
        
    for epoch in range(8):
        print(f'epoch:{epoch}')
        for i in range(len(train_batches)):
            optimizer.zero_grad()
            aug_train_inpt, aug_train_target_alpha, aug_train_target_beta = augment(train_batches[i], train_target_alpha_batches[i], train_target_beta_batches[i])
            alphas, betas = network(aug_train_inpt)
            # print(alphas)
            # print(betas)
            # print(aug_train_target_alpha)
            # print(aug_train_target_beta)
            # train_loss = loss_dkl(alphas+1/2, betas+1/2, aug_train_target_alpha+1/2, aug_train_target_beta+1/2)
            train_loss = regularized_mse_loss(alphas, betas, aug_train_target_alpha, aug_train_target_beta)
            # train_loss = log_likelihood_loss(alphas, betas, alphas.detach() + aug_train_target_alpha*1, betas.detach() + aug_train_target_beta*1)
            # train_loss = loss_direct_mse(alphas, betas, aug_train_target_alpha, aug_train_target_beta)
            # print('alpha')
            # print(alphas)
            # print('target alpha')
            # print(aug_train_target_alpha)
            # print('beta')
            # print(betas)
            # print('target beta')
            # print(aug_train_target_beta)
            # train_loss = loss_mse(alphas, betas, aug_train_target_alpha, aug_train_target_beta)
            running_loss = running_loss*.999 + train_loss*.001
            print(f'train loss: {running_loss}')
            running_p_hat_loss = running_p_hat_loss*.999 + ((alphas/(alphas+betas+.01)-aug_train_target_alpha/(aug_train_target_alpha+aug_train_target_beta+.01))**2).mean() * .001
            print(f'p_hat loss: {running_p_hat_loss}')
            # print(train_loss)
            # train_loss = loss4(alphas, betas, aug_train_target_alpha, aug_train_target_beta)
            # train_loss = loss4(alphas, betas, aug_train_target_alpha, aug_train_target_beta)
            train_loss.backward()
            optimizer.step()
            if i % 20 == 0:
                print(train_loss.item())
        test_alphas, test_betas = network(test_states)
        # val_loss = loss_mse(test_alphas, test_betas, test_target_alphas, test_target_betas)
        val_loss = regularized_mse_loss(test_alphas, test_betas, test_target_alphas, test_target_betas)
        # val_loss = loss_direct_mse(alphas, betas, aug_train_target_alpha, aug_train_target_beta)
        # val_loss = loss4(test_alphas, test_betas, test_target_alphas, test_target_betas)
        print(f'val loss:{val_loss}')

In [9]:
# train(network.to(mps_device), torch.FloatTensor(states).to(mps_device), torch.FloatTensor(target_alphas).to(mps_device), torch.FloatTensor(target_betas).to(mps_device))

In [10]:
writer = SummaryWriter()
env = go_v5.env(board_size = 5, komi = 3.5)
policy_algorithm=Policies.MOSSExperimental
env.reset(seed=42)
last = env.last()
winner_list = []
big_number = 100
win_condition = 100
i = 0
using_network=False
states_per_training_batch = 50000
num_unique_states_visited = 1
# data collection and training loop
while True:
    with torch.no_grad():
        fictitious_alphas, fictitious_betas = network(torch.FloatTensor(state).unsqueeze(0)) if using_network else (None, None)
    root_node = GoNode(num_arms = last[0]['action_mask'].sum(), 
                       legal_actions = last[0]['action_mask'].nonzero()[0], 
                       policy_algorithm=policy_algorithm, 
                       state = last[0]['observation'].transpose((2,1,0)), 
                       fictitious_alphas = fictitious_alphas,
                       fictitious_betas = fictitious_betas)
    # data collection and tree update loop
    while num_unique_states_visited < states_per_training_batch:
        print(num_unique_states_visited)
        node = root_node
        node_action_list = []
        observation, reward, termination, truncation, info = env.last()
        state = observation['observation'].transpose((2,1,0))
        mask = observation["action_mask"]
        # single game playout loop
        for agent in env.agent_iter():
            if termination or truncation:
                action = None
                winner = reward if agent=='black_0' else -reward
                winner = (winner + 1)/2
                winner_list.append(winner)
                writer.add_scalar('Winner', winner, str(i))
                break
            if np.random.random()>.001:
                policy_choice = node.policy.choice()
            else:
                policy_choice = np.random.randint(node.num_arms)
            action = node.legal_actions[policy_choice]
            
            node_action_list.append((node,policy_choice))
            env.step(action)
            observation, reward, termination, truncation, info = env.last()
            state = observation['observation'].transpose((2,1,0))
            mask = observation["action_mask"]
            num_arms = np.count_nonzero(mask)
            if policy_choice in node.next_nodes:
                node = node.next_nodes[policy_choice]
            else:
                with torch.no_grad():
                    fictitious_alphas, fictitious_betas = network(torch.FloatTensor(state).unsqueeze(0)) if using_network else (None, None)
                node.next_nodes[policy_choice] = GoNode(num_arms=num_arms, 
                                                        legal_actions = mask.nonzero()[0], 
                                                        policy_algorithm=policy_algorithm if num_arms > 3 else Policies.klUCBPlus, 
                                                        player = 'white_0' if node.player=='black_0' else 'black_0', 
                                                        state=state,
                                                        fictitious_alphas = fictitious_alphas,
                                                        fictitious_betas = fictitious_betas)
                num_unique_states_visited += 1
                node = node.next_nodes[policy_choice]
        env.close()

        if {'black_0':1,'white_0':0}[node_action_list[-1][0].player] == winner:
            node_solved=-1
        else:
            node_solved=1
        # tree update loop
        for node, action in reversed(node_action_list):
            if node_solved !=0:
                node.policy.rewards[action] = big_number if node_solved==-1 else .5
                node.policy.pulls[action]=big_number
                node.fictitious_rewards[action] = big_number if node_solved==-1 else .5
                node.fictitious_pulls[action] = big_number

                # node.policy.pulls[action] +=1
            else:
                node.policy.getReward(action, (winner if node.player=='black_0' else 1-winner))
            node_solved = 1 if any(node.policy.rewards==big_number) else (-1 if (all(node.policy.rewards==.5) and all(node.policy.pulls==big_number)) else 0)
        
        # if len(winner_list)>win_condition and ((np.array(winner_list[-win_condition:])==1).all() or (np.array(winner_list[-win_condition:])==0).all()):
        #     break
        if any(root_node.policy.rewards==big_number):
            break
        env.reset(seed=42)

    states, target_alphas, target_betas = pull_data(root_node)
    if not using_network:
        network = Net()
        using_network = True
    train(network, torch.FloatTensor(states), torch.FloatTensor(target_alphas), torch.FloatTensor(target_betas))
    num_unique_states_visited = 1


1
7
26
73
103
156
180
212
300
332
363
398
430
457
484
530
560
583
621
658
725
770
807
834
892
916
940
989
1015
1040
1110
1204
1236
1273
1282
1303
1324
1383
1420
1460
1503
1546
1593
1623
1645
1701
1730
1760
1790
1827
1859
1911
1948
1992
2022
2047
2099
2153
2168
2206
2233
2287
2327
2336
2387
2389
2446
2475
2491
2506
2527
2551
2579
2604
2625
2670
2699
2737
2774
2801
2841
2885
2930
2962
2996
3033
3104
3151
3241
3329
3355
3378
3426
3450
3490
3499
3539
3577
3605
3637
3675
3718
3738
3779
3818
3865
3901
3934
3966
4063
4100
4121
4139
4189
4223
4248
4259
4279
4291
4318
4356
4405
4449
4472
4507
4541
4557
4589
4628
4694
4712
4736
4772
4796
4840
4877
4904
5010
5073
5123
5171
5237
5272
5295
5322
5361
5383
5429
5462
5523
5577
5601
5673
5702
5727
5772
5809
5856
5860
5895
5924
5955
5982
6034
6058
6100
6151
6195
6225
6270
6297
6338
6409
6445
6470
6501
6506
6570
6619
6666
6706
6754
6789
6833
6860
6883
6909
6924
6969
6991
7028
7060
7082
7110
7140
7180
7246
7336
7368
7413
7452
7474
7500
7593
7639
7681
7712

  train(network, torch.FloatTensor(states), torch.FloatTensor(target_alphas), torch.FloatTensor(target_betas))


epoch:0
train loss: 1.3341552019119263
p_hat loss: 8.266701479442418e-05
1334.1551513671875
train loss: 1.6235532760620117
p_hat loss: 0.0002332527219550684
train loss: 1.8956615924835205
p_hat loss: 0.00039530853973701596
train loss: 2.105135440826416
p_hat loss: 0.0005084575386717916
train loss: 2.4653360843658447
p_hat loss: 0.0006543154013343155
train loss: 2.7394704818725586
p_hat loss: 0.0007771989330649376
train loss: 3.0169436931610107
p_hat loss: 0.0008855742053128779
train loss: 3.526123046875
p_hat loss: 0.0010553696192800999
train loss: 4.125044345855713
p_hat loss: 0.0011716950684785843
train loss: 4.233937740325928
p_hat loss: 0.0013504467206075788
train loss: 4.363389015197754
p_hat loss: 0.0015159172471612692
train loss: 4.668794631958008
p_hat loss: 0.0016506689134985209
train loss: 4.807648181915283
p_hat loss: 0.0018091037636622787
train loss: 5.167330741882324
p_hat loss: 0.001949411234818399
train loss: 5.59084415435791
p_hat loss: 0.002096290700137615
train loss: 

KeyboardInterrupt: 

In [None]:
states, target_alphas, target_betas = pull_data(root_node)

In [None]:
target_alphas

In [None]:
alphas, betas = network(torch.FloatTensor(states[:1]))

In [None]:
target_alphas[:1]

In [None]:
target_betas[:1]

In [None]:
regularized_mse_loss(alphas, betas, target_alphas[:1], target_betas[:1])

In [None]:
network = Net()

In [None]:
train(network, torch.FloatTensor(states), torch.FloatTensor(target_alphas), torch.FloatTensor(target_betas))

In [None]:
train(network, torch.FloatTensor(states), torch.FloatTensor(target_alphas), torch.FloatTensor(target_betas))

In [None]:
network(torch.FloatTensor(state).unsqueeze(0))

In [None]:
root_node.fictitious_pulls

In [None]:
root_node.fictitious_rewards

In [None]:
root_node.state

In [None]:
a,b =network(torch.FloatTensor(root_node.state).unsqueeze(0))

In [None]:
network = Net()

In [None]:
loss = loss_dkl(alphas, betas, torch.FloatTensor(root_node.policy.rewards), torch.FloatTensor(root_node.policy.pulls-root_node.policy.rewards))

In [None]:
alphas, betas = network(torch.FloatTensor(root_node.state).unsqueeze(0))
print(alphas)
print(betas)
print((alphas/(alphas+betas))[0,:25].view(5,5))

In [None]:
loss_mse(alphas, betas, torch.FloatTensor(root_node.policy.rewards), torch.FloatTensor(root_node.policy.pulls-root_node.policy.rewards))

In [None]:
(root_node.policy.rewards/root_node.policy.pulls).round(4)[:25].reshape(5,5)

In [None]:
(root_node.policy.rewards/root_node.policy.pulls).round(3)

In [None]:
target_betas[0]

In [None]:
target_alphas[0]

In [None]:
root_node.policy.rewards

In [None]:
root_node.policy.pulls - root_node.policy.rewards

In [None]:
a/(a+b)

In [None]:
a

In [None]:
b