In [96]:
import numpy as np
import scipy
import matplotlib.pyplot as plt
import pandas as pd
from pettingzoo.classic import go_v5
from SMPyBandits import Policies
from torch.utils.tensorboard import SummaryWriter
import warnings
import torch
from torch import nn
import torch.nn.functional as F
import random
import itertools

In [97]:
def fixedMOSSEcomputeAllIndex(self):
    """ Compute the current indexes for all arms, in a vectorized manner."""
    pulls_of_suboptimal_arms = np.sum(self.pulls[self.pulls < np.sqrt(self.t)])
    if pulls_of_suboptimal_arms > 0:
        indexes = (self.rewards / self.pulls) + np.sqrt(0.5 * np.maximum(0, np.log(self.t / pulls_of_suboptimal_arms)) / self.pulls)
    else:
        indexes = (self.rewards / self.pulls) + np.sqrt(0.5 * np.maximum(0, np.log(self.t / (self.nbArms * self.pulls))) / self.pulls)
    # indexes[self.pulls < 1] = float('+inf')
    self.index[:] = indexes
Policies.MOSSExperimental.computeAllIndex = fixedMOSSEcomputeAllIndex

In [98]:
class GoNode():
    def __init__(self, num_arms, legal_actions, policy_algorithm, player='black_0', state=None, fictitious_alphas = None, fictitious_betas = None):
        self.num_arms = num_arms
        self.legal_actions = legal_actions
        self.player = player
        self.state = state
        self.policy_algorithm = policy_algorithm
        self.policy = self.policy_algorithm(num_arms)
        if fictitious_alphas == None and fictitious_betas == None:
            self.fictitious_pulls = np.ones(num_arms)
            self.fictitious_rewards = np.ones(num_arms) / 2
        else:
            self.fictitious_pulls = (fictitious_alphas + fictitious_betas).view(-1).numpy()[legal_actions]
            self.fictitious_rewards = fictitious_alphas.view(-1).numpy()[legal_actions]
        self.policy.pulls = self.fictitious_pulls.copy()
        self.policy.rewards = self.fictitious_rewards.copy()
        self.policy.t = self.policy.pulls.sum()
        self.next_nodes = {}

In [99]:
def pull_data(node):
    states = [node.state]
    num_arms = len(root_node.state[0].flatten())+1
    a = np.zeros(num_arms)
    a[node.legal_actions] = node.policy.rewards
    alphas = [a]
    p = np.zeros(num_arms)
    p[node.legal_actions] = node.policy.pulls
    betas = [p-a]
    for next_node in node.next_nodes.values():
        next_states, next_alphas, next_betas = pull_data(next_node)
        states += next_states
        alphas += next_alphas
        betas += next_betas
    return states, alphas, betas

In [194]:
class Net(nn.Module):
    def __init__(self, size):
        super(Net, self).__init__()
        if size == 1:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 2, 3, 1, padding='same')])
        elif size == 2:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 32, 3, 1, padding='same') for _ in range(2)] + \
                [nn.Conv2d(32, 2, 3, 1, padding='same')])
        elif size == 3:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 32, 3, 1, padding='same') for _ in range(4)] + \
                [nn.Conv2d(32, 2, 3, 1, padding='same')])
        elif size == 4:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 64, 3, 1, padding='same')] + \
                [nn.Conv2d(64, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 16, 3, 1, padding='same')] + \
                [nn.Conv2d(16, 2, 3, 1, padding='same')])
        elif size == 5:
            self.layers = nn.ModuleList(
                [nn.Conv2d(17, 128, 3, 1, padding='same')] + \
                [nn.Conv2d(128, 64, 3, 1, padding='same')] + \
                [nn.Conv2d(64, 32, 3, 1, padding='same')] + \
                [nn.Conv2d(32, 16, 3, 1, padding='same')] + \
                [nn.Conv2d(16, 2, 3, 1, padding='same')])

    def forward(self, x):
        for i in range(len(self.layers)-1):
            x = self.layers[i](x)
            x = F.relu(x)
        out = torch.flatten(self.layers[-1](x), -2, -1)
        passes = x[:,:2,:,:].mean(axis=(-1,-2))
        out = torch.cat((out, passes.unsqueeze(-1)), axis=-1)
        alphas = torch.exp(out[:,0,:])
        betas=torch.exp(out[:,1,:])
        return alphas, betas
    
