In [None]:
import json, io
import numpy as np
import torch
import torch.nn as nn
!pip install chess
!pip install pytorch_optimizer
import chess
import re
import random
import pickle

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# configuring device
try:
    device = xm.xla_device()
    print("Running on the TPU")
except:
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
        print('Running on the GPU')
        torch.cuda.synchronize()
    else:
        device = torch.device('cpu')
        print('Running on the CPU')

Mounted at /content/drive
Running on the GPU


In [None]:
LONG_INTERVAL = 200
SHORT_INTERVAL = 10

EVAL_SET_SIZE = 65536
PUZZLE_SET_SIZE = 4096

# Stages:
EASY_PUZZLES = 0 # 75% easy puzzles, 9.375% mates, 3.125% openings, 6.25% midgames, 6.25% endgames
MID_PUZZLES = 1 # 12.5% easy puzzles, 62.5% mid puzzles, 6.25% mates, 3.125% openings, 9.375% midgames, 6.25% endgames
MID_CONSOLIDATION = 2 # 25% easy puzzles, 25% mid puzzles, 6.25% mates, 15.625% openings, 15.625% midgames, 12.5% endgames
HARD_PUZZLES = 3 # 6.25% easy puzzles, 9.375% mid puzzles, 50% hard puzzles, 12.5% mates, 6.25% openings, 9.375% midgames, 6.25% endgames
POSITIONAL = 4 # 3.125% easy puzzles, 6.25% mid puzzles, 9.375% hard puzzles, 6.25% mates, 25% openings, 25% midgames, 25% endgames

# Types:
DRAW = 5
ADV = 6
WIN = 7

# Total (rough, assuming equal training time):
# Easy puzzles: 22.4%, Mid puzzles: 19.3%, Hard puzzles: 12.0%, Mates: 7.8%, Openings: 10.9%, Midgames: 14.1%, Endgames: 13.5%

STAGE = POSITIONAL
if STAGE in (EASY_PUZZLES, MID_PUZZLES, HARD_PUZZLES):
    EPOCH_SIZE = 65536
    BATCH_SIZE = 4096
    VAL_SIZE = 131072
elif STAGE in (MID_CONSOLIDATION, POSITIONAL):
    EPOCH_SIZE = 65536
    BATCH_SIZE = 2048
    VAL_SIZE = 131072

learning_rate = 0.001
l2r = 0
checkpoint = 0
save = True
model_name = "complex2_parrot"

TIME_CHECK = False

In [None]:
class UltraSimpleModel(nn.Module):

    def __init__(self, name):
        super().__init__()

        self.name = name

        self.conv_net = nn.Sequential()
        self.conv_net.add_module("Conv 1", nn.Conv2d(1, 64, 8, 1))
        self.conv_net.add_module("Conv batchnorm 1", nn.BatchNorm2d(64, momentum=0.2))
        self.conv_net.add_module("Conv activation 1", nn.LeakyReLU())
        self.conv_net.add_module("Flattener", nn.Flatten())

        self.mlp = nn.Sequential()
        self.mlp.add_module("Layer 1", nn.Linear(76, 1))
        self.mlp.add_module("Activation 1", nn.Sigmoid())

    def forward(self, x, feat):
        x = self.conv_net.forward(x)
        y = torch.column_stack((x, feat))
        return self.mlp.forward(y)

    def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad)

