In [1]:
import sys
sys.path.append("./ml/")
import glob
import json
import struct
import numpy as np
import torch
import matplotlib.pyplot as plt

In [2]:
import train_nnue

In [3]:
model = train_nnue.Nnue()
model.load_state_dict(torch.load("nnue.pt"))
model.adjust_leak(0)

In [4]:
data_files = glob.glob("./run-011-duck-chess/step-*/games/*.npz")[-5:]
data_files

['./run-011-duck-chess/step-091/games/games-mcts-801d9b63846d6733-nnue-data.npz',
 './run-011-duck-chess/step-066/games/games-mcts-1060499d6c37ec40-nnue-data.npz',
 './run-011-duck-chess/step-066/games/games-mcts-14f2d4a43d9e5dd9-nnue-data.npz',
 './run-011-duck-chess/step-075/games/games-mcts-fb530bceb09373bc-nnue-data.npz',
 './run-011-duck-chess/step-075/games/games-mcts-86bf3b09ec0916c0-nnue-data.npz']

In [5]:
make_batch = train_nnue.get_make_batch(data_files, "cpu")

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00,  3.66it/s]

Total examples: 3641764
Constant model loss: 0.44617455699141884





In [6]:
indices, offsets, which_model, lengths, targets = make_batch(1024 * 1024)

In [7]:
def estimate_loss():
    outputs = model(indices, offsets, which_model, lengths)
    return torch.mean((outputs - targets)**2)

In [8]:
estimate_loss()

tensor(0.3419, grad_fn=<MeanBackward0>)

In [9]:
outputs = model(indices, offsets, which_model, lengths)
lose = outputs[targets == -1]
draw = outputs[targets == 0]
win = outputs[targets == 1]

In [10]:
# Each clipped relu wants inputs from -128 to +127 for its active range.
# If the largest intermediates we care to represent are -2.0 to +2.0,
# then this means that -2.0 should map to -32768, and +1.99... should map to +32767.
# Therefore we scale down inputs by 128 before passing them in to the clipped relu.
# Therefore, 128 * 128 = 1684 represents 1.0 as an input to relu.
# This means that a quantized weight of 128 represents the weight 1.0 in the original.
# There is one exception to this, in the original embedding layer, and all biases,
# where 16384 represents 1.0.

In [11]:
new_values = {}
quantized_weights = {}
output_right_shift = {}
for k, v in model.named_parameters():
    output_right_shift[k] = 0
    if "main_embed" in k or "bias" in k:
        quantized = (v * 16384).to(torch.int16)
        f = quantized.float().detach() / 16384
        output_right_shift[k] = 14
    else:
        quantized = (v * 128).to(torch.int8)
        f = quantized.float().detach() / 128
        output_right_shift[k] = 7
    zero_fraction = (quantized == 0).sum() / v.numel()
    new_values[k] = f
    quantized_weights[k] = quantized
    print(f"{k:20} {str(tuple(v.shape)):15} {v.min().item():.3f} {v.max().item():.3f} zero={100 * zero_fraction:.3f}%")

main_bias            (256,)          -0.015 0.016 zero=0.391%
main_embed.weight    (106496, 256)   -0.361 0.388 zero=0.213%
networks.0.0.weight  (16, 256)       -0.229 0.214 zero=13.232%
networks.0.0.bias    (16,)           -0.094 0.067 zero=0.000%
networks.0.2.weight  (32, 16)        -0.267 0.265 zero=2.734%
networks.0.2.bias    (32,)           -0.228 0.242 zero=0.000%
networks.0.4.weight  (1, 32)         -0.121 0.147 zero=25.000%
networks.0.4.bias    (1,)            -0.084 -0.084 zero=0.000%
networks.1.0.weight  (16, 256)       -0.215 0.220 zero=12.939%
networks.1.0.bias    (16,)           -0.034 0.048 zero=0.000%
networks.1.2.weight  (32, 16)        -0.278 0.275 zero=3.320%
networks.1.2.bias    (32,)           -0.237 0.222 zero=0.000%
networks.1.4.weight  (1, 32)         -0.133 0.142 zero=15.625%
networks.1.4.bias    (1,)            0.014 0.014 zero=0.000%
networks.2.0.weight  (16, 256)       -0.216 0.275 zero=13.281%
networks.2.0.bias    (16,)           -0.035 0.062 zero=0.000%
net

In [12]:
model.load_state_dict(new_values)

<All keys matched successfully>

In [13]:
estimate_loss()

tensor(0.3420, grad_fn=<MeanBackward0>)

In [14]:
def pack_i32(i32):
    return struct.pack("<i", i32)

In [15]:
header_alloc = 7808
aligned_storage = bytearray(header_alloc)

def add_bytes(b):
    # Align to the nearest 32-byte boundary.
    padding = (32 - len(aligned_storage)) % 32
    aligned_storage.extend(b'\0' * padding)
    offset = len(aligned_storage)
    aligned_storage.extend(b)
    return offset

weights = {}
for k, v in quantized_weights.items():
    shift = output_right_shift[k]
    v = v.detach().cpu().numpy()
    k = k.replace("networks.", "n")
    k = k.replace("0.weight", "0.w")
    k = k.replace("2.weight", "1.w")
    k = k.replace("4.weight", "2.w")
    k = k.replace("0.bias", "0.b")
    k = k.replace("2.bias", "1.b")
    k = k.replace("4.bias", "2.b")
    offset = add_bytes(v.tobytes())
    assert offset % 32 == 0
    weights[k] = {
        "shape": tuple(v.shape),
        "dtype": {"int8": "i8", "int16": "i16"}[str(v.dtype)],
        "offset": offset,
        "shift": shift,
    }