def augment(inpt, target_alphas, target_betas):
    rotation = random.randrange(4)
    flip = random.randrange(2)
    alpha_board_targets , alpha_pass_targets = torch.split(target_alphas, [target_alphas.shape[-1]-1,1], dim=-1)
    beta_board_targets , beta_pass_targets = torch.split(target_betas, [target_betas.shape[-1]-1,1], dim=-1)
    alpha_board_targets_reshaped = alpha_board_targets.view(-1, int(np.sqrt(alpha_board_targets.shape[-1])),int(np.sqrt(alpha_board_targets.shape[-1])))
    beta_board_targets_reshaped = beta_board_targets.view(-1, int(np.sqrt(beta_board_targets.shape[-1])),int(np.sqrt(beta_board_targets.shape[-1])))
    inpt = torch.rot90(inpt,rotation,[-2,-1])
    alpha_board_targets_reshaped = torch.rot90(alpha_board_targets_reshaped,rotation,[-2,-1])
    beta_board_targets_reshaped = torch.rot90(beta_board_targets_reshaped,rotation,[-2,-1])
    if flip:
        inpt = torch.flip(inpt, [-1])
        alpha_board_targets_reshaped = torch.flip(alpha_board_targets_reshaped, [-1])
        beta_board_targets_reshaped = torch.flip(beta_board_targets_reshaped, [-1])
    alpha_board_targets = alpha_board_targets_reshaped.reshape(-1, alpha_board_targets.shape[-1])
    beta_board_targets = beta_board_targets_reshaped.reshape(-1, beta_board_targets.shape[-1])
    alpha_targets = torch.cat((alpha_board_targets, alpha_pass_targets),-1).contiguous()
    beta_targets = torch.cat((beta_board_targets, beta_pass_targets),-1).contiguous()
    return inpt, alpha_targets, beta_targets

def loss1(target_alpha_batches, p_hat, s):
    # return (target_beta_batches*p_hat**2 + target_alpha_batches*(1-p_hat)**2 + p_hat*(1-p_hat)/(s+1)*pulls).mean()
    return (target_alpha_batches - target_alpha_batches*p_hat + p_hat*(1-p_hat)/(s+1)).mean()

def loss2(p_hat, target_probs):
    return ((p_hat-target_probs)**2).mean()

def loss3(target_beta_batches, target_alpha_batches, p_hat):
    return (target_beta_batches*p_hat**2 + target_alpha_batches*(1-p_hat)**2).mean()

def loss4(alphas, betas, target_alphas, target_betas):
    # return (target_alpha_batches*(torch.digamma(s)-torch.digamma(alphas))).mean()
    return ((target_alphas + target_betas)*torch.digamma(alphas+betas)-target_alphas*torch.digamma(alphas)-target_betas*torch.digamma(betas)).mean()

def loss_mse(alphas, betas, target_alphas, target_betas):
    s = alphas+betas + 1
    return ((alphas/s-target_alphas/(target_alphas+target_betas+1))**2).mean()

def loss_dkl(alphas, betas, target_alphas, target_betas):
    return (torch.lgamma(alphas)+torch.lgamma(betas)-torch.lgamma(alphas+betas)
            -torch.lgamma(target_alphas)-torch.lgamma(target_betas)+torch.lgamma(target_alphas+target_betas)
            +(target_alphas-alphas)*torch.digamma(target_alphas)+(target_betas-betas)*torch.digamma(target_betas)
            +(alphas-target_alphas+betas-target_betas)*torch.digamma(target_alphas+target_betas)).mean()

def loss_direct_mse(alphas, betas, target_alphas, target_betas):
    return ((alphas-target_alphas)**2 +(betas-target_betas)**2).mean()

def reverse_log_likelihood_loss(alpha, beta, target_alpha, target_beta):
    return -(
        (target_alpha-1)*(torch.digamma(alpha)-torch.digamma(alpha+beta)) + 
        (target_beta-1) *(torch.digamma(beta) -torch.digamma(alpha+beta))
        # -torch.lgamma(target_alpha)-torch.lgamma(target_beta)+torch.lgamma(target_alpha+target_beta)).mean()
        ).mean()

