In [130]:
import torch
import torch.nn.functional as F
from model import PolicyValueNetwork

In [50]:
d = torch.load("data/iter003.safetensors")

In [51]:
states, policies, values = d['states'], d['policies'], d['values']

In [52]:
from sklearn.model_selection import train_test_split

tr_s, te_s, tr_p, te_p, tr_v, te_v = train_test_split(states, policies, values, test_size=0.2)

In [53]:
net = PolicyValueNetwork()
optimizer = torch.optim.AdamW(params=net.parameters(), lr=1e-4, weight_decay=1e-4)

In [54]:
def train_iteration(optimizer, epochs:int=10):
    net.train()
    for epoch in range(epochs):
        epoch_loss = 0.0
        for (s, p, z) in zip(tr_s, tr_p, tr_v):
            optimizer.zero_grad()
            pred_policy, pred_value = net.forward(s.unsqueeze(dim=0))

            # build policy-value loss: L = (z - v)^2 - pi^T * log(p) + R(\theta)
            # weight decay comes from AdamW
            value_loss: torch.Tensor = F.mse_loss(pred_value.view(-1), z.view(-1))
            policy_loss: torch.Tensor = -torch.sum(p * F.log_softmax(pred_policy, dim=1)) / s.size(0)
            total_loss = value_loss + policy_loss

            total_loss.backward()
            optimizer.step()
            
            epoch_loss += total_loss.detach().item()


        print(f"Epoch {epoch}: loss is {epoch_loss}")

train_iteration(optimizer, 20)

Epoch 0: loss is 2278.515717148781
Epoch 1: loss is 1892.8241513967514
Epoch 2: loss is 1723.052909553051
Epoch 3: loss is 1637.100773692131
Epoch 4: loss is 1616.4045441150665
Epoch 5: loss is 1570.3321644067764
Epoch 6: loss is 1551.0690293312073
Epoch 7: loss is 1549.6628386378288
Epoch 8: loss is 1557.8566510677338
Epoch 9: loss is 1523.4548639059067


In [56]:
net.predict(te_s)

(tensor([[0.1489, 0.1358, 0.1680,  ..., 0.1491, 0.1523, 0.1102],
         [0.1359, 0.1254, 0.1421,  ..., 0.1374, 0.1442, 0.1417],
         [0.1414, 0.1425, 0.1318,  ..., 0.1515, 0.1428, 0.1413],
         ...,
         [0.1530, 0.1229, 0.1424,  ..., 0.1317, 0.1400, 0.1573],
         [0.1130, 0.1557, 0.1439,  ..., 0.1375, 0.1906, 0.1509],
         [0.1587, 0.1389, 0.1343,  ..., 0.1401, 0.1354, 0.1373]]),
 tensor([[-0.3068],
         [-0.9960],
         [ 0.3106],
         [-0.8950],
         [ 0.8472],
         [-0.7278],
         [ 0.9400],
         [-0.8699],
         [-0.0644],
         [-0.6350],
         [ 0.9925],
         [ 0.3717],
         [-0.0672],
         [-0.4608],
         [ 0.2195],
         [ 0.8900],
         [ 0.9735],
         [ 0.7206],
         [ 0.1894],
         [-0.1321],
         [-0.6763],
         [ 0.9694],
         [ 0.9464],
         [ 0.8663],
         [-0.1914],
         [ 0.1032],
         [ 0.2931],
         [ 0.5067],
         [ 0.9721],
         [-0.1

In [154]:
from state import Game

In [155]:
g = Game(-1)

In [201]:
g.make_move(0)
g.make_move(1)
g.make_move(0)
g.make_move(2)
g.make_move(0)

In [202]:
g


Turn: X

   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
 O |   |   |   |   |   |   
---+---+---+---+---+---+---
 O |   |   |   |   |   |   
---+---+---+---+---+---+---
 O | X | X |   |   |   |   

In [199]:
from zero import AlphaZero

In [207]:
az = AlphaZero(noise=0.3, model_pth="models/iter021.safetensors")
g = Game()

In [198]:
while not g.over():
    m = az.get_best_move(g)
    g.make_move(m)
    print(g)


Turn: X

   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   | O |   

Turn: O

   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   | X |   |   | O |   

Turn: X

   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   | O |   
---+---+---+---+---+---+---
   |   | X |   |   | O |   

Turn: O

   |   |   |   |   |   |   
---+---+

In [203]:
az.q[g.get_hash()]

AttributeError: 'AlphaZero' object has no attribute 'q'

In [204]:
az.get_best_move(g)

array(5)

In [206]:
az.visits[g.get_hash()]

tensor([50., 23., 55., 19., 56., 58., 39.])