# Installation

### Install `pytorch_geometric`

In [None]:
!pip install --upgrade torch-scatter

In [None]:
!pip install --upgrade torch-sparse

In [None]:
!pip install --upgrade torch-cluster

In [None]:
!pip install --upgrade torch-spline-conv

In [None]:
!pip install torch-geometric

### Install `proteinsolver`

In [None]:
!pip install git+https://gitlab.com/ostrokach/proteinsolver.git

# Imports

In [None]:
!nvidia-smi

In [None]:
import atexit
import csv
import os
import tempfile
import time
import warnings
from collections import deque
from contextlib import contextmanager
from pathlib import Path
from typing import NamedTuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow
import pyarrow.parquet as pq
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy import stats
from torch import optim
from torch.utils.data import DataLoader
from torch_geometric.data import DataLoader
from torch_geometric.nn import ChebConv, EdgeConv, GATConv, GCNConv
from torch_geometric.nn.inits import reset
from torch_geometric.utils import add_self_loops, remove_self_loops, scatter_

# from google.colab import files

In [None]:
torch.cuda.is_available()

In [None]:
import kmbio
from kmbio import PDB
from kmtools import structure_tools

import proteinsolver
import proteinsolver.datasets

In [None]:
%load_ext autoreload
%autoreload 2

# Properties

In [None]:
NOTEBOOK_NAME = "generate_protein_sequences"

In [None]:
NOTEBOOK_PATH = Path(NOTEBOOK_NAME).resolve()
NOTEBOOK_PATH.mkdir(exist_ok=True)
NOTEBOOK_PATH

In [None]:
STRUCTURE_FILE = Path(os.getenv("STRUCTURE_FILE", NOTEBOOK_PATH / "inputs" / "1n5uA03.pdb")).resolve()
# STRUCTURE_FILE = Path(os.getenv("STRUCTURE_FILE", NOTEBOOK_PATH / "inputs" / "4z8jA00.pdb")).resolve()
# STRUCTURE_FILE = Path(os.getenv("STRUCTURE_FILE", NOTEBOOK_PATH / "inputs" / "4unuA00.pdb")).resolve()
# STRUCTURE_FILE = Path(os.getenv("STRUCTURE_FILE", NOTEBOOK_PATH / "inputs" / "4beuA02.pdb")).resolve()
STRUCTURE_FILE

In [None]:
structure_all = PDB.load(STRUCTURE_FILE)
structure = PDB.Structure(STRUCTURE_FILE.name + "A", structure_all[0].extract('A'))
assert len(list(structure.chains)) == 1

In [None]:
PDB.view_structure(structure)

# Models

In [None]:
class EdgeConvMod(torch.nn.Module):
    def __init__(self, nn, aggr="max"):
        super().__init__()
        self.nn = nn
        self.aggr = aggr
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn)

    def forward(self, x, edge_index, edge_attr=None):
        """"""
        row, col = edge_index
        x = x.unsqueeze(-1) if x.dim() == 1 else x

        # TODO: Try -x[col] instead of x[col] - x[row]
        if edge_attr is None:
            out = torch.cat([x[row], x[col]], dim=-1)
        else:
            out = torch.cat([x[row], x[col], edge_attr], dim=-1)
        out = self.nn(out)
        x = scatter_(self.aggr, out, row, dim_size=x.size(0))

        return x, out

    def __repr__(self):
        return "{}(nn={})".format(self.__class__.__name__, self.nn)