def tuning(network, loss_func):
    optimizer = torch.optim.Adam(network.parameters(), lr=1e-5)

    batch_size = 64
    
    train_target_alphas, val_target_alphas, test_target_alphas = torch.split(target_alphas, [int(len(target_alphas)*0.7), int(len(target_alphas)*0.15), int(len(target_alphas)*0.15) + 1])
    train_target_betas, val_target_betas, test_target_betas = torch.split(target_betas, [int(len(target_betas)*0.7), int(len(target_betas)*0.15), int(len(target_betas)*0.15) + 1])
    train_states, val_states, test_states = torch.split(states, [int(len(states)*0.7), int(len(states)*0.15), int(len(states)*0.15) + 1])
    
    train_target_alpha_batches = torch.split(train_target_alphas, batch_size)
    train_target_beta_batches = torch.split(train_target_betas, batch_size)
    train_batches = torch.split(train_states, batch_size)
    ma = 0
        
    for epoch in range(10):
        print(f'epoch:{epoch}')
        for i in range(len(train_batches)):
            optimizer.zero_grad()
            aug_train_inpt, aug_train_target_alpha, aug_train_target_beta = augment(train_batches[i], train_target_alpha_batches[i], train_target_beta_batches[i])
            alphas, betas = network(aug_train_inpt)

            s = alphas+betas
            p_hat = alphas/s
            
            pulls = aug_train_target_alpha + aug_train_target_beta
            target_probs = aug_train_target_alpha/pulls
            target_probs[target_probs != target_probs] = 0

            if loss_func == loss1:
                train_loss = loss1(aug_train_target_alpha, p_hat, s)
            elif loss_func == loss2:
                train_loss = loss2(p_hat, target_probs)
            elif loss_func == loss3:
                train_loss = loss3(aug_train_target_beta, aug_train_target_alpha, p_hat)
            elif loss_func == loss4:
                train_loss = loss4(aug_train_target_alpha, s, alphas)
            train_loss.backward()
            optimizer.step()
            if i % 500 == 0:
                print(train_loss.item())
    
    val_target_alpha_batches, val_target_beta_batches = torch.split(val_target_alphas, batch_size), torch.split(val_target_betas, batch_size)
    val_batches = torch.split(val_states, batch_size)
    val_batch_loss = []
    for i in range(len(val_batches)):
        optimizer.zero_grad()
        alphas, betas = network(val_batches[i])

        s = alphas+betas
        p_hat = alphas/s
        
        pulls = val_target_alpha_batches[i] + val_target_beta_batches[i]
        target_probs = val_target_alpha_batches[i]/pulls
        target_probs[target_probs != target_probs] = 0

        val_batch_loss += [loss1(val_target_alpha_batches[i], p_hat, s).item()]

    return val_batch_loss

In [197]:
def train(network, states, target_alphas, target_betas):
    optimizer = torch.optim.Adam(network.parameters(), lr=1e-2)
    batch_size = 1024

    train_target_alphas, test_target_alphas = torch.split(target_alphas, [int(len(target_alphas)*0.85), len(target_alphas) - int(len(target_alphas)*0.85)])
    train_target_betas,  test_target_betas = torch.split(target_betas, [int(len(target_betas)*0.85), len(target_betas) - int(len(target_betas)*0.85)])
    train_states, test_states = torch.split(states, [int(len(states)*0.85), len(states) - int(len(states)*0.85)])
    
    train_target_alpha_batches = torch.split(train_target_alphas, batch_size)
    train_target_beta_batches = torch.split(train_target_betas, batch_size)
    train_batches = torch.split(train_states, batch_size)
        
    for epoch in range(10):
        print(f'epoch:{epoch}')
        for i in range(len(train_batches)):
            optimizer.zero_grad()
            aug_train_inpt, aug_train_target_alpha, aug_train_target_beta = augment(train_batches[i], train_target_alpha_batches[i], train_target_beta_batches[i])
            alphas, betas = network(aug_train_inpt)
            print(alphas)
            print(betas)
            print(aug_train_target_alpha)
            print(aug_train_target_beta)
            train_loss = loss_dkl(alphas+1/2, betas+1/2, aug_train_target_alpha+1/2, aug_train_target_beta+1/2)
            # train_loss = loss_direct_mse(alphas, betas, aug_train_target_alpha, aug_train_target_beta)
            # train_loss = loss_mse(alphas, betas, aug_train_target_alpha, aug_train_target_beta)
            print(train_loss)
            # print(train_loss)
            # train_loss = loss4(alphas, betas, aug_train_target_alpha, aug_train_target_beta)
            # train_loss = loss4(alphas, betas, aug_train_target_alpha, aug_train_target_beta)
            train_loss.backward()
            optimizer.step()
            if i % 20 == 0:
                print(train_loss.item())
        test_alphas, test_betas = network(test_states)
        # val_loss = loss_mse(test_alphas, test_betas, test_target_alphas, test_target_betas)
        val_loss = loss_dkl(test_alphas+1/2, test_betas+1/2, test_target_alphas+1/2, test_target_betas+1/2)
        # val_loss = loss_direct_mse(alphas, betas, aug_train_target_alpha, aug_train_target_beta)
        # val_loss = loss4(test_alphas, test_betas, test_target_alphas, test_target_betas)
        print(f'val loss:{val_loss}')

