In [None]:
import pickle
import tarfile
import chess.pgn
import io
from lcztools import LeelaBoard, load_network
from torch import optim, nn
import numpy as np
import torch

net = load_network()

for name, param in net.model.named_parameters():
    # print(name)
    if ('_val' in name) and not ('conv1_bn.weight' in name):
        print('+', name)
        param.requires_grad = True
    else:
        # print('-', name)
        param.requires_grad = False

optimizer = optim.Adam(net.model.parameters(), lr = 0.0002, weight_decay=1e-5)
criterion = nn.MSELoss()

losses = []

def do_step(training_game):
    if training_game[1] == '1/2-1/2':
        result = 0
    elif training_game[1] == '1-0':
        result = 1
    elif training_game[1] == '0-1':
        result = -1
    elif training_game[1] == '*':
        return
    features_stack = []
    results = []
    for compressed_features in training_game[0]:
        features_stack.append(LeelaBoard.decompress_features(compressed_features))
        results.append(result)
        result *= -1
    features_stack = np.stack(features_stack)
    
    optimizer.zero_grad()
    pols, vals = net.model(features_stack)
    results = torch.Tensor(results).view(-1, 1).cuda()
    loss = criterion(vals, results)
    losses.append(loss.item())
    if len(losses)==100:
        print(sum(losses)/len(losses))
        losses.clear()
    loss.backward()
    optimizer.step()    

with tarfile.open('training_data.tgz') as f:
    for idx, member in enumerate(f):
        if member.isfile():
            training_game = pickle.load(f.extractfile(member))
            do_step(training_game)

net.model.save_weights_file('/home/trevor/projects/lczero/weights/txt/test_weights_3.txt')