In [None]:
class EdgeConvBatch(nn.Module):
    def __init__(self, gnn, hidden_size, batch_norm=True, dropout=0.2):
        super().__init__()

        self.gnn = gnn

        x_post_modules = []
        edge_attr_post_modules = []

        if batch_norm is not None:
            x_post_modules.append(nn.LayerNorm(hidden_size))
            edge_attr_post_modules.append(nn.LayerNorm(hidden_size))

        if dropout:
            x_post_modules.append(nn.Dropout(dropout))
            edge_attr_post_modules.append(nn.Dropout(dropout))

        self.x_postprocess = nn.Sequential(*x_post_modules)
        self.edge_attr_postprocess = nn.Sequential(*edge_attr_post_modules)

    def forward(self, x, edge_index, edge_attr=None):
        x, edge_attr = self.gnn(x, edge_index, edge_attr)
        x = self.x_postprocess(x)
        edge_attr = self.edge_attr_postprocess(edge_attr)
        return x, edge_attr

In [None]:
def get_graph_conv_layer(input_size, hidden_size, output_size):
    mlp = nn.Sequential(
        #
        nn.Linear(input_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, output_size),
    )
    gnn = EdgeConvMod(nn=mlp, aggr="add")
    graph_conv = EdgeConvBatch(gnn, output_size, batch_norm=True, dropout=0.2)
    return graph_conv