In [198]:
writer = SummaryWriter()
env = go_v5.env(board_size = 5, komi = 3.5)
policy_algorithm=Policies.MOSSExperimental
env.reset(seed=42)
last = env.last()
winner_list = []
big_number = 1e5
win_condition = 100
i = 0
using_network=False
states_per_training_batch = 50000
num_unique_states_visited = 1
# data collection and training loop
while True:
    with torch.no_grad():
        fictitious_alphas, fictitious_betas = network(torch.FloatTensor(state).unsqueeze(0)) if using_network else (None, None)
    root_node = GoNode(num_arms = last[0]['action_mask'].sum(), 
                       legal_actions = last[0]['action_mask'].nonzero()[0], 
                       policy_algorithm=policy_algorithm, 
                       state = last[0]['observation'].transpose((2,1,0)), 
                       fictitious_alphas = fictitious_alphas,
                       fictitious_betas = fictitious_betas)
    # data collection and tree update loop
    while num_unique_states_visited < states_per_training_batch:
        print(num_unique_states_visited)
        node = root_node
        node_action_list = []
        observation, reward, termination, truncation, info = env.last()
        state = observation['observation'].transpose((2,1,0))
        mask = observation["action_mask"]
        # single game playout loop
        for agent in env.agent_iter():
            if termination or truncation:
                action = None
                winner = reward if agent=='black_0' else -reward
                winner = (winner + 1)/2
                winner_list.append(winner)
                writer.add_scalar('Winner', winner, str(i))
                break
            policy_choice = node.policy.choice()
            action = node.legal_actions[policy_choice]
            node_action_list.append((node,policy_choice))
            env.step(action)
            observation, reward, termination, truncation, info = env.last()
            state = observation['observation'].transpose((2,1,0))
            mask = observation["action_mask"]
            num_arms = np.count_nonzero(mask)
            if policy_choice in node.next_nodes:
                node = node.next_nodes[policy_choice]
            else:
                with torch.no_grad():
                    fictitious_alphas, fictitious_betas = network(torch.FloatTensor(state).unsqueeze(0)) if using_network else (None, None)
                node.next_nodes[policy_choice] = GoNode(num_arms=num_arms, 
                                                        legal_actions = mask.nonzero()[0], 
                                                        policy_algorithm=policy_algorithm if num_arms > 3 else Policies.klUCBPlus, 
                                                        player = 'white_0' if node.player=='black_0' else 'black_0', 
                                                        state=state,
                                                        fictitious_alphas = fictitious_alphas,
                                                        fictitious_betas = fictitious_betas)
                num_unique_states_visited += 1
                node = node.next_nodes[policy_choice]
        env.close()

        if {'black_0':1,'white_0':0}[node_action_list[-1][0].player] == winner:
            node_solved=-1
        else:
            node_solved=1
        # tree update loop
        for node, action in reversed(node_action_list):
            if node_solved !=0:
                node.policy.rewards[action] = big_number if node_solved==-1 else .5
                node.policy.pulls[action]=big_number
                # node.policy.pulls[action] +=1
            else:
                node.policy.getReward(action, (winner if node.player=='black_0' else 1-winner))
            node_solved = 1 if any(node.policy.rewards==big_number) else (-1 if (all(node.policy.rewards==.5) and all(node.policy.pulls==big_number)) else 0)
        
        # if len(winner_list)>win_condition and ((np.array(winner_list[-win_condition:])==1).all() or (np.array(winner_list[-win_condition:])==0).all()):
        #     break
        if any(root_node.policy.rewards==big_number):
            break
        env.reset(seed=42)

    states, target_alphas, target_betas = pull_data(root_node)
    if not using_network:
        network = Net(5)
        using_network = True
    train(network, torch.FloatTensor(states), torch.FloatTensor(target_alphas), torch.FloatTensor(target_betas))
    num_unique_states_visited = 1