class SimpleModel(nn.Module):

    def __init__(self, name):
        super().__init__()

        self.name = name

        self.conv_net = nn.Sequential()
        self.conv_net.add_module("Conv 1", nn.Conv2d(1, 100, 2, 1))
        self.conv_net.add_module("Batchnorm 1", nn.BatchNorm2d(100))
        self.conv_net.add_module("Conv activation", nn.LeakyReLU())
        self.conv_net.add_module("Conv 2", nn.Conv2d(100, 250, 2, 1))
        self.conv_net.add_module("Batchnorm 2", nn.BatchNorm2d(250))
        self.conv_net.add_module("Conv activation 2", nn.LeakyReLU())
        self.conv_net.add_module("Conv 3", nn.Conv2d(250, 400, 3, 1))
        self.conv_net.add_module("Batchnorm 3", nn.BatchNorm2d(400))
        self.conv_net.add_module("Conv activation 3", nn.LeakyReLU())
        self.conv_net.add_module("Conv 4", nn.Conv2d(400, 700, 4, 1))
        self.conv_net.add_module("Batchnorm 4", nn.BatchNorm2d(700))
        self.conv_net.add_module("Conv activation 4", nn.LeakyReLU())
        self.conv_net.add_module("Flattener", nn.Flatten())

        self.mlp = nn.Sequential()
        self.mlp.add_module("Layer 1", nn.Linear(700, 500))
        self.mlp.add_module("Batchnorm 1", nn.BatchNorm1d(500))
        self.mlp.add_module("Activation 1", nn.LeakyReLU())
        self.mlp.add_module("Layer 2", nn.Linear(500, 250))
        self.mlp.add_module("Activation 2", nn.LeakyReLU())
        self.mlp.add_module("Layer 3", nn.Linear(250, 1))
        self.mlp.add_module("Activation 3", nn.Sigmoid())

    def forward(self, x):
        x = self.conv_net.forward(x)
        return self.mlp.forward(x)

    def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad)


class SE_Block(nn.Module):
    "credits: https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py#L4"
    def __init__(self, c, r=16):
        super().__init__()
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(c, c // r, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c // r, c, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        bs, c, _, _ = x.shape
        y = self.squeeze(x).view(bs, c)
        y = self.excitation(y).view(bs, c, 1, 1)
        return x * y.expand_as(x)

class ComplexModel(nn.Module):
  # Attempt at using a deeper CNN.

    def __init__(self, name):
        super().__init__()

        self.name = name

        self.conv_net = nn.Sequential()
        self.conv_net.add_module("Conv 1", nn.Conv2d(1, 200, 3, 1, 1))
        self.conv_net.add_module("SE 1", SE_Block(200))
        self.conv_net.add_module("Batchnorm 1", nn.BatchNorm2d(200))
        self.conv_net.add_module("Conv activation", nn.Mish())
        self.conv_net.add_module("Conv 2", nn.Conv2d(200, 190, 3, 1, 1))
        self.conv_net.add_module("SE 2", SE_Block(190))
        self.conv_net.add_module("Batchnorm 2", nn.BatchNorm2d(190))
        self.conv_net.add_module("Conv activation 2", nn.Mish())
        self.conv_net.add_module("Conv 3", nn.Conv2d(190, 180, 3, 1, 1))
        self.conv_net.add_module("SE 3", SE_Block(180))
        self.conv_net.add_module("Batchnorm 3", nn.BatchNorm2d(180))
        self.conv_net.add_module("Conv activation 3", nn.Mish())
        self.conv_net.add_module("Conv 4", nn.Conv2d(180, 170, 3, 1, 1))
        self.conv_net.add_module("SE 4", SE_Block(170))
        self.conv_net.add_module("Batchnorm 4", nn.BatchNorm2d(170))
        self.conv_net.add_module("Conv activation 4", nn.Mish())
        self.conv_net.add_module("Conv 5", nn.Conv2d(170, 160, 3, 1, 1))
        self.conv_net.add_module("SE 5", SE_Block(160))
        self.conv_net.add_module("Batchnorm 5", nn.BatchNorm2d(160))
        self.conv_net.add_module("Conv activation 5", nn.Mish())
        self.conv_net.add_module("Conv 6", nn.Conv2d(160, 150, 3, 1, 1))
        self.conv_net.add_module("SE 6", SE_Block(150))
        self.conv_net.add_module("Batchnorm 6", nn.BatchNorm2d(150))
        self.conv_net.add_module("Conv activation 6", nn.Mish())
        self.conv_net.add_module("Conv 7", nn.Conv2d(150, 140, 3, 1, 1))
        self.conv_net.add_module("SE 7", SE_Block(140))
        self.conv_net.add_module("Batchnorm 7", nn.BatchNorm2d(140))
        self.conv_net.add_module("Conv activation 7", nn.Mish())
        self.conv_net.add_module("Conv 8", nn.Conv2d(140, 130, 3, 1, 1))
        self.conv_net.add_module("SE 8", SE_Block(130))
        self.conv_net.add_module("Batchnorm 8", nn.BatchNorm2d(130))
        self.conv_net.add_module("Conv activation 8", nn.Mish())
        self.conv_net.add_module("Conv 9", nn.Conv2d(130, 120, 3, 1, 1))
        self.conv_net.add_module("SE 9", SE_Block(120))
        self.conv_net.add_module("Batchnorm 9", nn.BatchNorm2d(120))
        self.conv_net.add_module("Conv activation 9", nn.Mish())
        self.conv_net.add_module("Conv 10", nn.Conv2d(120, 110, 3, 1, 1))
        self.conv_net.add_module("SE 10", SE_Block(110))
        self.conv_net.add_module("Batchnorm 10", nn.BatchNorm2d(110))
        self.conv_net.add_module("Conv activation 10", nn.Mish())
        self.conv_net.add_module("Flattener", nn.Flatten())

        self.mlp = nn.Sequential()
        self.mlp.add_module("Linear 1", nn.Linear(7040, 1))
        self.mlp.add_module("Activation 1", nn.Sigmoid())

    def forward(self, x):
        a=self.conv_net(x)
        return self.mlp(self.conv_net(x))

    def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad)


