## Summary


----

## Install dependencies (Google Colab only)

In [1]:
try:
    import google.colab
    GOOGLE_COLAB = True
except ImportError:
    GOOGLE_COLAB = False

In [2]:
if GOOGLE_COLAB:
    !pip install --upgrade torch-scatter

In [3]:
if GOOGLE_COLAB:
    !pip install --upgrade torch-sparse

In [4]:
if GOOGLE_COLAB:
    !pip install --upgrade torch-cluster

In [5]:
if GOOGLE_COLAB:
    !pip install --upgrade torch-spline-conv

In [6]:
if GOOGLE_COLAB:
    !pip install torch-geometric

In [7]:
if GOOGLE_COLAB:
    !pip install git+https://gitlab.com/ostrokach/proteinsolver.git

## Imports

In [26]:
import atexit
import csv
import tempfile
import time
import uuid
import warnings
from collections import deque
from contextlib import contextmanager
from pathlib import Path

import numpy as np

import matplotlib.pyplot as plt
import pandas as pd
import pyarrow
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from torch_geometric.data import DataLoader
from torch_geometric.nn import ChebConv, EdgeConv, GATConv, GCNConv
from torch_geometric.nn.inits import reset
from torch_geometric.utils import add_self_loops, remove_self_loops, scatter_

In [9]:
import proteinsolver
import proteinsolver.datasets



In [10]:
# %load_ext autoreload
# %autoreload 2

In [11]:
assert torch.cuda.is_available()

# Properties

In [27]:
NOTEBOOK_NAME = "protein_4xEdgeConv_bs4"

In [30]:
try:
    NOTEBOOK_PATH
    UNIQUE_PATH
except NameError:
    NOTEBOOK_PATH = Path(NOTEBOOK_NAME).resolve()
    NOTEBOOK_PATH.mkdir(exist_ok=True)
    unique_id = uuid.uuid4().hex[:8]
    UNIQUE_PATH = NOTEBOOK_PATH.joinpath(unique_id)
    UNIQUE_PATH.mkdir()
NOTEBOOK_PATH, UNIQUE_PATH

(PosixPath('/home/kimlab1/strokach/workspace/proteinsolver/notebooks/protein_4xEdgeConv_bs4'),
 PosixPath('/home/kimlab1/strokach/workspace/proteinsolver/notebooks/protein_4xEdgeConv_bs4/6a84866a'))

# Datasets

In [14]:
DATA_ROOT = Path(tempfile.gettempdir())
DATA_ROOT = Path("/home/strokach/ml_data")
DATA_ROOT.mkdir(exist_ok=True)
DATA_ROOT

PosixPath('/home/strokach/ml_data')

In [15]:
protein_dataset_train_0 = proteinsolver.datasets.ProteinDataset(root=DATA_ROOT / "protein_train_0", subset="train_0")

In [16]:
protein_dataset_train_1 = proteinsolver.datasets.ProteinDataset(root=DATA_ROOT / "protein_train_1", subset="train_1")

In [17]:
# protein_dataset_train_2 = proteinsolver.datasets.ProteinDataset(root=DATA_ROOT / "protein_train_2", subset="train_2")

In [18]:
protein_dataset_valid = proteinsolver.datasets.ProteinInMemoryDataset(root=DATA_ROOT / "protein_valid", subset="valid")

In [19]:
protein_dataset_test = proteinsolver.datasets.ProteinInMemoryDataset(root=DATA_ROOT / "protein_test", subset="test")

In [20]:
file = "/home/kimlab1/database_data/datapkg_output_dir/adjacency-net-v2/master/validation_dataset_wdistances/adjacency_matrix.parquet/database_id=G3DSA%3A2.40.155.10/part-00000-d5e89475-69dd-45c6-9a80-5bf751fce422-c000.snappy.parquet"
protein_dataset_gfp = proteinsolver.datasets.ProteinInMemoryDataset(root=DATA_ROOT / "protein_gfp", subset="gfp", data_url=file)

## `SudokuDataset`

In [21]:
sudoku_dataset_train = proteinsolver.datasets.SudokuDataset2(root=DATA_ROOT.joinpath("sudoku_train"), subset="train")

In [22]:
sudoku_dataset_valid = proteinsolver.datasets.SudokuDataset2(root=DATA_ROOT.joinpath("sudoku_valid"), subset="valid")

## `TUDataset`