In [None]:
class Net(nn.Module):
    def __init__(self, x_input_size, adj_input_size, hidden_size, output_size):
        super().__init__()

        self.embed_x = nn.Sequential(
            nn.Embedding(x_input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size),
#             nn.ReLU(),
        )

        if adj_input_size:
            self.embed_adj = nn.Sequential(
                nn.Linear(adj_input_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.LayerNorm(hidden_size),
#                 nn.ELU(),
            )
        else:
            self.embed_adj = None

        self.graph_conv_1 = get_graph_conv_layer((2 + bool(adj_input_size)) * hidden_size, 2 * hidden_size, hidden_size)
        self.graph_conv_2 = get_graph_conv_layer(3 * hidden_size, 2 * hidden_size, hidden_size)
        self.graph_conv_3 = get_graph_conv_layer(3 * hidden_size, 2 * hidden_size, hidden_size)
        self.graph_conv_4 = get_graph_conv_layer(3 * hidden_size, 2 * hidden_size, hidden_size)
        self.linear_out = nn.Linear(hidden_size, output_size)

    def forward(self, x, edge_index, edge_attr=None):

        x = self.embed_x(x)
        edge_index, _ = remove_self_loops(edge_index)  # We should remove self loops in this case!
        edge_attr = self.embed_adj(edge_attr) if edge_attr is not None else None

        x_out, edge_attr_out = self.graph_conv_1(x, edge_index, edge_attr)
        x += x_out
        edge_attr = (edge_attr + edge_attr_out) if edge_attr is not None else edge_attr_out

        x = F.relu(x)
        edge_attr = F.relu(edge_attr)
        x_out, edge_attr_out = self.graph_conv_2(x, edge_index, edge_attr)
        x += x_out
        edge_attr += edge_attr_out

        x = F.relu(x)
        edge_attr = F.relu(edge_attr)
        x_out, edge_attr_out = self.graph_conv_3(x, edge_index, edge_attr)
        x += x_out
        edge_attr += edge_attr_out

        x = F.relu(x)
        edge_attr = F.relu(edge_attr)
        x_out, edge_attr_out = self.graph_conv_4(x, edge_index, edge_attr)
        x += x_out
        edge_attr += edge_attr_out

        x = self.linear_out(x)
        return x

In [None]:
@contextmanager
def eval_net(net: nn.Module):
    training = net.training
    try:
        net.train(False)
        yield
    finally:
        net.train(training)


def get_stats_on_missing(data, output):
    mask = (data.x == num_features).squeeze()
    output_missing = output[mask]
    _, predicted_missing = torch.max(output_missing.data, 1)
    return (predicted_missing == data.y[mask]).sum().item(), len(predicted_missing)


def get_data_x(data, frac_present):
    x = torch.where(
        torch.rand(data.y.size(0), device=data.y.device) < frac_present,
        data.y,
        torch.ones(1, dtype=torch.long, device=data.y.device) * num_features,
    )
    return x

# Pipeline

## Load network

In [None]:
# Parameters
device = torch.device("cuda:0")
# device = "cpu"
batch_size = 4
num_features = 20
adj_input_size = 2
hidden_size = 128
frac_present = 0.5
frac_present_valid = frac_present
info_size= 1024

In [None]:
net = Net(
    x_input_size=num_features + 1, adj_input_size=adj_input_size, hidden_size=hidden_size, output_size=num_features
)
net.load_state_dict(torch.load("protein_4xEdgeConv_bs4/e12-s1652709-d6610836.state"))
net.eval()
net = net.to(device)

## Load protein data

In [None]:
class ProteinData(NamedTuple):
    sequence: str
    row_index: torch.LongTensor
    col_index: torch.LongTensor
    distances: torch.FloatTensor

In [None]:
def get_interaction_dataset_wdistances(structure_file, model_id, chain_id, r_cutoff=12):
    structure = PDB.load(structure_file)
    chain = structure[0][chain_id]
    num_residues = len(list(chain.residues))
    dd = structure_tools.DomainDef(model_id, chain_id, 1, num_residues)
    domain = structure_tools.extract_domain(structure, [dd])
    distances_core = structure_tools.get_distances(domain, r_cutoff, 0, groupby="residue")
    assert (distances_core["residue_idx_1"] <= distances_core["residue_idx_2"]).all()
    return domain, distances_core

In [None]:
def extract_seq_and_adj(structure_file, chain_id):
    domain, result_df = get_interaction_dataset_wdistances(structure_file, 0, chain_id, r_cutoff=12)
    domain_sequence = structure_tools.get_chain_sequence(domain)
    assert max(result_df["residue_idx_1"].values) < len(domain_sequence)
    assert max(result_df["residue_idx_2"].values) < len(domain_sequence)
    data = ProteinData(
        domain_sequence,
        result_df["residue_idx_1"].values,
        result_df["residue_idx_2"].values,
        result_df["distance"].values,
    )
    return data

In [None]:
pdata = extract_seq_and_adj(STRUCTURE_FILE, 'A')
print(pdata)
print(len(pdata.sequence))

## Generate sequences

In [None]:
@torch.no_grad()
def design_protein_old(net, x, edge_index, edge_attr, results, x_proba=None, cutoff=-0.7):
    if x_proba is None:
        x_proba = torch.zeros_like(x).to(torch.float)

    mask = x == 20
    if not mask.any():
        if (len(results) + 1) % 100 == 0:
            print(f"Num. results: {len(results) + 1}", flush=True)
        results.append((x, x_proba))
        return

    index_array = torch.arange(x.size(0))

    output = net(x, edge_index, edge_attr)
    output = torch.softmax(output, dim=1)
    output = output[mask]
    index_array = index_array[mask]

    max_proba, max_index = output.max(dim=1)[0].max(dim=0)
    row_with_max_proba = output[max_index]

    sum_log_prob = x_proba.sum()
    assert sum_log_prob.item() <= 0, x_proba
    p_cutoff = min(torch.exp(cutoff * x.size(0) - sum_log_prob), row_with_max_proba.max()).item()

    for i, p in enumerate(row_with_max_proba):
        if p < p_cutoff:
            continue
        x_clone = x.clone()
        x_proba_clone = x_proba.clone()
        assert x_clone[index_array[max_index]] == 20
        assert x_proba_clone[index_array[max_index]] == 0
        x_clone[index_array[max_index]] = i
        x_proba_clone[index_array[max_index]] = torch.log(p)
        design_protein(net, x_clone, edge_index, edge_attr, results=results, x_proba=x_proba_clone, cutoff=cutoff)

In [None]:
import heapq
from dataclasses import dataclass, field
from typing import Any


@dataclass(order=True)
class PrioritizedItem:
    p: float
    x: Any = field(compare=False)
    x_proba: Any = field(compare=False)


@torch.no_grad()
def get_descendents(net, x, x_proba, edge_index, edge_attr, cutoff):
    index_array = torch.arange(x.size(0))
    mask = x == 20

    output = net(x, edge_index, edge_attr)
    output = torch.softmax(output, dim=1)
    output = output[mask]
    index_array = index_array[mask]

    max_proba, max_index = output.max(dim=1)[0].max(dim=0)
    row_with_max_proba = output[max_index]

    sum_log_prob = x_proba.sum()
    assert sum_log_prob.item() <= 0, x_proba
#     p_cutoff = min(torch.exp(sum_log_prob), row_with_max_proba.max()).item()

    children = []
    for i, p in enumerate(row_with_max_proba):
#         if p < p_cutoff:
#             continue
        x_clone = x.clone()
        x_proba_clone = x_proba.clone()
        assert x_clone[index_array[max_index]] == 20
        assert x_proba_clone[index_array[max_index]] == cutoff
        x_clone[index_array[max_index]] = i
        x_proba_clone[index_array[max_index]] = torch.log(p)
        children.append((x_clone, x_proba_clone))
    return children


@torch.no_grad()
def design_protein(net, x, edge_index, edge_attr, results, cutoff):
    x_proba = torch.ones_like(x).to(torch.float) * cutoff
    heap = [PrioritizedItem(0, x, x_proba)]
    i = 0
    while heap:
        item = heapq.heappop(heap)
        if i % 1000 == 0:
            print(
                f"i: {i}; p: {item.p:.4f}; num missing: {(item.x == 20).sum()}; "
                f"heap size: {len(heap):7d}; results size: {len(results)}"
            )
        if not (item.x == 20).any():
            results.append(item)
        else:
            children = get_descendents(net, item.x, item.x_proba, edge_index, edge_attr, cutoff)
            for x, x_proba in children:
                heapq.heappush(heap, PrioritizedItem(-x_proba.sum(), x, x_proba))
        i += 1
        if len(heap) > 1_000_000:
            heap = heap[:700_000]
            heapq.heapify(heap)
    return results

In [None]:
@torch.no_grad()
def get_protein_proba(net, x_ref, edge_index, edge_attr):
    x = torch.ones_like(x_ref) * 20
    x_proba = torch.zeros_like(x_ref).to(torch.float)
    index_array_ref = torch.arange(x_ref.size(0))
    mask = x == 20
    while mask.any():
        output = net(x, edge_index, edge_attr)
        output = torch.softmax(output, dim=1)
        output_for_x = output.gather(1, x_ref.view(-1, 1))

        output_for_x = output_for_x[mask]
        index_array = index_array_ref[mask]
        max_proba, max_proba_position = output_for_x.max(dim=0)

        assert x[index_array[max_proba_position]] == 20
        assert x_proba[index_array[max_proba_position]] == 0
        x[index_array[max_proba_position]] = x_ref[index_array[max_proba_position]]
        x_proba[index_array[max_proba_position]] = torch.log(max_proba)
        mask = x == 20
    return x_proba

In [None]:
data = proteinsolver.datasets.protein.row_to_data(pdata)
data = proteinsolver.datasets.protein.transform_edge_attr(data)
data.to(device)

get_protein_proba(net, data.x, data.edge_index, data.edge_attr).sum().item()

In [None]:
data = proteinsolver.datasets.protein.row_to_data(pdata)
data = proteinsolver.datasets.protein.transform_edge_attr(data)
data.to(device)

data.y = data.x
x_in = torch.ones_like(data.x) * 20
results = []
design_protein(net, x_in, data.edge_index, data.edge_attr, results=results, cutoff=np.log(0.15))
# identity_all = float((y == data.y).sum()) / data.y.size(0)
# identity_missing = float((y[~is_present] == data.y[~is_present]).sum()) / (~is_present).sum().item()
# result = {"identity_all": identity_all, "identity_missing": identity_missing}
# result

In [None]:
torch.save(results, NOTEBOOK_PATH / (STRUCTURE_FILE.stem + ".torch"))