class MeanModel(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, a, b):
        return torch.mean(a)



In [None]:
m = SimpleModel("test")
print(m.count_parameters())

m = ComplexModel("test").to(device)
print(m.count_parameters())
m(torch.rand(2, 1, 8, 8, device=device))

5961751
2041051


tensor([[0.5021],
        [0.4646]], device='cuda:0', grad_fn=<SigmoidBackward0>)

In [None]:
# Curriculum learning

def read_eval_data(phase, num):
    data_list = pickle.load(open(f"/content/drive/MyDrive/parrot/evaluation_database/{phase}_data_{num}.chess", "rb"))
    return data_list

def read_puzzle_data(person, difficulty, num):
    data_list = pickle.load(open(f"/content/drive/MyDrive/parrot/puzzle_database/data_{person}_{difficulty}_{num}.chess", "rb"))
    return data_list

print("Dataset info:")
print("- Drawish positions: 65536 * 150 = 9,830,400")
print("- Advantageous positions: 65536 * 250 = 16,384,000")
print("- Winning positions: 65536 * 112 = 7,340,032")
print("- Total = 33,554,432")
print()
print("- Easy puzzles: 4096 * 213 =   872,448")
print("- Mid puzzles : 4096 * 384 = 1,572,864")
print("- Hard puzzles: 4096 * 152 =   622,592")
print("- Total = 3,067,904")

print("----------------")


# Filenames + 1 since indexing starts from 0.
EASY_DISTRIBUTION = [51, 37, 42, 41, 42]
MID_DISTRIBUTION = [92, 66, 76, 74, 76]
HARD_DISTRIBUTION = [36, 27, 30, 29, 30]

EASY_PUZZLE_VALIDATION_SET = [[0, 5, 13, 21, 29, 37, 45, 50], [1, 6, 11, 16, 21, 26], [2, 8, 14, 20, 26, 32], [2, 8, 14, 20, 26, 32], [2, 8, 14, 20, 26, 32]]
MID_PUZZLE_VALIDATION_SET = [[1, 11, 21, 31, 41, 51, 61, 71], [2, 14, 26, 38, 50, 62], [3, 16, 29, 42, 55, 68], [4, 16, 28, 40, 52, 64], [5, 16, 27, 38, 49, 60]]
HARD_PUZZLE_VALIDATION_SET = [[3, 12, 20, 34], [1, 12, 26], [5, 10, 20], [5, 19, 24], [8, 18, 28]]
DRAW_VALIDATION_SET = [25, 125]
ADV_VALIDATION_SET = [75, 200]
WIN_VALIDATION_SET = [50, 100]

