## Summary

- *hidden_size = 162*.
- *num_heads = 9*.
- *dropout = 0*.
- N=16.
- Add node and edge features (node features as 81-dim. embedding in `hidden_size`-dim space).
- Edgeconv: embed x and edge to half their size and keep row x only.
- Embed attention with `model_size == 63` and add `output_dim` attribute to attention.

----

## Install dependencies (Google Colab only)

In [1]:
try:
    import google.colab

    GOOGLE_COLAB = True
except ImportError:
    GOOGLE_COLAB = False

In [2]:
if GOOGLE_COLAB:
    !pip install --upgrade torch-scatter
    !pip install --upgrade torch-sparse
    !pip install --upgrade torch-cluster
    !pip install --upgrade torch-spline-conv
    !pip install torch-geometric

In [3]:
if GOOGLE_COLAB:
    !pip install git+https://gitlab.com/ostrokach/proteinsolver.git

## Imports

In [4]:
import atexit
import csv
import itertools
import tempfile
import time
import uuid
import warnings
from collections import deque
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow
import torch
import torch.nn as nn
import torch.utils.tensorboard
from torch import optim
from torch_geometric.data import DataLoader

In [5]:
import proteinsolver
import proteinsolver.datasets



In [6]:
%load_ext autoreload
%autoreload 2

In [7]:
assert torch.cuda.is_available()

## Parameters

In [8]:
device = torch.device("cuda:0")

In [9]:
DATA_ROOT = Path(tempfile.gettempdir())
DATA_ROOT = next(Path("/localscratch/").glob("strokach.*")).joinpath("sudoku")
DATA_ROOT.mkdir(exist_ok=True)
DATA_ROOT

PosixPath('/localscratch/strokach.3687035.0/sudoku')

In [10]:
UNIQUE_ID = "4264619e"
CONTINUE_PREVIOUS = True

In [11]:
try:
    NOTEBOOK_PATH
    UNIQUE_PATH
except NameError:
    NOTEBOOK_PATH = Path("sudoku_train").resolve()
    NOTEBOOK_PATH.mkdir(exist_ok=True)
    if UNIQUE_ID is None:
        UNIQUE_ID = uuid.uuid4().hex[:8]
        exist_ok = False
    else:
        exist_ok = True
    UNIQUE_PATH = NOTEBOOK_PATH.joinpath(UNIQUE_ID)
    UNIQUE_PATH.mkdir(exist_ok=exist_ok)
NOTEBOOK_PATH, UNIQUE_PATH

(PosixPath('/lustre04/scratch/strokach/workspace/proteinsolver/notebooks/sudoku_train'),
 PosixPath('/lustre04/scratch/strokach/workspace/proteinsolver/notebooks/sudoku_train/4264619e'))

In [12]:
DATAPKG_DATA_DIR = Path(f"~/datapkg_output_dir").expanduser().resolve()
DATAPKG_DATA_DIR

PosixPath('/scratch/strokach/datapkg_output_dir')

In [13]:
proteinsolver.settings.data_url = DATAPKG_DATA_DIR.as_posix()
proteinsolver.settings.data_url

'/scratch/strokach/datapkg_output_dir'

## Datasets

In [14]:
datasets = {}

### `SudokuDataset`

In [15]:
for i in range(10):
    dataset_name = f"sudoku_train_{i}"
    datasets[dataset_name] = proteinsolver.datasets.SudokuDataset4(
        root=DATA_ROOT.joinpath(dataset_name), subset=f"train_{i}"
    )

In [16]:
datasets["sudoku_valid_0"] = proteinsolver.datasets.SudokuDataset4(
    root=DATA_ROOT.joinpath("sudoku_valid_0"), subset=f"valid_0"
)

In [17]:
datasets["sudoku_valid_old"] = proteinsolver.datasets.SudokuDataset2(
    root=DATA_ROOT.joinpath("sudoku_valid_old"),
    data_url=DATAPKG_DATA_DIR.joinpath(
        "deep-protein-gen", "sudoku", "sudoku_valid.csv.gz"
    ).as_posix(),
)