In [23]:
from torch_geometric.datasets import TUDataset

In [24]:
tu_dataset = TUDataset(root=tempfile.gettempdir() + '/ENZYMES', name='ENZYMES')

In [25]:
tu_dataset[0]

Data(edge_index=[2, 168], x=[37, 3], y=[1])

# Models

In [78]:
%%file {UNIQUE_PATH}/model.py
import atexit
import csv
import tempfile
import time
import warnings

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.inits import reset
from torch_geometric.utils import add_self_loops, remove_self_loops, scatter_


class EdgeConvMod(torch.nn.Module):
    def __init__(self, nn, aggr="max"):
        super().__init__()
        self.nn = nn
        self.aggr = aggr
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn)

    def forward(self, x, edge_index, edge_attr=None):
        """"""
        row, col = edge_index
        x = x.unsqueeze(-1) if x.dim() == 1 else x

        # TODO: Try -x[col] instead of x[col] - x[row]
        if edge_attr is None:
            out = torch.cat([x[row], x[col]], dim=-1)
        else:
            out = torch.cat([x[row], x[col], edge_attr], dim=-1)
        out = self.nn(out)
        x = scatter_(self.aggr, out, row, dim_size=x.size(0))

        return x, out

    def __repr__(self):
        return "{}(nn={})".format(self.__class__.__name__, self.nn)


class EdgeConvBatch(nn.Module):
    def __init__(self, gnn, hidden_size, batch_norm=True, dropout=0.2):
        super().__init__()

        self.gnn = gnn

        x_post_modules = []
        edge_attr_post_modules = []

        if batch_norm is not None:
            x_post_modules.append(nn.LayerNorm(hidden_size))
            edge_attr_post_modules.append(nn.LayerNorm(hidden_size))

        if dropout:
            x_post_modules.append(nn.Dropout(dropout))
            edge_attr_post_modules.append(nn.Dropout(dropout))

        self.x_postprocess = nn.Sequential(*x_post_modules)
        self.edge_attr_postprocess = nn.Sequential(*edge_attr_post_modules)

    def forward(self, x, edge_index, edge_attr=None):
        x, edge_attr = self.gnn(x, edge_index, edge_attr)
        x = self.x_postprocess(x)
        edge_attr = self.edge_attr_postprocess(edge_attr)
        return x, edge_attr


def get_graph_conv_layer(input_size, hidden_size, output_size):
    mlp = nn.Sequential(
        #
        nn.Linear(input_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, output_size),
    )
    gnn = EdgeConvMod(nn=mlp, aggr="add")
    graph_conv = EdgeConvBatch(gnn, output_size, batch_norm=True, dropout=0.2)
    return graph_conv


