## Summary

- 5 layer graph-conv + attn.
- batch size = 1.

----

## Install dependencies (Google Colab only)

In [1]:
try:
    import google.colab
    GOOGLE_COLAB = True
except ImportError:
    GOOGLE_COLAB = False

In [2]:
if GOOGLE_COLAB:
    !pip install --upgrade torch-scatter
    !pip install --upgrade torch-sparse
    !pip install --upgrade torch-cluster
    !pip install --upgrade torch-spline-conv
    !pip install torch-geometric

In [3]:
if GOOGLE_COLAB:
    !pip install git+https://gitlab.com/ostrokach/proteinsolver.git

## Imports

In [4]:
import atexit
import csv
import tempfile
import time
import uuid
import warnings
from collections import deque
from contextlib import contextmanager
from pathlib import Path

import numpy as np

import matplotlib.pyplot as plt
import pandas as pd
import pyarrow
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.tensorboard
from torch import optim
from torch_geometric.data import DataLoader
from torch_geometric.nn import ChebConv, EdgeConv, GATConv, GCNConv
from torch_geometric.nn.inits import reset
from torch_geometric.utils import add_self_loops, remove_self_loops, scatter_

In [5]:
import proteinsolver
import proteinsolver.datasets



In [6]:
# %load_ext autoreload
# %autoreload 2

In [7]:
assert torch.cuda.is_available()

# Properties

In [8]:
device = torch.device("cuda:0")

In [9]:
DATA_ROOT = Path(tempfile.gettempdir())
DATA_ROOT = Path("/home/strokach/ml_data")
DATA_ROOT.mkdir(exist_ok=True)
DATA_ROOT

PosixPath('/home/strokach/ml_data')

In [10]:
try:
    NOTEBOOK_PATH
    UNIQUE_PATH
except NameError:
    NOTEBOOK_PATH = Path("protein_train").resolve()
    NOTEBOOK_PATH.mkdir(exist_ok=True)
    unique_id = uuid.uuid4().hex[:8]
    UNIQUE_PATH = NOTEBOOK_PATH.joinpath(unique_id)
    UNIQUE_PATH.mkdir()
NOTEBOOK_PATH, UNIQUE_PATH

(PosixPath('/home/kimlab1/strokach/workspace/proteinsolver/notebooks/protein_train'),
 PosixPath('/home/kimlab1/strokach/workspace/proteinsolver/notebooks/protein_train/0007604c'))

In [11]:
DATAPKG_OUTPUT_DIR = Path(f"~/datapkg_output_dir").expanduser().resolve()
DATAPKG_OUTPUT_DIR

PosixPath('/home/kimlab1/database_data/datapkg_output_dir')

In [12]:
proteinsolver.settings.data_url = DATAPKG_OUTPUT_DIR.as_posix()
proteinsolver.settings.data_url

'/home/kimlab1/database_data/datapkg_output_dir'

# Datasets

In [13]:
datasets = {}

In [14]:
for i in range(10):
    dataset_name = f"protein_train_{i}"
    datasets[dataset_name] = proteinsolver.datasets.ProteinDataset2(
        root=DATA_ROOT.joinpath(dataset_name), subset=f"protein_train_{i}"
    )

In [15]:
datasets["protein_valid"] = proteinsolver.datasets.ProteinInMemoryDataset(root=DATA_ROOT / "protein_valid", subset="valid")

In [16]:
datasets["protein_test"] = proteinsolver.datasets.ProteinInMemoryDataset(root=DATA_ROOT / "protein_test", subset="test")

In [17]:
file = "/home/kimlab1/database_data/datapkg_output_dir/adjacency-net-v2/master/validation_dataset_wdistances/adjacency_matrix.parquet/database_id=G3DSA%3A2.40.155.10/part-00000-d5e89475-69dd-45c6-9a80-5bf751fce422-c000.snappy.parquet"
datasets["protein_gfp"] = proteinsolver.datasets.ProteinInMemoryDataset(
    root=DATA_ROOT / "protein_gfp", subset="gfp", data_url=file
)

# Models