Processing...
Done!


# Models

In [18]:
%%file {UNIQUE_PATH}/model.py
import copy
import tempfile

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.container import ModuleList
from torch_geometric.nn.inits import reset
from torch_geometric.utils import (
    add_self_loops,
    remove_self_loops,
    scatter_,
    to_dense_adj,
    to_dense_batch,
)


class EdgeConvMod(torch.nn.Module):
    def __init__(self, nn, aggr="max"):
        super().__init__()
        self.nn = nn
        self.aggr = aggr
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn)

    def forward(self, x, edge_index, edge_attr=None):
        """"""
        row, col = edge_index
        x = x.unsqueeze(-1) if x.dim() == 1 else x

        # TODO: Try -x[col] instead of x[col] - x[row]
        if edge_attr is None:
            out = torch.cat([x[row], x[col]], dim=-1)
        else:
            out = torch.cat([x[row], x[col], edge_attr], dim=-1)
        out = self.nn(out)
        x = scatter_(self.aggr, out, row, dim_size=x.size(0))

        return x, out

    def __repr__(self):
        return "{}(nn={})".format(self.__class__.__name__, self.nn)


class EdgeConvBatch(nn.Module):
    def __init__(self, gnn, hidden_size, batch_norm=True, dropout=0.2):
        super().__init__()

        self.gnn = gnn

        x_post_modules = []
        edge_attr_post_modules = []

        if batch_norm is not None:
            x_post_modules.append(nn.LayerNorm(hidden_size))
            edge_attr_post_modules.append(nn.LayerNorm(hidden_size))

        if dropout:
            x_post_modules.append(nn.Dropout(dropout))
            edge_attr_post_modules.append(nn.Dropout(dropout))

        self.x_postprocess = nn.Sequential(*x_post_modules)
        self.edge_attr_postprocess = nn.Sequential(*edge_attr_post_modules)

    def forward(self, x, edge_index, edge_attr=None):
        x, edge_attr = self.gnn(x, edge_index, edge_attr)
        x = self.x_postprocess(x)
        edge_attr = self.edge_attr_postprocess(edge_attr)
        return x, edge_attr


def get_graph_conv_layer(input_size, hidden_size, output_size):
    mlp = nn.Sequential(
        #
        nn.Linear(input_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, output_size),
    )
    gnn = EdgeConvMod(nn=mlp, aggr="add")
    graph_conv = EdgeConvBatch(gnn, output_size, batch_norm=True, dropout=0.2)
    return graph_conv


