## Summary

## Install dependencies (Google Colab only)

In [1]:
try:
    import google.colab

    GOOGLE_COLAB = True
except ImportError:
    GOOGLE_COLAB = False

In [None]:
if GOOGLE_COLAB:
    !pip install --upgrade torch-scatter
    !pip install --upgrade torch-sparse
    !pip install --upgrade torch-cluster
    !pip install --upgrade torch-spline-conv
    !pip install torch-geometric

In [None]:
if GOOGLE_COLAB:
    !pip install git+https://gitlab.com/ostrokach/proteinsolver.git

## Imports

In [1]:
import atexit
import csv
import tempfile
import time
import warnings
from collections import deque
from pathlib import Path

import numpy as np

import matplotlib.pyplot as plt
import pandas as pd
import pyarrow
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [2]:
import proteinsolver
import proteinsolver.datasets

ModuleNotFoundError: No module named 'proteinsolver'

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
assert torch.cuda.is_available()

# Datasets

In [None]:
DATA_ROOT = Path(tempfile.gettempdir())
DATA_ROOT = Path("/home/strokach/ml_data")
DATA_ROOT.mkdir(exist_ok=True)
DATA_ROOT

In [None]:
NOTEBOOK_NAME = "sudoku_train"
NOTEBOOK_PATH = Path(NOTEBOOK_NAME)
NOTEBOOK_PATH.mkdir(exist_ok=True)
NOTEBOOK_PATH

## `SudokuDataset`

In [None]:
sudoku_dataset_train = proteinsolver.datasets.SudokuDataset2(root=DATA_ROOT.joinpath("sudoku_train"), subset="sudoku_train")

In [None]:
sudoku_dataset_valid = proteinsolver.datasets.SudokuDataset2(root=DATA_ROOT.joinpath("sudoku_valid"), subset="sudoku_valid")

## `TUDataset`

In [None]:
from torch_geometric.datasets import TUDataset

In [None]:
tu_dataset = TUDataset(root=tempfile.gettempdir() + '/ENZYMES', name='ENZYMES')

In [None]:
tu_dataset[0]

# Models

In [None]:
import torch
import torch.nn.functional as F
from torch import optim
from torch_geometric.data import DataLoader
from torch_geometric.nn import ChebConv, EdgeConv, GATConv, GCNConv
from torch_geometric.nn.inits import reset
from torch_geometric.utils import add_self_loops, remove_self_loops, scatter_

In [None]:
class EdgeConvMod(torch.nn.Module):
    def __init__(self, nn, aggr="max"):
        super().__init__()
        self.nn = nn
        self.aggr = aggr
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn)

    def forward(self, x, edge_index, edge_attr=None):
        """"""
        row, col = edge_index
        x = x.unsqueeze(-1) if x.dim() == 1 else x

        # TODO: Try -x[col] instead of x[col] - x[row]
        if edge_attr is None:
            out = torch.cat([x[row], x[col] - x[row]], dim=-1)
        else:
            out = torch.cat([x[row], x[col] - x[row], edge_attr], dim=-1)
        out = self.nn(out)
        x = scatter_(self.aggr, out, row, dim_size=x.size(0))

        return x, out

    def __repr__(self):
        return "{}(nn={})".format(self.__class__.__name__, self.nn)

In [None]:
class EdgeConvBatch(nn.Module):
    def __init__(self, gnn, hidden_size, batch_norm=True, dropout=0.2):
        super().__init__()

        self.gnn = gnn

        x_post_modules = []
        edge_attr_post_modules = []

        if batch_norm is not None:
            x_post_modules.append(nn.BatchNorm1d(hidden_size))
            edge_attr_post_modules.append(nn.BatchNorm1d(hidden_size))

        if dropout:
            x_post_modules.append(nn.Dropout(dropout))
            edge_attr_post_modules.append(nn.Dropout(dropout))

        self.x_postprocess = nn.Sequential(*x_post_modules)
        self.edge_attr_postprocess = nn.Sequential(*edge_attr_post_modules)

    def forward(self, x, edge_index, edge_attr=None):
        x, edge_attr = self.gnn(x, edge_index, edge_attr)
        x = self.x_postprocess(x)
        edge_attr = self.edge_attr_postprocess(edge_attr)
        return x, edge_attr