In [18]:
%%file {UNIQUE_PATH}/model.py
import atexit
import copy
import csv
import tempfile
import time
import warnings

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.container import ModuleList
from torch_geometric.nn.inits import reset
from torch_geometric.utils import add_self_loops, remove_self_loops, scatter_, to_dense_adj, to_dense_batch


class EdgeConvMod(torch.nn.Module):
    def __init__(self, nn, aggr="max"):
        super().__init__()
        self.nn = nn
        self.aggr = aggr
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn)

    def forward(self, x, edge_index, edge_attr=None):
        """"""
        row, col = edge_index
        x = x.unsqueeze(-1) if x.dim() == 1 else x

        # TODO: Try -x[col] instead of x[col] - x[row]
        if edge_attr is None:
            out = torch.cat([x[row], x[col]], dim=-1)
        else:
            out = torch.cat([x[row], x[col], edge_attr], dim=-1)
        out = self.nn(out)
        x = scatter_(self.aggr, out, row, dim_size=x.size(0))

        return x, out

    def __repr__(self):
        return "{}(nn={})".format(self.__class__.__name__, self.nn)


class EdgeConvBatch(nn.Module):
    def __init__(self, gnn, hidden_size, batch_norm=True, dropout=0.2):
        super().__init__()

        self.gnn = gnn

        x_post_modules = []
        edge_attr_post_modules = []

        if batch_norm is not None:
            x_post_modules.append(nn.LayerNorm(hidden_size))
            edge_attr_post_modules.append(nn.LayerNorm(hidden_size))

        if dropout:
            x_post_modules.append(nn.Dropout(dropout))
            edge_attr_post_modules.append(nn.Dropout(dropout))

        self.x_postprocess = nn.Sequential(*x_post_modules)
        self.edge_attr_postprocess = nn.Sequential(*edge_attr_post_modules)

    def forward(self, x, edge_index, edge_attr=None):
        x, edge_attr = self.gnn(x, edge_index, edge_attr)
        x = self.x_postprocess(x)
        edge_attr = self.edge_attr_postprocess(edge_attr)
        return x, edge_attr


def get_graph_conv_layer(input_size, hidden_size, output_size):
    mlp = nn.Sequential(
        #
        nn.Linear(input_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, output_size),
    )
    gnn = EdgeConvMod(nn=mlp, aggr="add")
    graph_conv = EdgeConvBatch(gnn, output_size, batch_norm=True, dropout=0.2)
    return graph_conv


class MyEdgeConv(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.nn = nn.Sequential(
            #
            nn.Linear(hidden_size * 3, hidden_size * 2),
            nn.ReLU(),
            nn.Linear(hidden_size * 2, hidden_size),
        )
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn)

    def forward(self, x, edge_index, edge_attr=None):
        """"""
        row, col = edge_index
        x = x.unsqueeze(-1) if x.dim() == 1 else x

        # TODO: Try -x[col] instead of x[col] - x[row]
        if edge_attr is None:
            out = torch.cat([x[row], x[col]], dim=-1)
        else:
            out = torch.cat([x[row], x[col], edge_attr], dim=-1)
        edge_attr_out = self.nn(out)

        return edge_attr_out

    def __repr__(self):
        return "{}(nn={})".format(self.__class__.__name__, self.nn)


class MyAttn(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(hidden_size, num_heads)
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.attn)

    def forward(self, x, edge_index, edge_attr, batch):
        """"""
        query = x.unsqueeze(0)
        key = to_dense_adj(edge_index, batch=batch, edge_attr=edge_attr).squeeze(0)

        adjacency = to_dense_adj(edge_index, batch=batch).squeeze(0)
        key_padding_mask = adjacency == 0
        key_padding_mask[torch.eye(key_padding_mask.size(0)).to(torch.bool)] = 0
        #         attn_mask = torch.zeros_like(key)
        #         attn_mask[mask] = -float("inf")

        x_out, _ = self.attn(query, key, key, key_padding_mask=key_padding_mask)
#         x_out = torch.where(torch.isnan(x_out), torch.zeros_like(x_out), x_out)
        x_out = x_out.squeeze(0)
        assert (x_out == x_out).all().item()
        assert x.shape == x_out.shape, (x.shape, x_out.shape)
        return x_out

    def __repr__(self):
        return "{}(nn={})".format(self.__class__.__name__, self.nn)