class MyEdgeConv(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.embed_x = nn.Linear(hidden_size, hidden_size // 2)
        self.embed_edge = nn.Linear(hidden_size, hidden_size // 2)

        self.nn = nn.Sequential(
            #
            nn.Linear(hidden_size, hidden_size * 2),
            nn.ReLU(),
            nn.Linear(hidden_size * 2, hidden_size),
        )
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn)

    def forward(self, x, edge_index, edge_attr=None):
        """"""
        row, col = edge_index
        x = x.unsqueeze(-1) if x.dim() == 1 else x

        x_in = self.embed_x(x)
        edge_attr_in = self.embed_edge(edge_attr)
        x_edge_attr_in = torch.cat([x_in[row], edge_attr_in], dim=-1)
        edge_attr_out = self.nn(x_edge_attr_in)

        #         if edge_attr is None:
        #             out = torch.cat([x[row], x[col]], dim=-1)
        #         else:
        #             out = torch.cat([x[row], x[col], edge_attr], dim=-1)
        #         edge_attr_out = self.nn(out)

        return edge_attr_out

    def __repr__(self):
        return "{}(nn={})".format(self.__class__.__name__, self.nn)


class MyAttn(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        model_size = 63
        self.embed_x = nn.Linear(hidden_size, model_size)
        self.attn = MultiheadAttention(
            embed_dim=model_size,
            output_dim=hidden_size,
            kdim=hidden_size,
            vdim=hidden_size,
            num_heads=9,
            dropout=0,
            bias=True,
        )
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.attn)

    def forward(self, x, edge_index, edge_attr, batch):
        """"""
        query = self.embed_x(x).unsqueeze(0)
        key = to_dense_adj(edge_index, batch=batch, edge_attr=edge_attr).squeeze(0)

        adjacency = to_dense_adj(edge_index, batch=batch).squeeze(0)
        key_padding_mask = adjacency == 0
        key_padding_mask[torch.eye(key_padding_mask.size(0)).to(torch.bool)] = 0
        #         attn_mask = torch.zeros_like(key)
        #         attn_mask[mask] = -float("inf")

        x_out, _ = self.attn(query, key, key, key_padding_mask=key_padding_mask)
        #         x_out = torch.where(torch.isnan(x_out), torch.zeros_like(x_out), x_out)
        x_out = x_out.squeeze(0)
        assert (x_out == x_out).all().item()
        assert x.shape == x_out.shape, (x.shape, x_out.shape)
        return x_out

    def __repr__(self):
        return "{}(nn={})".format(self.__class__.__name__, self.nn)


class Net(nn.Module):
    def __init__(
        self, x_input_size, adj_input_size, hidden_size, output_size, batch_size=1
    ):
        super().__init__()

        x_labels = torch.arange(81, dtype=torch.long)
        self.register_buffer("x_labels", x_labels)

        self.register_buffer("batch", torch.zeros(10000, dtype=torch.int64))

        self.embed_x = nn.Sequential(nn.Embedding(x_input_size, hidden_size), nn.ReLU())
        self.embed_x_labels = nn.Sequential(nn.Embedding(81, hidden_size), nn.ReLU())
        self.finalize_x = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size), nn.LayerNorm(hidden_size)
        )

        if adj_input_size:
            self.embed_adj = nn.Sequential(
                nn.Linear(adj_input_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.LayerNorm(hidden_size),
                #                 nn.ELU(),
            )
        else:
            self.embed_adj = None

        N = 16
        self.N = N

        norm = nn.LayerNorm(hidden_size)
        self.x_norms_0 = _get_clones(norm, N)
        self.adj_norms_0 = _get_clones(norm, N)
        self.x_norms_1 = _get_clones(norm, N)
        self.adj_norms_1 = _get_clones(norm, N)

        edge_conv = MyEdgeConv(hidden_size)
        self.edge_convs = _get_clones(edge_conv, N)

        attn = MyAttn(hidden_size)
        self.attns = _get_clones(attn, N)

        self.dropout = nn.Dropout(0.1)

        self.linear_out = nn.Linear(hidden_size, output_size)

    def forward(self, x, edge_index, edge_attr):

        x = self.embed_x(x)
        x_labels = self.embed_x_labels(self.x_labels)
        x_labels = x_labels.repeat(x.size(0) // x_labels.size(0), 1)
        x = torch.cat([x, x_labels], dim=1)
        x = self.finalize_x(x)

        edge_attr = self.embed_adj(edge_attr)

        for i in range(self.N):
            edge_attr_out = self.edge_convs[i](x, edge_index, edge_attr)
            edge_attr = edge_attr + self.dropout(edge_attr_out)
            edge_attr = self.adj_norms_1[i](edge_attr)

            x_out = self.attns[i](
                x,
                edge_index,
                self.adj_norms_0[i](edge_attr_out),
                self.batch[: x.size(0)],
            )
            x = x + self.dropout(x_out)
            x = self.x_norms_1[i](x)

        x = self.linear_out(x)

        return x


class MultiheadAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        output_dim,
        num_heads,
        dropout=0.0,
        bias=True,
        add_bias_kv=False,
        add_zero_attn=False,
        kdim=None,
        vdim=None,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"

        if self._qkv_same_embed_dim is False:
            self.q_proj_weight = nn.Parameter(torch.Tensor(embed_dim, embed_dim))
            self.k_proj_weight = nn.Parameter(torch.Tensor(embed_dim, self.kdim))
            self.v_proj_weight = nn.Parameter(torch.Tensor(embed_dim, self.vdim))
        else:
            self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))

        if bias:
            self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))
        else:
            self.register_parameter("in_proj_bias", None)
        self.out_proj = nn.Linear(embed_dim, output_dim, bias=bias)

        if add_bias_kv:
            self.bias_k = nn.Parameter(torch.empty(1, 1, embed_dim))
            self.bias_v = nn.Parameter(torch.empty(1, 1, embed_dim))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self._reset_parameters()

    def _reset_parameters(self):
        if self._qkv_same_embed_dim:
            nn.init.xavier_uniform_(self.in_proj_weight)
        else:
            nn.init.xavier_uniform_(self.q_proj_weight)
            nn.init.xavier_uniform_(self.k_proj_weight)
            nn.init.xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            nn.init.constant_(self.in_proj_bias, 0.0)
            nn.init.constant_(self.out_proj.bias, 0.0)
        if self.bias_k is not None:
            nn.init.xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            nn.init.xavier_normal_(self.bias_v)

    def forward(
        self,
        query,
        key,
        value,
        key_padding_mask=None,
        need_weights=True,
        attn_mask=None,
    ):
        if hasattr(self, "_qkv_same_embed_dim") and self._qkv_same_embed_dim is False:
            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                None,  # set self.in_proj_weight = None
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight,
                k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight,
            )
        else:
            if not hasattr(self, "_qkv_same_embed_dim"):
                warnings.warn(
                    "A new version of MultiheadAttention module has been implemented. \
                    Please re-train your model with the new module",
                    UserWarning,
                )

            return F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
            )


