In [None]:
import sys
sys.path.append("ml")

In [None]:
import time
import json
import glob
import random
import numpy as np
import torch
import make_dataset2 as make_dataset

In [None]:
random.seed(1234)

In [None]:
device = "cpu"

In [None]:
class EWMA:
    def __init__(self, alpha=0.02):
        self.alpha = alpha
        self.value = None

    def apply(self, x):
        self.value = x if self.value is None else (1 - self.alpha) * self.value + self.alpha * x

In [None]:
%%time
dm_train = np.load("dm_train.npz")
dm_features = torch.tensor(dm_train["features"])
dm_policy_to = torch.tensor(dm_train["policy_to"])
dm_policy_from = torch.tensor(dm_train["policy_from"])
dm_value = torch.tensor(dm_train["value"])

In [None]:
is_white_turn = dm_features.reshape(-1, 15, 8, 8)[:, 13, 0, 0]
is_duck_move = dm_features.reshape(-1, 15, 8, 8)[:, 14, 0, 0]

In [None]:
%%time
dm_val = np.load("dm_val.npz")
dm_val_features = torch.tensor(dm_val["features"], dtype=torch.float32, device=device)
dm_val_policy_to = torch.tensor(dm_val["policy_to"], dtype=torch.float32, device=device)
dm_val_policy_from = torch.tensor(dm_val["policy_from"], dtype=torch.float32, device=device)
dm_val_value = torch.tensor(dm_val["value"], dtype=torch.float32, device=device)

In [None]:
val_is_white_turn = dm_val_features.reshape(-1, 15, 8, 8)[:, 13, 0, 0].to(torch.int64)
val_is_duck_move = dm_val_features.reshape(-1, 15, 8, 8)[:, 14, 0, 0].to(torch.int64)

In [None]:
# Here we make the important decision to predict
# the target squares of moves, not source squares.
dm_policy = dm_policy_to
dm_val_policy = dm_val_policy_to

In [None]:
feature_count = dm_features.shape[1]

class MultiModel(torch.nn.Module):
    ACCUM_SIZE = 256
    SIZE1 = 16
    SIZE2 = 32
    FINAL_SIZE = 1

    def __init__(self):
        super().__init__()
        self.main_embed = torch.nn.Linear(feature_count, self.ACCUM_SIZE)
        self.relu = torch.nn.ReLU()
        self.tanh = torch.nn.Tanh()
        make_net = lambda: torch.nn.Sequential(
            torch.nn.Linear(self.ACCUM_SIZE, self.SIZE1),
            torch.nn.ReLU(),
            torch.nn.Linear(self.SIZE1, self.SIZE2),
            torch.nn.ReLU(),
            torch.nn.Linear(self.SIZE2, self.FINAL_SIZE),
        )
        self.white_main = make_net()
        self.black_main = make_net()
        self.white_duck = make_net()
        self.black_duck = make_net()

    def forward(self, inputs, which_model):
        embedding = self.main_embed(inputs)
        embedding = self.relu(embedding)
        white_main = self.white_main(embedding)
        black_main = self.black_main(embedding)
        white_duck = self.white_duck(embedding)
        black_duck = self.black_duck(embedding)
        data = torch.stack([white_main, black_main, white_duck, black_duck])
        data = data[which_model, torch.arange(len(which_model))]
        #policy = data[:, :64]
        #value = data[:, 64:]
        value = data[:, :1]
        return None, self.tanh(value)
        #return policy, self.tanh(value)

model = MultiModel()

print("Parameters:", sum(np.product(t.shape) for t in model.parameters()))

In [None]:
cross_en = torch.nn.CrossEntropyLoss()
mse_func = torch.nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
policy_loss_ewma = EWMA()
value_loss_ewma = EWMA()

In [None]:
for g in optimizer.param_groups:
    g['lr'] = 1e-5

In [None]:
dm_val_which_model = 2 * val_is_duck_move + (1 - val_is_white_turn)

def make_batch(batch_size):
    indices = np.random.randint(0, len(dm_features), size=batch_size)
    features = torch.tensor(dm_features[indices], dtype=torch.float32, device=device)
    policy = torch.tensor(dm_policy[indices], dtype=torch.int64, device=device)
    value = torch.tensor(dm_value[indices], dtype=torch.float32, device=device)
    which_model = torch.tensor(
        2 * is_duck_move[indices] + (1 - is_white_turn[indices]),
        dtype=torch.int64,
        device=device,
    )
    return features, policy, value, which_model

start_time = time.time()
for i in range(1_000_000):
    optimizer.zero_grad()
    features, target_policy, target_value, which_model = make_batch(512)
    policy_output, value_output = model(features, which_model)
    #policy_loss = cross_en(policy_output, target_policy)
    policy_loss = torch.tensor(0)
    value_loss = mse_func(value_output, target_value)
    loss = policy_loss + value_loss
    loss.backward()
    optimizer.step()
    policy_loss_ewma.apply(policy_loss.item())
    value_loss_ewma.apply(value_loss.item())

    if i % 2500 == 0:
        # Compute the accuracy.
        val_policy_output, val_value_output = model(dm_val_features, dm_val_which_model)
        #correct = val_policy_output.argmax(axis=-1) == dm_val_policy
        #accuracy = correct.mean(dtype=torch.float32).item()
        correct = 0
        accuracy = 0
        print("(%7.1f) [%7i] loss = %.4f (policy = %.4f  value = %0.4f) (val acc: %5.1f%%)" % (
            time.time() - start_time,
            i,
            policy_loss_ewma.value + value_loss_ewma.value,
            policy_loss_ewma.value,
            value_loss_ewma.value,
            100 * accuracy,
        ))

In [None]:
for weight in model.parameters():
    print(weight.min().item(), weight.max().item(), weight.mean().item(), weight.var().item())

In [None]:
torch.save(model.state_dict(), "multi-model-nonsense.pt")

In [None]:
#torch.save(model.state_dict(), "move_model.pt")

In [None]:
for k, v in model.state_dict().items():
    print(k, v.shape)

In [None]:
W = model.state_dict()["0.weight"].detach().cpu().numpy()

In [None]:
_ = plt.hist(W.flatten(), bins=100)

In [None]:
np.abs(W).max()

In [None]:
Wi = (W * 1000).astype(np.int32)

In [None]:
Wi

In [None]:
batch_size = 30
cases = 4
features = 64

# Generate some fake data.
data = torch.tensor(np.random.randn(cases, batch_size, features))
idx = torch.tensor(np.random.randint(low=0, high=cases, size=batch_size))

# Index into the data of shape [batch_size, cases, features], getting a result of shape [batch_size, features].
# This is the same as:
#   result = np.zeros((batch_size, features))
#   for i in range(batch_size):
#       result[i] = data[i, idx[i]]
result = data[idx, torch.arange(batch_size)]


In [None]:
result.shape

In [None]:
idx

In [None]:
result[0] == data[3, 0]