class Net(nn.Module):
    def __init__(self, x_input_size, adj_input_size, hidden_size, output_size):
        super().__init__()

        self.embed_x = nn.Sequential(
            nn.Embedding(x_input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
            #             nn.ReLU(),
        )

        if adj_input_size:
            self.embed_adj = nn.Sequential(
                nn.Linear(adj_input_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.LayerNorm(hidden_size),
                #                 nn.ELU(),
            )
        else:
            self.embed_adj = None

        N = 5
        norm = nn.LayerNorm(hidden_size)
        self.x_norms_0 = _get_clones(norm, N)
        self.adj_norms_0 = _get_clones(norm, N)
        self.x_norms_1 = _get_clones(norm, N)
        self.adj_norms_1 = _get_clones(norm, N)

        edge_conv = MyEdgeConv(hidden_size)
        self.edge_convs = _get_clones(edge_conv, N)

        attn = MyAttn(hidden_size, 8)
        self.attns = _get_clones(attn, N)

        self.linear_out = nn.Linear(hidden_size, output_size)
        
        self.register_buffer('batch', torch.zeros(10000, dtype=torch.int64))

    def forward(self, x, edge_index, edge_attr):

        x = self.embed_x(x)
#         edge_index, _ = add_self_loops(edge_index)  # We should remove self loops in this case!
        edge_attr = self.embed_adj(edge_attr) if edge_attr is not None else None

        i = 0
        edge_attr_out = self.edge_convs[i](x, edge_index, edge_attr)
        edge_attr = edge_attr + F.dropout(edge_attr_out, 0.1)
        edge_attr = self.adj_norms_1[i](edge_attr)

        x_out = self.attns[i](x, edge_index, self.adj_norms_0[i](edge_attr_out), self.batch[:x.size(0)])
        x = x + F.dropout(x_out, 0.1)
        x = self.x_norms_1[i](x)
    

        i = 1
        edge_attr_out = self.edge_convs[i](x, edge_index, edge_attr)
        edge_attr = edge_attr + F.dropout(edge_attr_out, 0.1)
        edge_attr = self.adj_norms_1[i](edge_attr)

        x_out = self.attns[i](x, edge_index, self.adj_norms_0[i](edge_attr_out), self.batch[:x.size(0)])
        x = x + F.dropout(x_out, 0.1)
        x = self.x_norms_1[i](x)
    

        i = 2
        edge_attr_out = self.edge_convs[i](x, edge_index, edge_attr)
        edge_attr = edge_attr + F.dropout(edge_attr_out, 0.1)
        edge_attr = self.adj_norms_1[i](edge_attr)

        x_out = self.attns[i](x, edge_index, self.adj_norms_0[i](edge_attr_out), self.batch[:x.size(0)])
        x = x + F.dropout(x_out, 0.1)
        x = self.x_norms_1[i](x)
    

        i = 3
        edge_attr_out = self.edge_convs[i](x, edge_index, edge_attr)
        edge_attr = edge_attr + F.dropout(edge_attr_out, 0.1)
        edge_attr = self.adj_norms_1[i](edge_attr)

        x_out = self.attns[i](x, edge_index, self.adj_norms_0[i](edge_attr_out), self.batch[:x.size(0)])
        x = x + F.dropout(x_out, 0.1)
        x = self.x_norms_1[i](x)


        i = 4
        edge_attr_out = self.edge_convs[i](x, edge_index, edge_attr)
        edge_attr = edge_attr + F.dropout(edge_attr_out, 0.1)
        edge_attr = self.adj_norms_1[i](edge_attr)

        x_out = self.attns[i](x, edge_index, self.adj_norms_0[i](edge_attr_out), self.batch[:x.size(0)])
        x = x + F.dropout(x_out, 0.1)
        x = self.x_norms_1[i](x)


        x = self.linear_out(x)

        return x


def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])


def to_fixed_width(lst, precision=None):
    lst = [round(l, precision) if isinstance(l, float) else l for l in lst]
    return [f"{l: <18}" for l in lst]