1
30
54
88
107
192
213
244
289
317
406
462
506
538
599
640
661
706
786
816
861
874
917
941
972
998
1002
1043
1080
1126
1182
1229
1261
1288
1331
1365
1389
1408
1455
1496
1547
1599
1621
1658
1695
1732
1771
1796
1814
1838
1883
1953
1993
2016
2096
2100
2122
2160
2186
2199
2227
2273
2327
2438
2473
2514
2555
2577
2633
2658
2679
2713
2749
2798
2804
2831
2851
2877
2902
2978
3040
3085
3091
3117
3147
3181
3225
3296
3345
3372
3429
3467
3517
3536
3566
3603
3646
3691
3728
3770
3837
3861
3895
3975
3999
4037
4058
4088
4119
4167
4260
4286
4308
4345
4387
4409
4454
4476
4498
4527
4554
4607
4615
4655
4700
4726
4780
4831
4875
4897
4919
4936
4968
5008
5031
5076
5084
5117
5158
5182
5195
5241
5260
5290
5336
5353
5383
5415
5535
5578
5599
5641
5683
5708
5755
5794
5822
5851
5889
5927
5945
5965
6039
6083
6130
6171
6208
6236
6247
6327
6366
6408
6415
6446
6471
6500
6523
6541
6572
6592
6621
6666
6690
6714
6738
6769
6811
6835
6860
6914
7059
7120
7128
7150
7184
7239
7285
7310
7336
7381
7418
7478
7506
7522
7572
7611
7

KeyboardInterrupt: 

In [103]:
network = Net(5)

In [114]:
train(network, torch.FloatTensor(states), torch.FloatTensor(target_alphas), torch.FloatTensor(target_betas))

epoch:0
nan


KeyboardInterrupt: 

In [None]:
train(network, torch.FloatTensor(states), torch.FloatTensor(target_alphas), torch.FloatTensor(target_betas))

epoch:0
112.05393981933594
68.24877166748047
50.61867904663086
40.52751159667969
60.80434036254883
val loss:35.66358947753906
epoch:1
93.04460906982422
47.82365417480469
35.40055847167969
24.78400421142578
42.740726470947266
val loss:33.83845138549805
epoch:2
83.05951690673828
38.5574951171875
35.672977447509766
43.107872009277344
35.690853118896484
val loss:27.170724868774414
epoch:3
60.86329650878906
53.04750061035156
34.7341194152832
19.440134048461914
25.23750877380371
val loss:21.388957977294922
epoch:4
58.416114807128906
43.52470397949219
25.83867073059082
15.260404586791992
35.57423400878906
val loss:38.22598648071289
epoch:5
83.50448608398438
38.12604904174805
34.62892532348633
26.771793365478516
15.723441123962402
val loss:30.369722366333008
epoch:6
59.47940444946289
38.21063995361328
32.1850471496582
20.06827735900879
29.365459442138672
val loss:27.099061965942383
epoch:7
58.81993103027344
60.528724670410156
39.799583435058594
14.612092971801758
16.209644317626953
val loss:34

In [185]:
network(torch.FloatTensor(state).unsqueeze(0))

(tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan]], grad_fn=<ExpBackward0>),
 tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
          nan, nan]], grad_fn=<ExpBackward0>))

In [None]:
root_node.fictitious_pulls

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [None]:
root_node.fictitious_rewards

array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
       0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])

In [None]:
root_node.state

array([[[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]],

       [[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]],

       [[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]],

       [[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]],

       [[False, False, False, False, False],
        [False, False, False, False, False],
  

In [153]:
a,b =network(torch.FloatTensor(root_node.state).unsqueeze(0))

In [186]:
network = Net(5)

In [152]:
loss = loss_dkl(alphas, betas, torch.FloatTensor(root_node.policy.rewards), torch.FloatTensor(root_node.policy.pulls-root_node.policy.rewards))

In [199]:
alphas,betas =network(torch.FloatTensor(root_node.state).unsqueeze(0))
print(alphas)
print(betas)
print((alphas/(alphas+betas ))[0,:25].view(5,5))

tensor([[3.0002e-30, 0.0000e+00, 0.0000e+00, 4.7131e-41, 1.7877e-22, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         5.4651e-44, 1.0000e+00]], grad_fn=<ExpBackward0>)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 1.]], grad_fn=<ExpBackward0>)
tensor([[1., nan, nan, 1., 1.],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, 1.]], grad_fn=<ViewBackward0>)