class Net(nn.Module):
    def __init__(self, x_input_size, adj_input_size, hidden_size, output_size):
        super().__init__()

        self.embed_x = nn.Sequential(
            nn.Embedding(x_input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            #             nn.ReLU(),
        )

        if adj_input_size:
            self.embed_adj = nn.Sequential(
                nn.Linear(adj_input_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.LayerNorm(hidden_size),
                #                 nn.ELU(),
            )
        else:
            self.embed_adj = None

        self.graph_conv_1 = get_graph_conv_layer((2 + bool(adj_input_size)) * hidden_size, 2 * hidden_size, hidden_size)
        self.graph_conv_2 = get_graph_conv_layer(3 * hidden_size, 2 * hidden_size, hidden_size)
        self.graph_conv_3 = get_graph_conv_layer(3 * hidden_size, 2 * hidden_size, hidden_size)
        self.graph_conv_4 = get_graph_conv_layer(3 * hidden_size, 2 * hidden_size, hidden_size)
        self.linear_out = nn.Linear(hidden_size, output_size)

    def forward(self, x, edge_index, edge_attr=None):

        x = self.embed_x(x)
        edge_index, _ = remove_self_loops(edge_index)  # We should remove self loops in this case!
        edge_attr = self.embed_adj(edge_attr) if edge_attr is not None else None

        x_out, edge_attr_out = self.graph_conv_1(x, edge_index, edge_attr)
        x = x + x_out
        edge_attr = (edge_attr + edge_attr_out) if edge_attr is not None else edge_attr_out

        x = F.relu(x)
        edge_attr = F.relu(edge_attr)
        x_out, edge_attr_out = self.graph_conv_2(x, edge_index, edge_attr)
        x = x + x_out
        edge_attr = edge_attr + edge_attr_out

        x = F.relu(x)
        edge_attr = F.relu(edge_attr)
        x_out, edge_attr_out = self.graph_conv_3(x, edge_index, edge_attr)
        x = x + x_out
        edge_attr = edge_attr + edge_attr_out

        x = F.relu(x)
        edge_attr = F.relu(edge_attr)
        x_out, edge_attr_out = self.graph_conv_4(x, edge_index, edge_attr)
        x = x + x_out
        edge_attr = edge_attr + edge_attr_out

        x = self.linear_out(x)
        return x


def to_fixed_width(lst, precision=None):
    lst = [round(l, precision) if isinstance(l, float) else l for l in lst]
    return [f"{l: <18}" for l in lst]


class Stats:
    epoch: int
    step: int
    batch_size: int
    echo: bool
    total_loss: float
    num_correct_preds: int
    num_preds: int
    num_correct_preds_missing: int
    num_preds_missing: int
    num_correct_preds_missing_valid: int
    num_preds_missing_valid: int
    start_time: float

    def __init__(self, *, epoch=0, step=0, batch_size=1, filename=None, echo=True):
        self.epoch = epoch
        self.step = step
        self.batch_size = batch_size
        self.echo = echo
        self.reset_parameters()

        if filename:
            self.filehandle = open(filename, "wt", newline="")
            self.writer = csv.DictWriter(self.filehandle, list(self.stats.keys()), dialect="unix")
            self.writer.writeheader()
            atexit.register(self.filehandle.close)
        else:
            self.filehandle = None
            self.writer = None

    def reset_parameters(self):
        self.num_steps = 0
        self.total_loss = 0
        self.num_correct_preds = 0
        self.num_preds = 0
        self.num_correct_preds_missing = 0
        self.num_preds_missing = 0
        self.num_correct_preds_missing_valid = 0
        self.num_preds_missing_valid = 0
        self.start_time = time.perf_counter()

    @property
    def header(self):
        return "".join(to_fixed_width(self.stats.keys()))

    @property
    def row(self):
        return "".join(to_fixed_width(self.stats.values(), 4))

    @property
    def stats(self):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            return {
                "epoch": self.epoch,
                "step": self.step,
                "datapoint": self.datapoint,
                "avg_loss": np.float64(1) * self.total_loss / self.num_steps,
                "accuracy": np.float64(1) * self.num_correct_preds / self.num_preds,
                "accuracy_m": np.float64(1) * self.num_correct_preds_missing / self.num_preds_missing,
                "accuracy_mv": self.accuracy_mv,
                "time_elapsed": time.perf_counter() - self.start_time,
            }

    @property
    def accuracy_mv(self):
        return np.float64(1) * self.num_correct_preds_missing_valid / self.num_preds_missing_valid

    @property
    def datapoint(self):
        return self.step * self.batch_size

    def write_header(self):
        if self.echo:
            print(self.header)
        if self.writer is not None:
            self.writer.writeheader()

    def write_row(self):
        if self.echo:
            print(self.row)
        if self.writer is not None:
            self.writer.writerow(self.stats)

Overwriting /home/kimlab1/strokach/workspace/proteinsolver/notebooks/protein_4xEdgeConv_bs4/6a84866a/model.py


In [79]:
%run {UNIQUE_PATH}/model.py

In [80]:
stats = Stats(epoch=0, step=0, batch_size=64, filename=tempfile.NamedTemporaryFile().name, echo=True)
stats.num_steps += 1
stats.num_preds += 1
stats.num_preds_missing += 1
stats.num_preds_missing_valid += 1

stats.write_header()
stats.write_row()

epoch             step              datapoint         avg_loss          accuracy          accuracy_m        accuracy_mv       time_elapsed      
0                 0                 0                 0.0               0.0               0.0               0.0               0.001             


In [81]:
# Parameters
device = torch.device("cuda:1")
# device = "cpu"
batch_size = 4
num_features = 20
adj_input_size = 2
hidden_size = 128
frac_present = 0.5
frac_present_valid = frac_present
info_size= 1024

dataloaders = {
    "train_0": DataLoader(protein_dataset_train_0, shuffle=True, num_workers=4, batch_size=batch_size, drop_last=True),
    "train_1": DataLoader(protein_dataset_train_1, shuffle=True, num_workers=4, batch_size=batch_size, drop_last=True),
    "train_2": None,
    "valid": DataLoader(protein_dataset_valid[:128], shuffle=False, num_workers=4, batch_size=1, drop_last=False),
    "test": DataLoader(protein_dataset_test[:128], shuffle=False, num_workers=4, batch_size=1, drop_last=False),

    "gfp_train": DataLoader(protein_dataset_test[:928], shuffle=True, num_workers=4, batch_size=batch_size, drop_last=True),
    "gfp_valid": DataLoader(protein_dataset_test[928:], shuffle=False, num_workers=4, batch_size=1, drop_last=False),
}

print(device)

cuda:1


In [82]:
@contextmanager
def eval_net(net: nn.Module):
    training = net.training
    try:
        net.train(False)
        yield
    finally:
        net.train(training)


def get_stats_on_missing(data, output):
    mask = (data.x == num_features).squeeze()
    output_missing = output[mask]
    _, predicted_missing = torch.max(output_missing.data, 1)
    return (predicted_missing == data.y[mask]).sum().item(), len(predicted_missing)


def get_data_x(data, frac_present):
    x = torch.where(
        torch.rand(data.y.size(0), device=data.y.device) < frac_present,
        data.y,
        torch.ones(1, dtype=torch.long, device=data.y.device) * num_features,
    )
    return x

In [None]:
continue_previous = False

if not continue_previous:
    net = Net(
        x_input_size=num_features + 1, adj_input_size=adj_input_size, hidden_size=hidden_size, output_size=num_features
    )
    net = net.to(device)
    stats = Stats(epoch=0, step=0, batch_size=batch_size, filename=NOTEBOOK_PATH.joinpath("training.log"), echo=True)
    optimizer = optim.Adam(net.parameters(), lr=0.0001)

criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', verbose=True)

net = net.train()
stats.write_header()
for epoch in range(stats.epoch + 1 if continue_previous else 0, 100_000):
    stats.epoch = epoch
    train_dl = iter(dataloaders[f"train_{epoch % 2}"])
    while True:
        try:
            data = next(train_dl)
        except StopIteration:
            break

        stats.step += 1
        optimizer.zero_grad()

        data = data.to(device)
        data.y = data.x
        data.x = get_data_x(data, frac_present)
        output = net(data.x, data.edge_index, data.edge_attr if hasattr(data, "edge_attr") else None)
        
        loss = criterion(output, data.y)
        loss.backward()

        stats.total_loss += loss.detach().item()
        stats.num_steps += 1

        # Accuracy for all
        _, predicted = torch.max(output.data, 1)
        stats.num_correct_preds += (predicted == data.y).sum().item()
        stats.num_preds += len(predicted)

        # Accuracy for missing only
        num_correct, num_total = get_stats_on_missing(data, output)
        stats.num_correct_preds_missing += num_correct
        stats.num_preds_missing += num_total

        optimizer.step()
        
        if (stats.datapoint % info_size) < batch_size:
            for j, data in enumerate(dataloaders["valid"]):
                data = data.to(device)
                data.y = data.x
                data.x = get_data_x(data, frac_present_valid)

                with torch.no_grad() and eval_net(net):
                    output = net(data.x, data.edge_index, data.edge_attr if hasattr(data, "edge_attr") else None)

                num_correct, num_total = get_stats_on_missing(data, output)
                stats.num_correct_preds_missing_valid += num_correct
                stats.num_preds_missing_valid += num_total

#             scheduler.step(stats.stats['accuracy'])
            stats.write_row()
            stats.reset_parameters()
    output_filename = f"e{stats.epoch}-s{stats.step}-d{stats.datapoint}.state"
    torch.save(net.state_dict(), NOTEBOOK_PATH.joinpath(output_filename))

epoch             step              datapoint         avg_loss          accuracy          accuracy_m        accuracy_mv       time_elapsed      
0                 256               1024              1.7162            0.4886            0.0742            0.0814            10.6911           
0                 512               2048              1.5226            0.5345            0.0791            0.084             10.9926           
0                 768               3072              1.5012            0.5399            0.084             0.0911            10.9548           
0                 1024              4096              1.4931            0.5403            0.0828            0.0899            10.4932           
0                 1280              5120              1.4844            0.5447            0.0883            0.0933            10.8028           
0                 1536              6144              1.465             0.5474            0.0897            0.0844            10.8