# Build validation sets
print("Target validation set size:", VAL_SIZE)
val_mid_puzzle_boards, val_mid_puzzle_evals, val_hard_puzzle_boards, val_hard_puzzle_evals = [], [], [], []

# 3.125% easy puzzles, 6.25% mid puzzles, 6.25% hard puzzles, 34.375% draws, 25% advantages, 25% wins
if STAGE == POSITIONAL:
    num_easy_puzzles_sets = round((VAL_SIZE * 0.03125) / PUZZLE_SET_SIZE)
    val_easy_puzzle_boards = []
    val_easy_puzzle_evals = []
    for i in range(1):
        j = random.randint(0, len(EASY_PUZZLE_VALIDATION_SET[i]) - 1)
        vb, ve = read_puzzle_data(i, "easy", EASY_PUZZLE_VALIDATION_SET[i][j])
        val_easy_puzzle_boards += vb
        val_easy_puzzle_evals += ve
    print("Easy puzzle board validation set dimensions: ", np.shape(val_easy_puzzle_boards))

    num_mid_puzzles_sets = round((VAL_SIZE * 0.0625) / PUZZLE_SET_SIZE)
    val_mid_puzzle_boards = []
    val_mid_puzzle_evals = []
    for i in range(1):
        j = random.randint(0, len(EASY_PUZZLE_VALIDATION_SET[i]) - 1)
        k = random.randint(0, len(EASY_PUZZLE_VALIDATION_SET[i]) - 1)
        while (j == k):
            j = random.randint(0, len(EASY_PUZZLE_VALIDATION_SET[i]) - 1)

        vb, ve = read_puzzle_data(i, "mid", MID_PUZZLE_VALIDATION_SET[i][j])
        vb1, ve1 = read_puzzle_data(i, "mid", MID_PUZZLE_VALIDATION_SET[i][k])
        val_mid_puzzle_boards += (vb + vb1)
        val_mid_puzzle_evals += (ve + ve1)
    print("Mid puzzle board validation set dimensions: ", np.shape(val_mid_puzzle_boards))

    num_hard_puzzles_sets = round((VAL_SIZE * 0.0625) / PUZZLE_SET_SIZE)
    val_hard_puzzle_boards = []
    val_hard_puzzle_evals = []
    for i in range(1, 3):
        for j in range(1):
            vb, ve = read_puzzle_data(i, "hard", HARD_PUZZLE_VALIDATION_SET[i][j])
            val_hard_puzzle_boards += vb
            val_hard_puzzle_evals += ve
    print("Hard puzzle board validation set dimensions: ", np.shape(val_hard_puzzle_boards))

    num_draws = round(VAL_SIZE * 0.5)
    val_draw_boards, val_draw_evals = read_eval_data("draw", DRAW_VALIDATION_SET[0])
    vb, ve = read_eval_data("draw", DRAW_VALIDATION_SET[1])
    val_draw_boards += vb[:(num_draws - EVAL_SET_SIZE)]
    val_draw_evals += ve[:(num_draws - EVAL_SET_SIZE)]
    print("Draws board validation set dimensions: ", np.shape(val_draw_boards))

    num_advs = round(VAL_SIZE * 0.25)
    val_adv_boards, val_adv_evals = [], []
    vb, ve = read_eval_data("adv", ADV_VALIDATION_SET[0])
    val_adv_boards += vb[:num_advs]
    val_adv_evals += ve[:num_advs]
    print("Advantages board validation set dimensions: ", np.shape(val_adv_boards))

    num_wins = int(VAL_SIZE * 0.09375)
    vb, ve = read_eval_data("winning", WIN_VALIDATION_SET[0])
    val_win_boards, val_win_feats, val_win_evals = [], [], []
    val_win_boards += vb[:num_wins]
    val_win_evals += ve[:num_wins]
    print("Winning board validation set dimensions: ", np.shape(val_win_boards))