def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])

Overwriting /lustre04/scratch/strokach/workspace/proteinsolver/notebooks/sudoku_train/4264619e/model.py


In [19]:
%run {UNIQUE_PATH}/model.py

In [20]:
%%file {UNIQUE_PATH}/stats.py
import atexit
import csv
import time
import warnings

import numpy as np


class Stats:
    epoch: int
    step: int
    batch_size: int
    echo: bool
    total_loss: float
    num_correct_preds: int
    num_preds: int
    num_correct_preds_missing: int
    num_preds_missing: int
    num_correct_preds_missing_valid: int
    num_preds_missing_valid: int
    num_correct_preds_missing_valid_old: int
    num_preds_missing_valid_old: int
    start_time: float

    def __init__(
        self, *, epoch=0, step=0, batch_size=1, filename=None, echo=True, tb_writer=None
    ):
        self.epoch = epoch
        self.step = step
        self.batch_size = batch_size
        self.echo = echo
        self.tb_writer = tb_writer
        self.prev = {}
        self.init_parameters()

        if filename:
            self.filehandle = open(filename, "wt", newline="")
            self.writer = csv.DictWriter(
                self.filehandle, list(self.stats.keys()), dialect="unix"
            )
            atexit.register(self.filehandle.close)
        else:
            self.filehandle = None
            self.writer = None

    def init_parameters(self):
        self.num_steps = 0
        self.total_loss = 0
        self.num_correct_preds = 0
        self.num_preds = 0
        self.num_correct_preds_missing = 0
        self.num_preds_missing = 0
        self.num_correct_preds_missing_valid = 0
        self.num_preds_missing_valid = 0
        self.num_correct_preds_missing_valid_old = 0
        self.num_preds_missing_valid_old = 0
        self.start_time = time.perf_counter()

    def reset_parameters(self):
        self.prev = self.stats
        self.init_parameters()

    @property
    def header(self):
        return "".join(to_fixed_width(self.stats.keys()))

    @property
    def row(self):
        return "".join(to_fixed_width(self.stats.values(), 4))

    @property
    def stats(self):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            return {
                "epoch": self.epoch,
                "step": self.step,
                "datapoint": self.datapoint,
                "avg_loss": np.float64(1) * self.total_loss / self.num_steps,
                "accuracy": np.float64(1) * self.num_correct_preds / self.num_preds,
                "accuracy_m": np.float64(1)
                * self.num_correct_preds_missing
                / self.num_preds_missing,
                "accuracy_mv": self.accuracy_mv,
                "accuracy_mv_old": self.accuracy_mv_old,
                "time_elapsed": time.perf_counter() - self.start_time,
            }

    @property
    def accuracy_mv(self):
        return (
            np.float64(1)
            * self.num_correct_preds_missing_valid
            / self.num_preds_missing_valid
        )

    @property
    def accuracy_mv_old(self):
        return (
            np.float64(1)
            * self.num_correct_preds_missing_valid_old
            / self.num_preds_missing_valid_old
        )

    @property
    def datapoint(self):
        return self.step * self.batch_size

    def write_header(self):
        if self.echo:
            print(self.header)
        if self.writer is not None:
            self.writer.writeheader()

    def write_row(self):
        if self.echo:
            print(self.row, end="\r")
        if self.writer is not None:
            self.writer.writerow(self.stats)
            self.filehandle.flush()
        if self.tb_writer is not None:
            stats = self.stats
            datapoint = stats.pop("datapoint")
            for key, value in stats.items():
                self.tb_writer.add_scalar(key, value, datapoint)
            self.tb_writer.flush()