In [189]:
loss_mse(alphas, betas, torch.FloatTensor(root_node.policy.rewards), torch.FloatTensor(root_node.policy.pulls-root_node.policy.rewards))

tensor(0.0520, grad_fn=<MeanBackward0>)

In [171]:
(root_node.policy.rewards/root_node.policy.pulls).round(4)[:25].reshape(5,5)

array([[0.0714, 0.375 , 0.3452, 0.412 , 0.25  ],
       [0.3452, 0.0714, 0.15  , 0.4344, 0.2083],
       [0.3873, 0.3462, 0.4387, 0.3523, 0.0714],
       [0.379 , 0.4427, 0.4136, 0.3452, 0.25  ],
       [0.3571, 0.2083, 0.4271, 0.3167, 0.3   ]])

In [151]:
(root_node.policy.rewards/root_node.policy.pulls).round(3)

array([0.219, 0.349, 0.414, 0.363, 0.319, 0.362, 0.15 , 0.383, 0.15 ,
       0.083, 0.192, 0.344, 0.402, 0.412, 0.319, 0.192, 0.288, 0.401,
       0.353, 0.15 , 0.15 , 0.375, 0.381, 0.394, 0.15 , 0.   ])

In [None]:
target_betas[0]

array([ 70.5,  92.5, 107.5,  26.5,  68.5, 199.5, 285.5, 141.5,  64.5,
       105.5,  46.5, 252.5, 297.5, 257.5,  88.5, 132.5, 352.5, 189.5,
       218.5,  22.5,  41.5, 134.5, 102.5, 122.5,  36.5,  29.5])

In [None]:
target_alphas[0]

array([ 35.5,  50.5,  60.5,   8.5,  34.5, 126.5, 190.5,  84.5,  31.5,
        59.5,  20.5, 165.5, 199.5, 169.5,  47.5,  80.5, 240.5, 119.5,
       140.5,   6.5,  17.5,  79.5,  57.5,  71.5,  14.5,  10.5])

In [None]:
root_node.policy.rewards

array([ 35.5,  50.5,  60.5,   8.5,  34.5, 126.5, 190.5,  84.5,  31.5,
        59.5,  20.5, 165.5, 199.5, 169.5,  47.5,  80.5, 240.5, 119.5,
       140.5,   6.5,  17.5,  79.5,  57.5,  71.5,  14.5,  10.5])

In [None]:
root_node.policy.pulls - root_node.policy.rewards

array([ 70.5,  92.5, 107.5,  26.5,  68.5, 199.5, 285.5, 141.5,  64.5,
       105.5,  46.5, 252.5, 297.5, 257.5,  88.5, 132.5, 352.5, 189.5,
       218.5,  22.5,  41.5, 134.5, 102.5, 122.5,  36.5,  29.5])

In [154]:
a/(a+b)

tensor([[0.5017, 0.5000, 0.5006, 0.4992, 0.5046, 0.5011, 0.5020, 0.5033, 0.5015,
         0.4986, 0.5011, 0.5030, 0.5044, 0.5032, 0.4997, 0.5000, 0.5014, 0.5028,
         0.5018, 0.4992, 0.5020, 0.5000, 0.5009, 0.5013, 0.5011, 0.5000]],
       grad_fn=<DivBackward0>)

In [122]:
a

tensor([[1.2070, 1.3767, 1.2609, 1.1449, 0.8606, 2.2842, 4.4812, 5.0612, 4.0220,
         1.5442, 2.2111, 4.6003, 5.3446, 4.2993, 1.6081, 1.5523, 2.7485, 3.1910,
         2.8743, 1.4348, 0.8869, 1.1883, 1.4465, 1.5555, 1.3247, 1.0000]],
       grad_fn=<ExpBackward0>)

In [123]:
b

tensor([[  21.4252,  108.9102,  184.7312,   78.5860,   15.5823,  120.7302,
         1090.3922, 2246.7175,  614.5872,   53.4604,  203.6358, 2204.3550,
         4772.0830, 1172.7881,   75.3444,   89.7863,  664.1031, 1315.8488,
          433.8573,   39.8708,   11.7174,   32.9229,   50.3503,   29.0238,
            8.1727,   25.1777]], grad_fn=<ExpBackward0>)