In [1]:
import torch
import torch.nn.functional as F
from model import PolicyValueNetwork

In [50]:
d = torch.load("data/iter003.safetensors")

In [51]:
states, policies, values = d['states'], d['policies'], d['values']

In [52]:
from sklearn.model_selection import train_test_split

tr_s, te_s, tr_p, te_p, tr_v, te_v = train_test_split(states, policies, values, test_size=0.2)

In [53]:
net = PolicyValueNetwork()
optimizer = torch.optim.AdamW(params=net.parameters(), lr=1e-4, weight_decay=1e-4)

In [54]:
def train_iteration(optimizer, epochs:int=10):
    net.train()
    for epoch in range(epochs):
        epoch_loss = 0.0
        for (s, p, z) in zip(tr_s, tr_p, tr_v):
            optimizer.zero_grad()
            pred_policy, pred_value = net.forward(s.unsqueeze(dim=0))

            # build policy-value loss: L = (z - v)^2 - pi^T * log(p) + R(\theta)
            # weight decay comes from AdamW
            value_loss: torch.Tensor = F.mse_loss(pred_value.view(-1), z.view(-1))
            policy_loss: torch.Tensor = -torch.sum(p * F.log_softmax(pred_policy, dim=1)) / s.size(0)
            total_loss = value_loss + policy_loss

            total_loss.backward()
            optimizer.step()
            
            epoch_loss += total_loss.detach().item()


        print(f"Epoch {epoch}: loss is {epoch_loss}")

train_iteration(optimizer, 20)

Epoch 0: loss is 2278.515717148781
Epoch 1: loss is 1892.8241513967514
Epoch 2: loss is 1723.052909553051
Epoch 3: loss is 1637.100773692131
Epoch 4: loss is 1616.4045441150665
Epoch 5: loss is 1570.3321644067764
Epoch 6: loss is 1551.0690293312073
Epoch 7: loss is 1549.6628386378288
Epoch 8: loss is 1557.8566510677338
Epoch 9: loss is 1523.4548639059067


In [56]:
net.predict(te_s)