def to_fixed_width(lst, precision=None):
    lst = [round(l, precision) if isinstance(l, float) else l for l in lst]
    return [f"{l: <18}" for l in lst]

Overwriting /lustre04/scratch/strokach/workspace/proteinsolver/notebooks/sudoku_train/4264619e/stats.py


In [21]:
%run {UNIQUE_PATH}/stats.py

In [22]:
def get_stats_on_missing(x, y, output):
    mask = (x == 9).squeeze()
    if not mask.any():
        return 0.0, 0.0
    output_missing = output[mask]
    _, predicted_missing = torch.max(output_missing.data, 1)
    return (predicted_missing == y[mask]).sum().item(), len(predicted_missing)



from contextlib import contextmanager
@contextmanager
def eval_net(net: nn.Module):
    training = net.training
    try:
        net.train(False)
        yield
    finally:
        net.train(training)

In [23]:
batch_size = 6
info_size = 200
hidden_size = 162
checkpoint_size = 100_000

batch_size, info_size, hidden_size

(6, 200, 162)

In [24]:
tensorboard_path = NOTEBOOK_PATH.joinpath("runs", UNIQUE_PATH.name)
tensorboard_path.mkdir(exist_ok=True)
tensorboard_path

PosixPath('/lustre04/scratch/strokach/workspace/proteinsolver/notebooks/sudoku_train/runs/4264619e')

In [25]:
last_epoch = None
last_step = None
last_datapoint = None
last_state_file = None

if CONTINUE_PREVIOUS:
    for path in UNIQUE_PATH.glob("*.state"):
        e, s, d, amv = path.name.split("-")
        datapoint = int(d.strip("d"))
        if last_datapoint is None or datapoint >= last_datapoint:
            last_datapoint = datapoint
            last_epoch = int(e.strip("e"))
            last_step = int(s.strip("s"))
            last_state_file = path
        
last_epoch, last_step, last_datapoint, last_state_file

(0,
 66667,
 400002,
 PosixPath('/lustre04/scratch/strokach/workspace/proteinsolver/notebooks/sudoku_train/4264619e/e0-s66667-d400002-amv07144.state'))