val_position_list = list(val_easy_puzzle_boards) + list(val_mid_puzzle_boards) + list(val_hard_puzzle_boards) + list(val_draw_boards) + list(val_adv_boards) + list(val_win_boards)
val_eval_list = list(val_easy_puzzle_evals) + list(val_mid_puzzle_evals) + list(val_hard_puzzle_evals) + list(val_draw_evals) + list(val_adv_evals) + list(val_win_evals)

print("Total board validation set dimensions: ", np.shape(val_position_list))


Dataset info:
- Drawish positions: 65536 * 150 = 9,830,400
- Advantageous positions: 65536 * 250 = 16,384,000
- Winning positions: 65536 * 112 = 7,340,032
- Total = 33,554,432

- Easy puzzles: 4096 * 213 =   872,448
- Mid puzzles : 4096 * 384 = 1,572,864
- Hard puzzles: 4096 * 152 =   622,592
- Total = 3,067,904
----------------
Target validation set size: 131072
Easy puzzle board validation set dimensions:  (4096, 1, 8, 8)
Mid puzzle board validation set dimensions:  (8192, 1, 8, 8)
Hard puzzle board validation set dimensions:  (8192, 1, 8, 8)
Draws board validation set dimensions:  (65536, 1, 8, 8)
Advantages board validation set dimensions:  (32768, 1, 8, 8)
Winning board validation set dimensions:  (12288, 1, 8, 8)
Total board validation set dimensions:  (131072, 1, 8, 8)


In [None]:
# Training data loader

class DataLoader:

    def __init__(self, dataset):
        self.dataset = dataset
        self.pointer = 0
        if self.dataset in [EASY_PUZZLES, MID_PUZZLES, HARD_PUZZLES]:
            self.length = 4096
        elif self.dataset in [DRAW, ADV, WIN]:
            self.length = 65536
        self.data = None

    def get_dataset(self):
        if self.dataset in [EASY_PUZZLES, MID_PUZZLES, HARD_PUZZLES]:
            rp = random.randint(0, 4)
            if self.dataset == EASY_PUZZLES:
                ri = random.randint(0, EASY_DISTRIBUTION[rp] - 1)
                while ri in EASY_PUZZLE_VALIDATION_SET[rp]:
                    ri = random.randint(0, EASY_DISTRIBUTION[rp] - 1)
                self.data = read_puzzle_data(rp, "easy", ri)
            elif self.dataset == MID_PUZZLES:
                ri = random.randint(0, MID_DISTRIBUTION[rp] - 1)
                while ri in MID_PUZZLE_VALIDATION_SET[rp]:
                    ri = random.randint(0, MID_DISTRIBUTION[rp] - 1)
                self.data = read_puzzle_data(rp, "mid", ri)
            elif self.dataset == HARD_PUZZLES:
                ri = random.randint(0, HARD_DISTRIBUTION[rp] - 1)
                while ri in HARD_PUZZLE_VALIDATION_SET[rp]:
                    ri = random.randint(0, HARD_DISTRIBUTION[rp] - 1)
                self.data = read_puzzle_data(rp, "hard", ri)

        elif self.dataset in [DRAW, ADV, WIN]:
            if self.dataset == DRAW:
                ri = random.randint(0, 149)
                while ri in DRAW_VALIDATION_SET:
                    ri = random.randint(0, 149)
                self.data = read_eval_data("draw", ri)
            elif self.dataset == ADV:
                ri = random.randint(0, 249)
                while ri in ADV_VALIDATION_SET:
                    ri = random.randint(0, 249)
                self.data = read_eval_data("adv", ri)
            elif self.dataset == WIN:
                ri = random.randint(0, 111)
                while ri in WIN_VALIDATION_SET:
                    ri = random.randint(0, 111)
                self.data = read_eval_data("winning", ri)

    def get_data(self, num):
        res_pos, res_eval = [], []
        if self.data is None:
            self.get_dataset()
        while num > 0:
            res_pos += self.data[0][self.pointer : num + self.pointer]
            res_eval += self.data[1][self.pointer : num + self.pointer]
            if num >= self.length - self.pointer:
                num -= (self.length - self.pointer)
                self.pointer = 0
                self.get_dataset()
            else:
                self.pointer += num
                num = 0
        return res_pos, res_eval