message = {
    "version": "v1",
    "weights": weights,
}

message_bytes = json.dumps(message).encode()
assert len(message_bytes) < header_alloc
aligned_storage[:len(message_bytes)] = message_bytes
len(message_bytes)

7610

In [16]:
with open("src/nnue-data.bin", "wb") as f:
    f.write(aligned_storage)

In [17]:
message

{'version': 'v1',
 'weights': {'main_bias': {'shape': (256,),
   'dtype': 'i16',
   'offset': 7808,
   'shift': 14},
  'main_embed.weight': {'shape': (106496, 256),
   'dtype': 'i16',
   'offset': 8320,
   'shift': 14},
  'n0.0.w': {'shape': (16, 256),
   'dtype': 'i8',
   'offset': 54534272,
   'shift': 7},
  'n0.0.b': {'shape': (16,), 'dtype': 'i16', 'offset': 54538368, 'shift': 14},
  'n0.1.w': {'shape': (32, 16), 'dtype': 'i8', 'offset': 54538400, 'shift': 7},
  'n0.1.b': {'shape': (32,), 'dtype': 'i16', 'offset': 54538912, 'shift': 14},
  'n0.2.w': {'shape': (1, 32), 'dtype': 'i8', 'offset': 54538976, 'shift': 7},
  'n0.2.b': {'shape': (1,), 'dtype': 'i16', 'offset': 54539008, 'shift': 14},
  'n1.0.w': {'shape': (16, 256),
   'dtype': 'i8',
   'offset': 54539040,
   'shift': 7},
  'n1.0.b': {'shape': (16,), 'dtype': 'i16', 'offset': 54543136, 'shift': 14},
  'n1.1.w': {'shape': (32, 16), 'dtype': 'i8', 'offset': 54543168, 'shift': 7},
  'n1.1.b': {'shape': (32,), 'dtype': 'i16', '

In [34]:
import zlib
import bz2

In [36]:
x = zlib.compress(bytes(aligned_storage), level=9)

In [38]:
len(x) * 1e-6

44.387080999999995

In [43]:
x = bz2.compress(bytes(aligned_storage))

In [44]:
len(x) * 1e-6

40.349599999999995

In [46]:
len(aligned_storage) * 1e-6

54.815265999999994

In [20]:
b = bytearray()
b.extend(b"v1")
b.extend(pack_i32(len(quantized_weights)))

def 

for k, v in quantized_weights.items():
    

In [None]:
output = open("network.bin", "wb")
output.write("v1")

In [115]:
message

{'version': 'v1',
 'weights': {'main_bias': {'shape': (257,),
   'dtype': 'i16',
   'offset': 7808,
   'shift': 7},
  'main_embed.weight': {'shape': (106496, 257),
   'dtype': 'i16',
   'offset': 8352,
   'shift': 7},
  'n0.0.w': {'shape': (16, 256),
   'dtype': 'i8',
   'offset': 54747296,
   'shift': 0},
  'n0.0.b': {'shape': (16,), 'dtype': 'i16', 'offset': 54751392, 'shift': 7},
  'n0.1.w': {'shape': (32, 16), 'dtype': 'i8', 'offset': 54751424, 'shift': 0},
  'n0.1.b': {'shape': (32,), 'dtype': 'i16', 'offset': 54751936, 'shift': 7},
  'n0.2.w': {'shape': (1, 32), 'dtype': 'i8', 'offset': 54752000, 'shift': 0},
  'n0.2.b': {'shape': (1,), 'dtype': 'i16', 'offset': 54752032, 'shift': 7},
  'n1.0.w': {'shape': (16, 256),
   'dtype': 'i8',
   'offset': 54752064,
   'shift': 0},
  'n1.0.b': {'shape': (16,), 'dtype': 'i16', 'offset': 54756160, 'shift': 7},
  'n1.1.w': {'shape': (32, 16), 'dtype': 'i8', 'offset': 54756192, 'shift': 0},
  'n1.1.b': {'shape': (32,), 'dtype': 'i16', 'offset

In [16]:
for k, v in quantized_weights.items():
    v = v.detach().cpu().numpy()
    #print(k, v.shape, v.dtype)

main_bias (257,) int16
main_embed.weight (106496, 257) int16
networks.0.0.weight (16, 256) int8
networks.0.0.bias (16,) int16
networks.0.2.weight (32, 16) int8
networks.0.2.bias (32,) int16
networks.0.4.weight (1, 32) int8
networks.0.4.bias (1,) int16
networks.1.0.weight (16, 256) int8
networks.1.0.bias (16,) int16
networks.1.2.weight (32, 16) int8
networks.1.2.bias (32,) int16
networks.1.4.weight (1, 32) int8
networks.1.4.bias (1,) int16
networks.2.0.weight (16, 256) int8
networks.2.0.bias (16,) int16
networks.2.2.weight (32, 16) int8
networks.2.2.bias (32,) int16
networks.2.4.weight (1, 32) int8
networks.2.4.bias (1,) int16
networks.3.0.weight (16, 256) int8
networks.3.0.bias (16,) int16
networks.3.2.weight (32, 16) int8
networks.3.2.bias (32,) int16
networks.3.4.weight (1, 32) int8
networks.3.4.bias (1,) int16
networks.4.0.weight (16, 256) int8
networks.4.0.bias (16,) int16
networks.4.2.weight (32, 16) int8
networks.4.2.bias (32,) int16
networks.4.4.weight (1, 32) int8
networks.4.4.

In [None]:
output.close()