In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt

In [2]:
model = torch.load("multi-model-nonsense.pt")

In [3]:
model.keys()

odict_keys(['main_embed.weight', 'main_embed.bias', 'white_main.0.weight', 'white_main.0.bias', 'white_main.2.weight', 'white_main.2.bias', 'white_main.4.weight', 'white_main.4.bias', 'black_main.0.weight', 'black_main.0.bias', 'black_main.2.weight', 'black_main.2.bias', 'black_main.4.weight', 'black_main.4.bias', 'white_duck.0.weight', 'white_duck.0.bias', 'white_duck.2.weight', 'white_duck.2.bias', 'white_duck.4.weight', 'white_duck.4.bias', 'black_duck.0.weight', 'black_duck.0.bias', 'black_duck.2.weight', 'black_duck.2.bias', 'black_duck.4.weight', 'black_duck.4.bias'])

In [4]:
integer_scale_i8 = 150
integer_scale_i16 = 5_000

code = [
    "// Automatically generated by ExtractTables.ipynb\n\n"
    "pub const INTEGER_SCALE_I8: f32 = %.1f;\n" % integer_scale_i8,
    "pub const INTEGER_SCALE_I16: f32 = %.1f;\n" % integer_scale_i16,
    "#[repr(C)]\nstruct SixteenByteAligned<T> {\n  _align: [u128; 0],\n  data: T,\n}\n",
]