class Stats:
    epoch: int
    step: int
    batch_size: int
    echo: bool
    total_loss: float
    num_correct_preds: int
    num_preds: int
    num_correct_preds_missing: int
    num_preds_missing: int
    num_correct_preds_missing_valid: int
    num_preds_missing_valid: int
    start_time: float

    def __init__(self, *, epoch=0, step=0, batch_size=1, filename=None, echo=True, tb_writer=None):
        self.epoch = epoch
        self.step = step
        self.batch_size = batch_size
        self.echo = echo
        self.tb_writer = tb_writer
        self.reset_parameters()

        if filename:
            self.filehandle = open(filename, "wt", newline="")
            self.writer = csv.DictWriter(self.filehandle, list(self.stats.keys()), dialect="unix")
            self.writer.writeheader()
            atexit.register(self.filehandle.close)
        else:
            self.filehandle = None
            self.writer = None

    def reset_parameters(self):
        self.num_steps = 0
        self.total_loss = 0
        self.num_correct_preds = 0
        self.num_preds = 0
        self.num_correct_preds_missing = 0
        self.num_preds_missing = 0
        self.num_correct_preds_missing_valid = 0
        self.num_preds_missing_valid = 0
        self.start_time = time.perf_counter()

    @property
    def header(self):
        return "".join(to_fixed_width(self.stats.keys()))

    @property
    def row(self):
        return "".join(to_fixed_width(self.stats.values(), 4))

    @property
    def stats(self):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            return {
                "epoch": self.epoch,
                "step": self.step,
                "datapoint": self.datapoint,
                "avg_loss": np.float64(1) * self.total_loss / self.num_steps,
                "accuracy": np.float64(1) * self.num_correct_preds / self.num_preds,
                "accuracy_m": np.float64(1) * self.num_correct_preds_missing / self.num_preds_missing,
                "accuracy_mv": self.accuracy_mv,
                "time_elapsed": time.perf_counter() - self.start_time,
            }

    @property
    def accuracy_mv(self):
        return np.float64(1) * self.num_correct_preds_missing_valid / self.num_preds_missing_valid

    @property
    def datapoint(self):
        return self.step * self.batch_size

    def write_header(self):
        if self.echo:
            print(self.header)
        if self.writer is not None:
            self.writer.writeheader()

    def write_row(self):
        if self.echo:
            print(self.row)
        if self.writer is not None:
            self.writer.writerow(self.stats)
        if self.tb_writer is not None:
            stats = self.stats
            datapoint = stats.pop("datapoint")
            for key, value in stats.items():
                self.tb_writer.add_scalar(key, value, datapoint)
            self.tb_writer.flush()

Writing /home/kimlab1/strokach/workspace/proteinsolver/notebooks/protein_train/0007604c/model.py


In [19]:
%run {UNIQUE_PATH}/model.py

In [20]:
# Parameters
batch_size = 4
num_features = 20
adj_input_size = 2
hidden_size = 128
frac_present = 0.5
frac_present_valid = frac_present
info_size= 1024

# dataloaders = {
#     "train_0": DataLoader(protein_dataset_train_0, shuffle=True, num_workers=4, batch_size=batch_size, drop_last=True),
#     "train_1": DataLoader(protein_dataset_train_1, shuffle=True, num_workers=4, batch_size=batch_size, drop_last=True),
#     "train_2": None,
#     "valid": DataLoader(protein_dataset_valid[:128], shuffle=False, num_workers=4, batch_size=1, drop_last=False),
#     "test": DataLoader(protein_dataset_test[:128], shuffle=False, num_workers=4, batch_size=1, drop_last=False),

#     "gfp_train": DataLoader(protein_dataset_test[:928], shuffle=True, num_workers=4, batch_size=batch_size, drop_last=True),
#     "gfp_valid": DataLoader(protein_dataset_test[928:], shuffle=False, num_workers=4, batch_size=1, drop_last=False),
# }

In [21]:
@contextmanager
def eval_net(net: nn.Module):
    training = net.training
    try:
        net.train(False)
        yield
    finally:
        net.train(training)


