## Validate Rust

We load up the predictions from `validate_nnue.rs` and check that they match what we get from pytorch.

In [None]:
import sys
sys.path.append("./ml/")
import numpy as np
import torch
import matplotlib.pyplot as plt
import train_nnue

In [None]:
model = train_nnue.Nnue()
model.load_state_dict(torch.load("nnue.pt"))
model.adjust_leak(0)

In [None]:
archive = np.load("rust-nnue-predictions.npz")

In [None]:
rust_predictions = archive["values"]

In [None]:
archive["policies"].min()

In [None]:
plt.matshow(archive["policies"][0,0].reshape((8, 8)))

In [None]:
data_file = "run-011-duck-chess/step-100/games/games-mcts-24d42a799adffcfc-nnue-data.npz"

In [None]:
make_batch = train_nnue.get_make_batch([data_file], "cpu")

In [None]:
#indices, offsets, which_model, lengths, value_for_white = make_batch(1000, randomize=False)
indices, offsets, which_model, lengths, value_for_white, moves_from, moves_to, legal_move_masks, have_quiescence_moves = make_batch(1000, randomize=False)
value_output, policy_from, policy_to = model(indices, offsets, which_model, lengths)
#val_loss = torch.nn.MSELoss()(value_output, value_for_white)


In [None]:
import matplotlib.pyplot as plt

In [None]:
policy_from.shape

In [None]:
plt.matshow(policy_from[0].reshape((8, 8)).detach().numpy())

In [None]:
policy_from[0].reshape((8, 8))

In [None]:
archive["policies"][0,0].reshape((8, 8)) / 2**13

In [None]:
a = rust_predictions[:1000]
b = value_output.flatten().detach().cpu().numpy()[:1000]
plt.plot(a)
plt.plot(b)
plt.plot(a - b)
plt.legend(["Rust", "PyTorch", "Diff"])

In [None]:
plt.plot(np.abs(a - b) / np.maximum(0.5, np.maximum(a, b)))

In [None]:
rust = [-0.026367188, 0.6855469, -10.581055, 0.99365234, -0.06738281, 0.107910156, 0.14697266, 0.7441406, 0.91308594, 0.92822266, 0.7373047, -0.49902344, 1.3066406, -0.011230469, 0.8745117, 1.0239258, -0.1484375, 0.03173828, 1.0058594, 12.312012, -0.08203125, 0.08984375, 0.9951172, 0.7138672, -0.18603516, 2.4033203, 4.1411133, 0.9560547, -2.152832, 0.05908203, -0.06933594, 3.053711]
rust2 = [-0.04296875, 0.6953125, -10.503906, 0.95703125, -0.06640625, 0.09765625, 0.1171875, 0.73828125, 0.87890625, 0.91015625, 0.73046875, -0.51953125, 1.2851563, 0.00390625, 0.84765625, 0.98828125, -0.1171875, 0.0234375, 0.9921875, -51.597656, -0.08203125, 0.078125, 1.0039063, 0.6953125, -0.1875, -29.523438, -27.777344, 0.953125, -2.1445313, 0.04296875, -0.0546875, -28.875, ]

python = [-0.02199588716030121, 0.6845449805259705, -10.594429016113281, 0.9983289837837219, -0.07030828297138214, 0.11041910946369171, 0.14891983568668365, 0.7469731569290161, 0.9151668548583984, 0.9316335916519165, 0.7408483624458313, -0.49524450302124023, 1.3086481094360352, -0.010220184922218323, 0.8767296671867371, 1.0258991718292236, -0.1514769196510315, 0.03474739193916321, 1.0084795951843262, -51.69867706298828, -0.08426901698112488, 0.09226831793785095, 0.9956939220428467, 0.7162597179412842, -0.18898600339889526, -29.60687255859375, -27.870325088500977, 0.9553431272506714, -2.1533918380737305, 0.06028661131858826, -0.0694335401058197, -28.958084106445312]
    

In [None]:
#plt.plot(rust)
plt.plot(rust2)
plt.plot(python)

In [None]:
named_params = list(model.named_parameters())

In [None]:
named_params[1][0]

In [None]:
mb = named_params[0][1]
mw = named_params[1][1]

In [None]:
mw.shape

In [None]:
torch.maximum(torch.tensor(0), mw[:, 16:48]).sum(axis=0).tolist()

In [None]:
torch.minimum(torch.tensor(0), mw[:, 16:48]).sum(axis=0).tolist()

In [None]:
for k, v in model.named_parameters():
    if "policy" in k:
        l1 = v.abs().sum(axis=-1).max().item()
        print("%35s %10s %.3f %.3f l1=%.3f" % (k, tuple(v.shape), v.min().item(), v.max().item(), l1))
        

In [None]:
plt.plot(rust)
plt.plot(python)
plt.legend(["Rust", "Python"])

## Examine data

In [None]:
import numpy as np

In [None]:
!ls run-011-duck-chess/step-*/games/games-mcts-*-nnue*

In [None]:
data_file = "run-011-duck-chess/step-001/games/games-mcts-9a3432a1d5657e13-nnue-data.npz"

In [None]:
d = np.load(data_file)

In [None]:
list(d)

In [None]:
d["meta"].shape

In [None]:
d["meta"][0]

In [None]:
d["meta"][0]

In [None]:
"".join(".@"[i] for i in d["meta"][:1000,5])

In [None]:
d["moves"].shape

In [None]:
d["legal_move_masks"].shape

In [None]:
ms = d["moves"]
lmm = d["legal_move_masks"]

In [None]:
lmm[np.arange(len(ms)), 0, ms[:, 0]].all()

In [None]:
lmm[np.arange(len(ms)), 1, ms[:, 1]].all()

In [None]:
d["legal_move_masks"][3]