In [26]:
net = Net(
    x_input_size=13,
    adj_input_size=3,
    hidden_size=hidden_size,
    output_size=9,
    batch_size=batch_size,
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", verbose=True)

In [27]:
if CONTINUE_PREVIOUS:
    net.load_state_dict(torch.load(last_state_file))
    print("Loaded network state file.")

Loaded network state file.


In [28]:
stats = Stats(
    epoch=last_epoch if CONTINUE_PREVIOUS else 0,
    step=last_step if CONTINUE_PREVIOUS else 0,
    batch_size=batch_size,
    filename=UNIQUE_PATH.joinpath("training.log"),
    echo=True,
    tb_writer=torch.utils.tensorboard.writer.SummaryWriter(
        log_dir=tensorboard_path.with_suffix(f".xxx"),
        purge_step=(last_datapoint if CONTINUE_PREVIOUS else None),
    ),
)
stats.write_header()

epoch             step              datapoint         avg_loss          accuracy          accuracy_m        accuracy_mv       accuracy_mv_old   time_elapsed      


In [None]:
datasets[f"sudoku_valid_0"].reset()
valid_0_data = list(itertools.islice(datasets[f"sudoku_valid_0"], 300))
valid_old_data = list(itertools.islice(datasets[f"sudoku_valid_old"], 300))
tmp_data = valid_0_data[0].to(device)
edge_index = tmp_data.edge_index
edge_attr = tmp_data.edge_attr

net = net.train()
for epoch in range(stats.epoch, 100_000):
    stats.epoch = epoch
    train_dataloader = DataLoader(
        datasets[f"sudoku_train_{epoch}"],
        shuffle=False,
        num_workers=1,
        batch_size=batch_size,
        drop_last=True,
    )
    for data in train_dataloader:
        stats.step += 1
        if CONTINUE_PREVIOUS and stats.step <= last_step:
            continue

        optimizer.zero_grad()

        data = data.to(device)
        output = net(data.x, data.edge_index, data.edge_attr)

        loss = criterion(output, data.y)
        loss.backward()

        stats.total_loss += loss.detach().item()
        stats.num_steps += 1

        # Accuracy for all
        _, predicted = torch.max(output.data, 1)
        stats.num_correct_preds += (predicted == data.y).sum().item()
        stats.num_preds += len(predicted)

        # Accuracy for missing only
        num_correct, num_total = get_stats_on_missing(data.x, data.y, output)
        stats.num_correct_preds_missing += num_correct
        stats.num_preds_missing += num_total

        optimizer.step()

        if (stats.datapoint % info_size) < batch_size:
            for j, data in enumerate(valid_0_data):
                data = data.to(device)

                with torch.no_grad() and eval_net(net):
                    output = net(data.x, data.edge_index, data.edge_attr)

                num_correct, num_total = get_stats_on_missing(data.x, data.y, output)
                stats.num_correct_preds_missing_valid += num_correct
                stats.num_preds_missing_valid += num_total

            for j, data in enumerate(valid_old_data):
                data = data.to(device)

                with torch.no_grad() and eval_net(net):
                    output = net(data.x, data.edge_index, edge_attr)

                num_correct, num_total = get_stats_on_missing(data.x, data.y, output)
                stats.num_correct_preds_missing_valid_old += num_correct
                stats.num_preds_missing_valid_old += num_total

            stats.write_row()
            stats.reset_parameters()

        if (stats.datapoint % checkpoint_size) < batch_size:
            output_filename = (
                f"e{stats.epoch}-s{stats.step}-d{stats.datapoint}"
                f"-amv{str(round(stats.prev['accuracy_mv'], 4)).replace('.', '')}.state"
            )
            torch.save(net.state_dict(), UNIQUE_PATH.joinpath(output_filename))

    scheduler.step(stats.prev["accuracy_mv"])
    output_filename = (
        f"e{stats.epoch}-s{stats.step}-d{stats.datapoint}"
        f"-amv{str(round(stats.prev['accuracy_mv'], 4)).replace('.', '')}.state"
    )
    torch.save(net.state_dict(), UNIQUE_PATH.joinpath(output_filename))

0                 72534             435204            0.4503            0.8049            0.7045            0.7195            1.0               35.2388           