In [None]:
import time
import random
import os
import math
import statistics

model = ComplexModel(model_name).to(device=device)

checkpoint = -1
learning_rate = 0.00003
l2r = 0

from pytorch_optimizer import SOAP
optimizer = SOAP(model.parameters(), lr=learning_rate)
loss_fn = nn.MSELoss()


if checkpoint == 0:
    num_epoch = 0
    best_vloss = 1000
    model.train()
    with open(f"/content/drive/MyDrive/parrot/{model_name}_loss.csv", "w") as file:
        file.write("")
    print("Using new models.")

elif checkpoint == -1:
    print("Loading from last checkpoint.")
    checkpoint_file = torch.load(f"/content/drive/MyDrive/parrot/{model_name}.pickle", weights_only=True, map_location=device)
    model.load_state_dict(checkpoint_file["model_state_dict"])
    model.to(device=device)
    #optimizer.load_state_dict(checkpoint_file["optimizer_state_dict"])

    # Update learning rate in case it is changed midway.
    for g in optimizer.param_groups:
      g['lr'] = learning_rate

    num_epoch = checkpoint_file["epoch"] + 1
    best_vloss = checkpoint_file["best_loss"]
    model.train()
    print(f"Best validation loss at checkpoint: {best_vloss}")

elif checkpoint == -2:
    print("Switching to new dataset.")
    checkpoint_file = torch.load(f"/content/drive/MyDrive/parrot/{model_name}.pickle", weights_only=True, map_location=device)
    model.load_state_dict(checkpoint_file["model_state_dict"])
    model.to(device=device)

    # Update learning rate in case it is changed midway.
    for g in optimizer.param_groups:
      g['lr'] = learning_rate

    num_epoch = 0
    best_vloss = 1000
    model.train()


else:
    print(f"Loading from epoch {checkpoint}.")
    checkpoint_file = torch.load(f"/content/drive/MyDrive/parrot/{model_name}_{checkpoint}.pickle", weights_only=True, map_location=device)
    model.load_state_dict(checkpoint_file["model_state_dict"])
    model.to(device=device)
    optimizer.load_state_dict(checkpoint_file["optimizer_state_dict"])

    # Update learning rate in case it is changed midway.
    for g in optimizer.param_groups:
      g['lr'] = learning_rate

    num_epoch = checkpoint_file["epoch"] + 1
    best_vloss = checkpoint_file["best_loss"]
    model.train()
    print(f"Best validation loss at checkpoint: {best_vloss}")



print(model)
print("Batch size = ", BATCH_SIZE)
print("Validation size = ", VAL_SIZE)
print("L2 regularisation strength =", l2r)
print("Learning rate =", learning_rate)
print("Stage: ", STAGE)

def warmup_then_expo(epoch):
  if epoch < 58:
    return epoch / 58
  else:
    return (0.999 ** (epoch - 58))
#scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_then_expo)

tlr, tl, vl = [], [], []

running_mean = 0
M2 = 0
readings = 0