(tensor([[0.1489, 0.1358, 0.1680,  ..., 0.1491, 0.1523, 0.1102],
         [0.1359, 0.1254, 0.1421,  ..., 0.1374, 0.1442, 0.1417],
         [0.1414, 0.1425, 0.1318,  ..., 0.1515, 0.1428, 0.1413],
         ...,
         [0.1530, 0.1229, 0.1424,  ..., 0.1317, 0.1400, 0.1573],
         [0.1130, 0.1557, 0.1439,  ..., 0.1375, 0.1906, 0.1509],
         [0.1587, 0.1389, 0.1343,  ..., 0.1401, 0.1354, 0.1373]]),
 tensor([[-0.3068],
         [-0.9960],
         [ 0.3106],
         [-0.8950],
         [ 0.8472],
         [-0.7278],
         [ 0.9400],
         [-0.8699],
         [-0.0644],
         [-0.6350],
         [ 0.9925],
         [ 0.3717],
         [-0.0672],
         [-0.4608],
         [ 0.2195],
         [ 0.8900],
         [ 0.9735],
         [ 0.7206],
         [ 0.1894],
         [-0.1321],
         [-0.6763],
         [ 0.9694],
         [ 0.9464],
         [ 0.8663],
         [-0.1914],
         [ 0.1032],
         [ 0.2931],
         [ 0.5067],
         [ 0.9721],
         [-0.1

In [7]:
from state import Game

In [155]:
g = Game(-1)

In [103]:
from zero import AlphaZero
from state import Game

In [104]:
az = AlphaZero(noise=0.3, model_pth="models/iter003.safetensors")
g = Game()

In [105]:
for _ in range(6):
    g.make_move(0)

g


Turn: O

 X |   |   |   |   |   |   
---+---+---+---+---+---+---
 O |   |   |   |   |   |   
---+---+---+---+---+---+---
 X |   |   |   |   |   |   
---+---+---+---+---+---+---
 O |   |   |   |   |   |   
---+---+---+---+---+---+---
 X |   |   |   |   |   |   
---+---+---+---+---+---+---
 O |   |   |   |   |   |   

In [107]:
az.get_best_move(g)
print(az.q[g.get_hash()])
print(az.priors[g.get_hash()])

tensor(300.)
tensor([ 0.0000, -0.4985,  0.0658,  0.0891, -0.0256, -0.5691, -0.3278])
tensor([0.0207, 0.1959, 0.1649, 0.1739, 0.1610, 0.1316, 0.1519])


In [108]:
az.visits[g.get_hash()]

tensor([ 0., 55., 56., 59., 53., 33., 44.])

In [109]:
net.predict(g.get_state_tensor())

(tensor([0.0017, 0.1395, 0.1611, 0.1967, 0.1838, 0.1576, 0.1597]),
 tensor([[-0.2341]]))

In [15]:
while not g.over():
    m = az.get_best_move(g)
    g.make_move(m)
    print(g)

tensor(300.)

Turn: X

   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   | O |   |   |   
tensor(300.)

Turn: O

   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   | X |   |   |   
---+---+---+---+---+---+---
   |   |   | O |   |   |   
tensor(300.)

Turn: X

   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   | X |   |   |   
---+---+---+---+---+---+---
   |   |   | O | O |   |   
tensor(

In [20]:
az.q[g.get_hash()]

AttributeError: 'AlphaZero' object has no attribute 'q'

In [204]:
az.get_best_move(g)

array(5)

In [206]:
az.visits[g.get_hash()]

tensor([50., 23., 55., 19., 56., 58., 39.])

In [5]:
import torch
d = torch.load("data/iter001.safetensors")

In [6]:
d

{'states': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],
 
          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0.

In [52]:
# manually inspect a game state + policy
i = 22
s, p, v = d['states'][i], d['policies'][i], d['values'][i]

board = s[0] + -1 * s[1]

g = Game()
g.board = board.numpy()

print(g)
print(p)
print(v)


Turn: O

   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
   |   |   |   |   |   |   
---+---+---+---+---+---+---
 X |   |   |   |   |   |   
tensor([0.1033, 0.2200, 0.1367, 0.1200, 0.1200, 0.1967, 0.1033])
tensor(1.)


In [43]:
az.get_best_move(g)

KeyError: b"\x94A\xff>\xd7P\xbd\xd8{\x858\xd3\xfbH\x1e\x99\r\xf9\xfcQ\x88w\xbd\xe7\xf2'\x10\xdf;7@~"

In [64]:
from model import PolicyValueNetwork

In [65]:
net = PolicyValueNetwork(path="models/iter001.safetensors")

In [66]:
for s, pi in zip(d['states'], d['policies']):
    p,v = net.predict(s.unsqueeze(dim=0))
    print("="*30)
    print(f"Predicted: {p} ")
    print("Real:", pi)

Predicted: tensor([0.0569, 0.0983, 0.1440, 0.1682, 0.1705, 0.1750, 0.1871]) 
Real: tensor([0.1383, 0.1317, 0.1067, 0.1383, 0.1200, 0.2400, 0.1250])
Predicted: tensor([0.1321, 0.1245, 0.1335, 0.1224, 0.1600, 0.1679, 0.1595]) 
Real: tensor([0.2000, 0.1200, 0.1067, 0.1367, 0.1467, 0.1050, 0.1850])
Predicted: tensor([0.1440, 0.1194, 0.1576, 0.1552, 0.1441, 0.1408, 0.1390]) 
Real: tensor([0.1183, 0.1150, 0.1067, 0.1750, 0.2700, 0.1117, 0.1033])
Predicted: tensor([0.1247, 0.1293, 0.1511, 0.1479, 0.1496, 0.1597, 0.1377]) 
Real: tensor([0.1067, 0.1250, 0.1167, 0.1667, 0.1817, 0.1533, 0.1500])
Predicted: tensor([0.1289, 0.1268, 0.1641, 0.1779, 0.1309, 0.1375, 0.1339]) 
Real: tensor([0.1900, 0.1083, 0.1083, 0.1250, 0.2100, 0.1350, 0.1233])
Predicted: tensor([0.1493, 0.1482, 0.1530, 0.1280, 0.1568, 0.1276, 0.1371]) 
Real: tensor([0.1017, 0.2100, 0.1650, 0.1250, 0.1167, 0.1467, 0.1350])
Predicted: tensor([0.1340, 0.1196, 0.1542, 0.1720, 0.1397, 0.1447, 0.1358]) 
Real: tensor([0.1200, 0.1483, 0.123