def get_stats_on_missing(data, output):
    mask = (data.x == num_features).squeeze()
    output_missing = output[mask]
    _, predicted_missing = torch.max(output_missing.data, 1)
    return (predicted_missing == data.y[mask]).sum().item(), len(predicted_missing)


def get_data_x(data, frac_present):
    x = torch.where(
        torch.rand(data.y.size(0), device=data.y.device) < frac_present,
        data.y,
        torch.ones(1, dtype=torch.long, device=data.y.device) * num_features,
    )
    return x

In [22]:
tensorboard_path = NOTEBOOK_PATH.joinpath("runs", UNIQUE_PATH.name)

In [None]:
continue_previous = False

if not continue_previous:
    net = Net(
        x_input_size=num_features + 1, adj_input_size=adj_input_size, hidden_size=hidden_size, output_size=num_features
    )
    net = net.to(device)
    stats = Stats(
        epoch=0,
        step=0,
        batch_size=batch_size,
        filename=UNIQUE_PATH.joinpath("training.log"),
        echo=True,
        tb_writer=torch.utils.tensorboard.writer.SummaryWriter(log_dir=tensorboard_path.with_suffix(f".xxx")),
    )
    optimizer = optim.Adam(net.parameters(), lr=0.0001)

criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", verbose=True)

net = net.train()
stats.write_header()
for epoch in range(stats.epoch + 1 if continue_previous else 0, 100_000):
    stats.epoch = epoch
    train_dl = iter(
        DataLoader(
            datasets[f"protein_train_{epoch % 10}"], shuffle=False, num_workers=1, batch_size=batch_size, drop_last=True
        )
    )
    while True:
        try:
            data = next(train_dl)
        except StopIteration:
            break

        stats.step += 1
        optimizer.zero_grad()

        max_size = 700
        if data.x.size(0) > max_size:
            data.x = data.x[:max_size]
            mask = (data.edge_index < max_size).all(dim=0)
            data.edge_index = data.edge_index[:, mask]
            data.edge_attr = data.edge_attr[mask, :]

        data = data.to(device)
        data.y = data.x
        data.x = get_data_x(data, frac_present)
        output = net(data.x, data.edge_index, data.edge_attr)
        loss = criterion(output, data.y)
        loss.backward()

        stats.total_loss += loss.detach().item()
        stats.num_steps += 1

        # Accuracy for all
        _, predicted = torch.max(output.data, 1)
        stats.num_correct_preds += (predicted == data.y).sum().item()
        stats.num_preds += len(predicted)

        # Accuracy for missing only
        num_correct, num_total = get_stats_on_missing(data, output)
        stats.num_correct_preds_missing += num_correct
        stats.num_preds_missing += num_total

        optimizer.step()

        if (stats.datapoint % info_size) < batch_size:
            for j, data in enumerate(
                DataLoader(
                    datasets["protein_valid"][:128], shuffle=False, num_workers=1, batch_size=1, drop_last=True
                )
            ):
                data = data.to(device)
                data.y = data.x
                data.x = get_data_x(data, frac_present_valid)

                with torch.no_grad() and eval_net(net):
                    output = net(data.x, data.edge_index, data.edge_attr)

                num_correct, num_total = get_stats_on_missing(data, output)
                stats.num_correct_preds_missing_valid += num_correct
                stats.num_preds_missing_valid += num_total

            #             scheduler.step(stats.stats['accuracy'])
            stats.write_row()
            stats.reset_parameters()
    output_filename = f"e{stats.epoch}-s{stats.step}-d{stats.datapoint}.state"
    torch.save(net.state_dict(), UNIQUE_PATH.joinpath(output_filename))

epoch             step              datapoint         avg_loss          accuracy          accuracy_m        accuracy_mv       time_elapsed      
0                 256               1024              1.7852            0.4911            0.0896            0.0907            46.6933           
0                 512               2048              1.4817            0.5475            0.0958            0.0856            45.9344           
0                 768               3072              1.4576            0.5498            0.0983            0.0902            47.8757           
0                 1024              4096              1.4481            0.5511            0.1014            0.0959            45.3028           
0                 1280              5120              1.443             0.5505            0.0974            0.0996            46.2231           
0                 1536              6144              1.4475            0.5491            0.1011            0.0987            45.2