# Setup dataloaders
easy_puzzle_loader = DataLoader(EASY_PUZZLES)
mid_puzzle_loader = DataLoader(MID_PUZZLES)
hard_puzzle_loader = DataLoader(HARD_PUZZLES)
draw_loader = DataLoader(DRAW)
adv_loader = DataLoader(ADV)
win_loader = DataLoader(WIN)

for epoch in range(num_epoch, 800000):

    bl, el = [], []
    bl4, el4, bl5, el5, bl6, el6 = [], [], [], [], [], []
    start = time.time()

    # 6.25% easy puzzles, 12.5% mid puzzles, 12.5% hard puzzles, 6.25% mates, 18.75% openings, 21.875% midgames, 21.875% endgames
    if STAGE == POSITIONAL:
        bl0, el0 = easy_puzzle_loader.get_data(round(EPOCH_SIZE * 0.03125))
        bl1, el1 = mid_puzzle_loader.get_data(round(EPOCH_SIZE * 0.0625))
        bl2, el2 = hard_puzzle_loader.get_data(round(EPOCH_SIZE * 0.0625))
        bl3, el3 = draw_loader.get_data(round(EPOCH_SIZE * 0.5))
        bl4, el4 = adv_loader.get_data(round(EPOCH_SIZE * 0.25))
        bl5, el5 = win_loader.get_data(round(EPOCH_SIZE * 0.09375))

    bl = (bl0 + bl1 + bl2 + bl3 + bl4 + bl5 + bl6)
    el = (el0 + el1 + el2 + el3 + el4 + el5 + el6)

    #print("Shape of training set boards:", np.shape(bl))


    zipped = list(zip(bl, el))
    random.shuffle(zipped)
    bl, el = zip(*zipped)

    cl = 0
    clr = 0
    norms = []
    data_load_time = time.time() - start
    readings += 1
    delta = data_load_time - running_mean
    running_mean += delta / readings
    delta2 = data_load_time - running_mean
    M2 += (delta * delta2)
    variance = M2 / readings
    print(f"Data loaded in {data_load_time} seconds. Mean {running_mean}, Stdev {variance ** 0.5}")
    if (len(bl) != EPOCH_SIZE) or (len(el) != EPOCH_SIZE):
        print("Dataset is invalid.")
        continue
    else:
        start = time.time()
        for batch in range(EPOCH_SIZE // BATCH_SIZE):
            tb = torch.tensor(bl[batch * BATCH_SIZE : (batch + 1) * BATCH_SIZE], device=device, dtype=torch.float).reshape(BATCH_SIZE, 1, 8, 8)
            te = torch.tensor(el[batch * BATCH_SIZE : (batch + 1) * BATCH_SIZE], device=device, dtype=torch.float).reshape(BATCH_SIZE, 1)

            optimizer.zero_grad()
            out = model.forward(tb)
            loss = loss_fn(out, te)
            l = loss.item()
            cl += l
            if l2r != 0:
                l2 = sum(p.pow(2).sum() for p in model.parameters())
                loss += l2 * l2r
                llr = loss.item()
                clr += llr
            loss.backward()

            # Gradient monitoring and clipping
            #torch.nn.utils.clip_grad_norm_(model.parameters(), 0.03)

            total_norm = 0
            parameters = [p for p in model.parameters() if p.grad is not None and p.requires_grad]
            for p in parameters:
                param_norm = p.grad.detach().data.norm(2)
                total_norm += param_norm.item() ** 2
            total_norm = total_norm ** 0.5
            norms.append(total_norm)

            optimizer.step()
            tl.append(l)
            if l2r != 0:
                tlr.append(llr)
            completion = int(20 * batch / (EPOCH_SIZE // BATCH_SIZE)) + 1
            if l2r != 0:
                print("\r" + f"[{'-' * completion} {' ' * (20 - completion)}]     Loss {round(l, 6)}, Regularised loss {round(llr, 6)}", end = "")
            else:
                print("\r" + f"[{'-' * completion} {' ' * (20 - completion)}]     Loss {round(l, 6)}", end = "")
        if l2r != 0: print(f"\nEpoch {epoch}, loss {round(cl / (EPOCH_SIZE // BATCH_SIZE), 6)}, regularised loss {round(clr / (EPOCH_SIZE // BATCH_SIZE), 6)}, completed in {time.time() - start} seconds.")
        else: print(f"\nEpoch {epoch}, loss {round(cl / (EPOCH_SIZE // BATCH_SIZE), 6)}, completed in {time.time() - start} seconds.")

        with torch.no_grad():
            model.eval()
            tvl = 0
            start = time.time()
            for vbatch in range(VAL_SIZE // BATCH_SIZE):
                vp = torch.tensor(val_position_list[vbatch * BATCH_SIZE : (vbatch + 1) * BATCH_SIZE], device=device, dtype=torch.float).reshape(BATCH_SIZE, 1, 8, 8)
                ve = torch.tensor(val_eval_list[vbatch * BATCH_SIZE : (vbatch + 1) * BATCH_SIZE], device=device, dtype=torch.float).reshape(BATCH_SIZE, 1)
                out = model.forward(vp)
                loss = loss_fn(out, ve)
                tvl += loss.item()

            print("Validation loss", round(tvl / (VAL_SIZE // BATCH_SIZE), 6), "completed in", time.time() - start, "seconds.")
            vl.append(tvl / (VAL_SIZE // BATCH_SIZE))
            if (tvl / (VAL_SIZE // BATCH_SIZE)) < best_vloss:
                print("New best model!")
                best_vloss = tvl / (VAL_SIZE // BATCH_SIZE)
                torch.save(model.state_dict(), f"/content/drive/MyDrive/parrot/best_{model_name}.pickle")

            model.train()

        #before_lr = learning_rate
        #learning_rate *= 2
        #for g in optimizer.param_groups:
        #    g['lr'] = learning_rate
        #print(f"Epoch {epoch} : lr {before_lr} -> {learning_rate}")


        if save:
            if epoch % LONG_INTERVAL == 0:
                torch.save({"epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "best_loss": best_vloss}, f"/content/drive/MyDrive/parrot/{model_name}_{epoch}.pickle")
            if epoch % SHORT_INTERVAL == 0:
                torch.save({"epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "best_loss": best_vloss}, f"/content/drive/MyDrive/parrot/{model_name}.pickle")
            try:
                os.remove(f"/content/drive/MyDrive/parrot/{model_name}_{epoch - LONG_INTERVAL}.pickle")
            except:
                pass

        with open(f"/content/drive/MyDrive/parrot/{model_name}_loss.csv", "a") as file:
            file.write(f"{epoch}, {round(cl / (EPOCH_SIZE // BATCH_SIZE), 6)}, {round(tvl / (VAL_SIZE // BATCH_SIZE), 6)}, {round(max(norms), 6)}\n")


Loading from last checkpoint.
Best validation loss at checkpoint: 0.013924401515396312
ComplexModel(
  (conv_net): Sequential(
    (Conv 1): Conv2d(1, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (SE 1): SE_Block(
      (squeeze): AdaptiveAvgPool2d(output_size=1)
      (excitation): Sequential(
        (0): Linear(in_features=200, out_features=12, bias=False)
        (1): ReLU(inplace=True)
        (2): Linear(in_features=12, out_features=200, bias=False)
        (3): Sigmoid()
      )
    )
    (Batchnorm 1): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (Conv activation): Mish()
    (Conv 2): Conv2d(200, 190, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (SE 2): SE_Block(
      (squeeze): AdaptiveAvgPool2d(output_size=1)
      (excitation): Sequential(
        (0): Linear(in_features=190, out_features=11, bias=False)
        (1): ReLU(inplace=True)
        (2): Linear(in_features=11, out_features=190, bias=False)
        

KeyboardInterrupt: 