for weight_name, weight in model.items():
    print("  Weight:", weight_name, weight.shape)
    name = "PARAMS_" + weight_name.replace(".", "_").upper()
    # Hack to rename layers. FIXME: This should be part of the model itself.
    name = name.replace("2", "1").replace("4", "2")
    if weight_name == "main_embed.weight":
        # Transpose only the very first weight.
        weight = weight.T
    if "bias" in weight_name:
        weight = (integer_scale_i16 * weight).to(torch.int32)
        type_name = "i16"
        array = " ".join("%i" % x + "," for x in weight)
        code.append(
            f"pub static {name}: &'static [{type_name}; {weight.shape[0]}] = &DUMMY_{name}.data;\n"
        )
        code.append(
            f"static DUMMY_{name}: SixteenByteAligned<[{type_name}, {weight.shape[0]}]>"
            f" = SixteenByteAligned {{\n  _align: [],\n  data: [\n  {array}\n  ],\n}};\n"
        )
        continue
    else:
        weight = (integer_scale_i8 * weight).to(torch.int32)
        type_name = "i8"
    array = "\n".join(
        "  [" + ", ".join("%i" % x for x in row) + "],"
        for row in weight
    )
    code.append(
            f"pub static {name}: &'static [[{type_name}; {weight.shape[1]}] {weight.shape[0]}] = &DUMMY_{name}.data\n"
        )
    #code.append("pub static %s: &'static [[%s; %i]; %i] = &[\n%s\n];\n" % (
    code.append(
        f"static DUMMY_{name}: SixteenByteAligned<[[{type_name}; {weight.shape[1]}], {weight.shape[0]}]>"
        f" = SixteenByteAligned {{\n  _align: [],\n  data: [\n  {array}\n  ],\n}};\n"
    )

  Weight: main_embed.weight torch.Size([256, 960])
  Weight: main_embed.bias torch.Size([256])
  Weight: white_main.0.weight torch.Size([16, 256])
  Weight: white_main.0.bias torch.Size([16])
  Weight: white_main.2.weight torch.Size([32, 16])
  Weight: white_main.2.bias torch.Size([32])
  Weight: white_main.4.weight torch.Size([1, 32])
  Weight: white_main.4.bias torch.Size([1])
  Weight: black_main.0.weight torch.Size([16, 256])
  Weight: black_main.0.bias torch.Size([16])
  Weight: black_main.2.weight torch.Size([32, 16])
  Weight: black_main.2.bias torch.Size([32])
  Weight: black_main.4.weight torch.Size([1, 32])
  Weight: black_main.4.bias torch.Size([1])
  Weight: white_duck.0.weight torch.Size([16, 256])
  Weight: white_duck.0.bias torch.Size([16])
  Weight: white_duck.2.weight torch.Size([32, 16])
  Weight: white_duck.2.bias torch.Size([32])
  Weight: white_duck.4.weight torch.Size([1, 32])
  Weight: white_duck.4.bias torch.Size([1])
  Weight: black_duck.0.weight torch.Size([16

In [5]:
with open("src/nnue_data.rs", "w") as f:
    f.write("\n".join(code))

# Debug incorrect outputs

In [6]:
def patch_up_features(feats):
    _, a, b, c = feats.shape
    assert (a, b, c) == (15, 8, 8)
    # feats of shape [batch, 15, 8, 8]
    # is_white_turn of shape [batch]
    is_white_turn = feats[:, -2, 0, 0]
    
    # If we're black then they're in order (black pieces black bottom, white pieces black bottom)
    # We want to convert this into (black pieces white bottom, white pieces white bottom)
    if_black = feats[:, :-2].copy()
    if_black[:, :, :, :] = feats[:, :-2, ::-1, :]

    # If we're white then they're in the order (white pieces white bottom, black pieces, white bottom)
    # We want to just swap the two halves.
    if_white = feats[:, :-2].copy()
    if_white[:, :6, :, :] = feats[:, 6:12, :, :]
    if_white[:, 6:12, :, :] = feats[:, :6, :, :]

    return np.where(
        is_white_turn.reshape(-1, 1, 1, 1),
        if_white,
        if_black,
    ).reshape(-1, 13 * 8 * 8)

feature_count = 13 * 64

class MultiModel(torch.nn.Module):
    FEATS = 64

    def __init__(self):
        super().__init__()
        self.main_embed = torch.nn.Linear(feature_count, self.FEATS)
        self.relu = torch.nn.ReLU()
        self.tanh = torch.nn.Tanh()
        self.white_main = torch.nn.Linear(self.FEATS, 64 + 1)
        self.black_main = torch.nn.Linear(self.FEATS, 64 + 1)
        self.white_duck = torch.nn.Linear(self.FEATS, 64 + 1)
        self.black_duck = torch.nn.Linear(self.FEATS, 64 + 1)

    def forward(self, inputs, which_model):
        embedding = self.main_embed(inputs)
        embedding = self.relu(embedding)
        white_main = self.white_main(embedding)
        black_main = self.black_main(embedding)
        white_duck = self.white_duck(embedding)
        black_duck = self.black_duck(embedding)
        data = torch.stack([white_main, black_main, white_duck, black_duck])
        data = data[which_model, torch.arange(len(which_model))]
        policy = data[:, :64]
        value = data[:, 64:]
        return policy, value, embedding

model = MultiModel()

In [7]:
model.load_state_dict(torch.load("multi-model-feat64-002.pt"))

FileNotFoundError: [Errno 2] No such file or directory: 'multi-model-feat64-002.pt'

In [None]:
dm_val = np.load("dm_val.npz")

In [None]:
f = patch_up_features(dm_val["features"].reshape(-1, 15, 8, 8))

In [None]:
f.reshape(-1, 13, 8, 8)[0]

In [None]:
starting_feat = f[0].copy().reshape(13, 8, 8)

In [None]:
starting_feat

In [None]:
starting_feat[6, 1, 1] = 1
starting_feat[6, 2, 1] = 0

In [None]:
starting_feat.shape

In [None]:
a, b, c = model(
    torch.tensor(starting_feat.reshape(1, -1), dtype=torch.float32),
    torch.tensor([0]),
)

In [None]:
b

In [None]:
plt.matshow(a.reshape(8, 8).detach().cpu().numpy())

In [None]:
target_linear_state = (c - model.state_dict()["main_embed.bias"]).detach().cpu().numpy()

In [None]:
plt.matshow(target_linear_state.reshape(8, 8))

In [None]:
plt.matshow(initial_linear_state.reshape(8, 8))

In [None]:
target_linear_state - initial_linear_state

In [None]:
#initial_linear_state = np.array([120001, 383580, 137398, 283756, 509098, 558127, -12455, 80365, 53778, 28397, -212367, 404435, 60373, 181293, 335118, -277217, 60015, 86152, 57077, -66096, -178402, -244652, 218425, 148898, -4227, -96644, -260137, 223537, 9941, 265566, 22002, 113431, -436525, 49648, -158136, -279139, 16615, 292464, -995, 141168, 28381, 177354, 466708, 105724, 8724, 162147, 426389, -41261, 114401, 169764, -323398, 50508, -32995, 384948, -163998, 207688, 111465, 23619, 168553, 257046, -323129, -294584, 333272, -362578])
initial_linear_state = np.array([208542, 74304, 103035, 177503, 99221, 99967, 49951, 122207, 126187, 64259, 159768, 151716, 140049, 161754, 193359, 5650, 123477, 134962, 280206, 61696, 169496, 112387, 147432, 49673, 58197, 83177, 200995, 156319, 83740, 49658, 178320, 140592, 29850, 120040, 86116, 102971, 214913, 83073, 147879, 166541, 11799, 143009, 47335, 123813, 119054, 147022, 141939, 145978, 265811, 151486, 73656, 147084, 177416, 142320, 194774, 206018, 196784, 89110, 148636, 182453, 128863, 137679, 165823, 47865])
initial_linear_state = initial_linear_state.astype(np.float64)
initial_linear_state /= integer_scale

In [None]:
initial_linear_state