In [None]:
def get_graph_conv_layer(input_size, hidden_size, output_size):
    mlp = nn.Sequential(
        #
        nn.Linear(input_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, output_size),
    )
    gnn = EdgeConvMod(nn=mlp, aggr="add")
    graph_conv = EdgeConvBatch(gnn, output_size, batch_norm=True, dropout=0.2)
    return graph_conv

In [None]:
class Net(nn.Module):
    def __init__(self, x_input_size, adj_input_size, hidden_size, output_size):
        super().__init__()

        self.embed_x = nn.Sequential(
            nn.Embedding(x_input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
        )

        if adj_input_size:
            self.embed_adj = nn.Sequential(
                nn.BatchNorm1d(adj_input_size),
                nn.Linear(adj_input_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.BatchNorm1d(hidden_size),
                nn.ReLU(),
            )
        else:
            self.embed_adj = None

        self.graph_conv_1 = get_graph_conv_layer((2 + bool(adj_input_size)) * hidden_size, 2 * hidden_size, hidden_size)
        self.graph_conv_2 = get_graph_conv_layer(3 * hidden_size, 2 * hidden_size, hidden_size)
        self.graph_conv_3 = get_graph_conv_layer(3 * hidden_size, 2 * hidden_size, hidden_size)
        self.graph_conv_4 = get_graph_conv_layer(3 * hidden_size, 2 * hidden_size, hidden_size)
        self.linear_out = nn.Linear(hidden_size, output_size)

    def forward(self, x, edge_index, edge_attr=None):

        x = self.embed_x(x)
        edge_index, _ = remove_self_loops(edge_index)  # We should remove self loops in this case!
        edge_attr = self.embed_adj(edge_attr) if edge_attr is not None else None

        x_out, edge_attr_out = self.graph_conv_1(x, edge_index, edge_attr)
        x += x_out
        edge_attr = (edge_attr + edge_attr_out) if edge_attr is not None else edge_attr_out

        x = F.relu(x)
        edge_attr = F.relu(edge_attr)
        x_out, edge_attr_out = self.graph_conv_2(x, edge_index, edge_attr)
        x += x_out
        edge_attr += edge_attr_out

        x = F.relu(x)
        edge_attr = F.relu(edge_attr)
        x_out, edge_attr_out = self.graph_conv_3(x, edge_index, edge_attr)
        x += x_out
        edge_attr += edge_attr_out

        x = F.relu(x)
        edge_attr = F.relu(edge_attr)
        x_out, edge_attr_out = self.graph_conv_4(x, edge_index, edge_attr)
        x += x_out
        edge_attr += edge_attr_out

        x = self.linear_out(x)
        return x

In [None]:
def to_fixed_width(lst, precision=None):
    lst = [round(l, precision) if isinstance(l, float) else l for l in lst]
    return [f"{l: <18}" for l in lst]


class Stats:
    epoch: int
    step: int
    batch_size: int
    echo: bool
    total_loss: float
    num_correct_preds: int
    num_preds: int
    num_correct_preds_missing: int
    num_preds_missing: int
    num_correct_preds_missing_valid: int
    num_preds_missing_valid: int
    start_time: float

    def __init__(self, *, epoch=0, step=0, batch_size=1, filename=None, echo=True):
        self.epoch = epoch
        self.step = step
        self.batch_size = batch_size
        self.echo = echo
        self.stats_columns = [
            "epoch",
            "step",
            "datapoint",
            "avg_loss",
            "accuracy",
            "accuracy_m",
            "accuracy_mv",
            "time_elapsed",
        ]
        self.reset_parameters()

        if filename:
            self.filehandle = open(filename, "wt", newline="")
            self.writer = csv.DictWriter(self.filehandle, list(self.stats.keys()), dialect="unix")
            self.writer.writeheader()
            atexit.register(self.filehandle.close)
        else:
            self.filehandle = None
            self.writer = None

    def reset_parameters(self):
        try:
            self.prev = self.stats
        except AttributeError:
            self.prev = {}
        self.num_steps = 0
        self.total_loss = 0
        self.num_correct_preds = 0
        self.num_preds = 0
        self.num_correct_preds_missing = 0
        self.num_preds_missing = 0
        self.num_correct_preds_missing_valid = 0
        self.num_preds_missing_valid = 0
        self.start_time = time.perf_counter()

    @property
    def header(self):
        return "".join(to_fixed_width(self.stats.keys()))

    @property
    def row(self):
        return "".join(to_fixed_width(self.stats.values(), 4))

    @property
    def stats(self):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            return {k: getattr(self, k) for k in self.stats_columns}

    @property
    def datapoint(self):
        return self.step * self.batch_size

    @property
    def avg_loss(self):
        return np.float64(1) * self.total_loss / self.num_steps

    @property
    def accuracy(self):
        return np.float64(1) * self.num_correct_preds / self.num_preds

    @property
    def accuracy_m(self):
        return np.float64(1) * self.num_correct_preds_missing / self.num_preds_missing

    @property
    def accuracy_mv(self):
        return np.float64(1) * self.num_correct_preds_missing_valid / self.num_preds_missing_valid

    @property
    def time_elapsed(self):
        return time.perf_counter() - self.start_time

    def write_header(self):
        if self.echo:
            print(self.header)
        if self.writer is not None:
            self.writer.writeheader()

    def write_row(self):
        if self.echo:
            print(self.row)
        if self.writer is not None:
            self.writer.writerow(self.stats)

In [None]:
def get_stats_on_missing(x, y, output):
    mask = (x == 9).squeeze()
    if not mask.any():
        return 0.0, 0.0
    output_missing = output[mask]
    _, predicted_missing = torch.max(output_missing.data, 1)
    return (predicted_missing == y[mask]).sum().item(), len(predicted_missing)



from contextlib import contextmanager
@contextmanager
def eval_net(net: nn.Module):
    training = net.training
    try:
        net.train(False)
        yield
    finally:
        net.train(training)

In [None]:
batch_size = 128
info_size = 2_000

from torch_geometric.data import DataLoader
dataloaders = {
    "train": DataLoader(sudoku_dataset_train, shuffle=True, num_workers=4, batch_size=batch_size, drop_last=True),
    "valid": DataLoader(
        sudoku_dataset_valid[:300], shuffle=False, num_workers=4, batch_size=batch_size, drop_last=False
    ),
}

In [None]:
continue_previous = False

if not continue_previous:
    net = Net(x_input_size=10, adj_input_size=None, hidden_size=128, output_size=9).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", verbose=True)

    stats = Stats(epoch=0, step=0, batch_size=batch_size, filename=NOTEBOOK_PATH / "training.log", echo=True)
    stats.write_header()

net = net.train()
for epoch in range(stats.epoch + 1 if continue_previous else 0, 100_000):
    stats.epoch = epoch
    for data in dataloaders["train"]:
        stats.step += 1
        optimizer.zero_grad()

        data = data.to(device)
        output = net(data.x, data.edge_index, data.edge_attr if hasattr(data, "edge_attr") else None)

        loss = criterion(output, data.y)
        loss.backward()

        stats.total_loss += loss.detach().item()
        stats.num_steps += 1

        # Accuracy for all
        _, predicted = torch.max(output.data, 1)
        stats.num_correct_preds += (predicted == data.y).sum().item()
        stats.num_preds += len(predicted)

        # Accuracy for missing only
        num_correct, num_total = get_stats_on_missing(data.x, data.y, output)
        stats.num_correct_preds_missing += num_correct
        stats.num_preds_missing += num_total

        optimizer.step()

        if (stats.datapoint % info_size) < batch_size:
            for j, data in enumerate(dataloaders["valid"]):
                data = data.to(device)

                with torch.no_grad() and eval_net(net):
                    output = net(data.x, data.edge_index, data.edge_attr if hasattr(data, "edge_attr") else None)

                num_correct, num_total = get_stats_on_missing(data.x, data.y, output)
                stats.num_correct_preds_missing_valid += num_correct
                stats.num_preds_missing_valid += num_total

            stats.write_row()
            stats.reset_parameters()

    scheduler.step(stats.prev["accuracy_mv"])
    output_filename = (
        f"e{stats.epoch}-s{stats.step}-d{stats.datapoint}"
        f"-amv{str(round(stats.prev['accuracy_mv'], 4)).replace('.', '')}.state"
    )
    torch.save(net.state_dict(), NOTEBOOK_PATH.joinpath(output_filename))