In [1]:
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu102.html --quiet
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+cu102.html --quiet
!pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.7.0+cu102.html --quiet
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.7.0+cu102.html --quiet
!pip install torch-geometric==1.6.3 --quiet

In [2]:
import numpy as np

from typing import Tuple

from torch_geometric.data import Data, GraphSAINTRandomWalkSampler

import torch
from torch import Tensor
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Parameter as Param
from torch.nn import Parameter
from torch_scatter import scatter
from torch_sparse import matmul, SparseTensor
from torch_geometric.nn.conv import MessagePassing

from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score, average_precision_score

from tqdm.notebook import tqdm

import plotly.graph_objects as go

import json
import pickle

In [3]:
def masked_edge_index(edge_index, edge_mask):
    if isinstance(edge_index, Tensor):
        return edge_index[:, edge_mask]

In [4]:
path_f = "../input/data-loading-and-init-features/"

cheat_features = torch.load(path_f + "cheat_features.pt", map_location="cpu")
random_features = torch.load(path_f + "random_features.pt", map_location="cpu")
node2vec_ebeddings = torch.load(path_f + "node2vec_ebeddings.pt", map_location="cpu")
data = torch.load(path_f + "data.pt", map_location=torch.device("cpu"))

with open(path_f + "edge_type_mapping.p", "rb") as fp:
    edge_type_mapping = pickle.load(fp)
with open(path_f + "drug_nodes_mapping.p", "rb") as fp:
    drug_nodes_mapping = pickle.load(fp)
with open(path_f + "protein_nodes_mapping.p", "rb") as fp:
    protein_nodes_mapping = pickle.load(fp)
    
data.x = torch.tensor(cheat_features)

In [5]:
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.kaiming_uniform_(m.weight)
        m.bias.data.fill_(0.)

In [6]:
class RGCNConv(MessagePassing):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        num_relations: int,
        aggr: str = "mean",
    ):  

        super(RGCNConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations

        in_channels = (in_channels, in_channels)
        self.in_channels_l = in_channels[0]

        self.weight = Parameter(torch.Tensor(num_relations, in_channels[0], out_channels))
        self.register_parameter("comp", None)

        self.root = Param(torch.Tensor(in_channels[1], out_channels))
        self.bias = Param(torch.Tensor(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight)
        nn.init.kaiming_uniform_(self.root)
        nn.init.zeros_(self.bias)

    def forward(
        self,
        x,
        edge_index,
        edge_type = None,
    ):
        
        x_l = x
        x_r = x_l

        size = (x_l.size(0), x_r.size(0))
        out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device)

        weight = self.weight
        
        for i in range(self.num_relations):
            tmp = masked_edge_index(edge_index, edge_type == i)
            h = self.propagate(tmp, x=x_l, size=size)
            out = out + (h @ weight[i])

        out += self.root[x_r] if x_r.dtype == torch.long else x_r @ self.root
        out += self.bias

        return out

    def message(self, x_j: Tensor):
        return x_j

    def message_and_aggregate(self, adj_t, x):
        adj_t = adj_t.set_value(None, layout=None)
        return matmul(adj_t, x, reduce=self.aggr)

In [7]:
class RGCNEncoder(torch.nn.Module):
    def __init__(
        self,
        depth,
        dimensions,
        relation_len,
        decoder_type="fact",
        h_drug_mlp=[],
        h_rest_mlp=[],
        other_rel_len=3
    ):
        super(RGCNEncoder, self).__init__()
        self.rgcns = nn.ModuleList(
            [
                RGCNConv(dimensions[i], dimensions[i + 1], relation_len)
                for i in range(depth)
            ]
        )

        if decoder_type == "fact":
            self.w_rels = nn.Parameter(torch.Tensor(relation_len, dimensions[-1]))
            nn.init.kaiming_uniform_(self.w_rels, nonlinearity="relu")

        elif decoder_type == "mlp":
            self.decoder = MLPDecoder(h_drug_mlp, h_rest_mlp, dimensions[-1])
            self.decoder.apply(init_weights)

        else:
            self.decoder = DecaDecoder(dimensions[-1],relation_len=relation_len,other_rel_len=other_rel_len)

    def forward(self, x, edge_index, edge_type):
        
        depth = len(self.rgcns)
        
        for i, rgcn in enumerate(self.rgcns):
            x = rgcn(x, edge_index, edge_type)
            x = F.leaky_relu(x) if i != depth - 1 else F.log_softmax(x, dim=-1)

        return x

In [8]:
class MLPDecoder(torch.nn.Module):
    def __init__(self, h_drug_mlp_list: list, h_rest_mlp_list: list, out_h:int):
        
        super(MLPDecoder, self).__init__()
        
        layers = []
        layers.append(nn.Linear(out_h * 2, h_drug_mlp_list[0]))
        
        for i, h in enumerate(h_drug_mlp_list[:-1]):
            layers.append(nn.Linear(h_drug_mlp_list[i], h_drug_mlp_list[i+1]))
            layers.append(nn.ReLU()) if not i == len(h_drug_mlp_list) - 2 else None
            
        self.drug_mlp = nn.Sequential(*layers)
        
        layers = []
        layers.append(nn.Linear(out_h * 2, h_rest_mlp_list[0]))
            
        for i, h in enumerate(h_rest_mlp_list[:-1]):
            layers.append(nn.Linear(h_rest_mlp_list[i], h_rest_mlp_list[i+1]))
            layers.append(nn.ReLU()) if not i == len(h_rest_mlp_list) - 2 else None
            
        self.rest_mlp = nn.Sequential(*layers)

    def forward(self, embed, edge_index, edge_type):

        for edge in range(len(edge_index[0])):
            concat_embedding = torch.cat([embed[edge_index[0, edge]], embed[edge_index[1, edge]]], dim=-1)
            
            scores = torch.tensor([], device=device)
            
            if edge_type[edge] < 1317: # if drug-drug pair
                scores = torch.cat(scores, F.log_softmax(self.drug_mlp(concat_embedding).squeeze(1)[edge_type[edge]], dim=-1))
            else: # any other pair
                scores = torch.cat(scores, F.log_softmax(self.rest_mlp(concat_embedding).squeeze(1)), dim=-1)
                
        return scores

class DecaDecoder(torch.nn.Module):
    def __init__(self,embed_dim, relation_len,other_rel_len):
        super(DecaDecoder, self).__init__()
        self.R = torch.nn.Parameter(torch.randn(embed_dim,embed_dim)).to(device)
        self.D = [torch.nn.Parameter(torch.randn(embed_dim,embed_dim)).to(device) for i in range(relation_len)]
        self.M = [torch.nn.Parameter(torch.randn(embed_dim,embed_dim)).to(device) for i in range(other_rel_len)]
    
    def forward(self, embed, edge_index, edge_type):
        node_1, node_2 = (edge_index[0, :],edge_index[1, :])
        emb_1, emb_2 = (embed[node_1],embed[node_2])
        print(node_1.shape,emb_1.shape, edge_index.shape[1], edge_type.shape)
        scores = torch.stack([ emb_1[i]@self.D[edge_type[i]]@self.R@self.D[edge_type[i]]@emb_2[i].t() if edge_type[i] < 1317 else emb_1[i]@self.M[edge_type[i]-1317]@emb_2[i].t() for i in range(edge_index.shape[1])],dim=-1)

        scores = torch.sigmoid(scores)
        return scores

In [9]:
def DistMult(embed, edge_index, edge_type, model):
    s = embed[edge_index[0, :]]
    o = embed[edge_index[1, :]]
    r = model.w_rels[edge_type]
    scores = torch.sum(s * r * o, dim=1)

    return torch.sigmoid(scores)

In [10]:
def get_metrics(model, embed, edge_index, edge_type, labels):
    probs = DistMult(embed, edge_index, edge_type, model)
    # probs = model.decoder(embed, edge_index, edge_type)

    # print(probs)
    
    loss = F.binary_cross_entropy(probs, labels)

    probs = probs.cpu().detach().numpy()
    labels = labels.cpu().detach().numpy()

    return loss, probs, labels

In [11]:
def get_link_labels(edge_index_pos_len, edge_index_neg_len):
    link_labels = (torch.zeros(edge_index_pos_len + edge_index_neg_len).float().to(device))
    link_labels[: int(edge_index_pos_len)] = 1.0
    return link_labels

In [12]:
def get_embeddings(data):
    x = data.x
    edge_index_pos = data.edge_index
    edge_type = torch.squeeze(data.edge_type)
    embed = model(x, edge_index_pos, edge_type)

    return embed

In [13]:
def negative_sample(edge_index, edge_meta_type):
    """
    generate negative samples but keep the node type the same
    """
    edge_index_copy = edge_index.clone()

    # resample ddi, the meta edge type for ddi is 1
    ddi = edge_index_copy[0, torch.squeeze(edge_meta_type == 1)]
    new_index = torch.randperm(ddi.shape[0])
    new_ddi = ddi[new_index]
    edge_index_copy[0, torch.squeeze(edge_meta_type == 1)] = new_ddi

    # resample ppi, the meta edge type for ppi is 2
    ppi = edge_index_copy[0, torch.squeeze(edge_meta_type == 2)]
    new_index = torch.randperm(ppi.shape[0])
    new_ppi = ppi[new_index]
    edge_index_copy[0, torch.squeeze(edge_meta_type == 2)] = new_ppi

    # resample target, the meta edge type for target is 3
    target = edge_index_copy[0, torch.squeeze(edge_meta_type == 3)]
    new_index = torch.randperm(target.shape[0])
    new_target = target[new_index]
    edge_index_copy[0, torch.squeeze(edge_meta_type == 3)] = new_target
    
    # resample targeted_by, the meta edge type for targeted_by is 4
    targeted_by = edge_index_copy[0, torch.squeeze(edge_meta_type == 4)]
    new_index = torch.randperm(targeted_by.shape[0])
    new_targeted_by = targeted_by[new_index]
    edge_index_copy[0, torch.squeeze(edge_meta_type == 4)] = new_targeted_by

    return edge_index_copy

In [14]:
params = {
    "depth": 2,
    "dimensions": (data.x.size(1), 64, 64),
    "relation_len": data.edge_type.unique().size(0),
    "decoder_type": "fact",
    "epochs": 50,
}

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RGCNEncoder(params["depth"], params["dimensions"], params["relation_len"], params["decoder_type"]).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, weight_decay=5e-4)

In [16]:
def train(data, embed):
    x = data.x
    
    edge_index_train_pos = data.edge_index[:, data.train_mask]
    edge_type_train = torch.squeeze(data.edge_type[data.train_mask])
    
    edge_meta_type = data.edge_meta_type[data.train_mask]
    edge_index_train_neg = negative_sample(edge_index_train_pos, edge_meta_type)

    edge_index_train_total = torch.cat([edge_index_train_pos, edge_index_train_neg], dim=-1)
    
    edge_type_train_total = torch.cat([edge_type_train, edge_type_train[:edge_index_train_neg.size(1)]], dim=-1)


    link_labels = get_link_labels(edge_index_train_pos.size(1), edge_index_train_neg.size(1))
    
    loss, probs, labels = get_metrics(model, embed, edge_index_train_total, edge_type_train_total, 
                                            link_labels)
    
    auroc = roc_auc_score(labels, probs)
    auprc = average_precision_score(labels, probs)
    
    loss_epoch_train.append(loss.item())
    auroc_epoch_train.append(auroc)
    
    loss.backward()
    optimizer.step()

In [17]:
@torch.no_grad()
def validation(data, embed, evaluate_rel=False):
    x = data.x
    
    edge_index_val_pos = data.edge_index[:, data.val_mask]
    edge_type_val = torch.squeeze(data.edge_type[data.val_mask])
    
    edge_meta_type = data.edge_meta_type[data.val_mask]
    edge_index_val_neg = negative_sample(edge_index_val_pos, edge_meta_type)
    
    edge_index_val_total = torch.cat([edge_index_val_pos, edge_index_val_neg], dim=-1)
    edge_type_val_total = torch.cat([edge_type_val, edge_type_val[:edge_index_val_neg.size(1)]], dim=-1)
    
    link_labels = get_link_labels(edge_index_val_pos.size(1), edge_index_val_neg.size(1))
    loss, probs, labels = get_metrics(model, embed, edge_index_val_total, edge_type_val_total, 
                                                                link_labels)
    auroc = roc_auc_score(labels, probs)
    auprc = average_precision_score(labels, probs)
    
    edge_type_val_total = edge_type_val_total.detach().cpu()
    
    loss_epoch_val.append(loss.item())
    auroc_epoch_val.append(auroc)
    
    if not evaluate_rel:
        return
    
    for i in range(num_relations):
        mask = (edge_type_val_total == i)
        if mask.sum() == 0:
            continue
        probs_per_rel = probs[mask]
        labels_per_rel = labels[mask]
        auroc_per_rel = roc_auc_score(labels_per_rel, probs_per_rel)
        auroc_edge_type[i].append(auroc_per_rel)

In [18]:
data_loader = GraphSAINTRandomWalkSampler(data, batch_size=128, walk_length=16, num_steps=32)

In [19]:
loss_train_total, loss_val_total = [], []
auroc_train_total, auroc_val_total = [], []

for epoch in range(0, params["epochs"]):
    loss_epoch_train, loss_epoch_val = [], []
    auroc_epoch_train, auroc_epoch_val = [], []

    for batch in data_loader:

        optimizer.zero_grad()
        model.train()
        embed = get_embeddings(batch)
        train(batch, embed)
        model.eval()
        validation(batch, embed)

    loss_train_total.append(np.mean(loss_epoch_train))
    auroc_train_total.append(np.mean(auroc_epoch_train))
    loss_val_total.append(np.mean(loss_epoch_val))
    auroc_val_total.append(np.mean(auroc_epoch_val))

    print(
        "Epoch: {} | train loss: {} | train auroc: {} |".format(
            epoch + 1,
            "%.3f" % np.mean(loss_epoch_train),
            "%.3f" % np.mean(auroc_epoch_train),
        )
    )
    print(
        "Epoch: {} | val loss: {} | val auroc: {} |".format(
            epoch + 1,
            "%.3f" % np.mean(loss_epoch_val),
            "%.3f" % np.mean(auroc_epoch_val),
        )
    )

    print("---------------------------------------------------------------------------")

Epoch: 1 | train loss: 3.367 | train auroc: 0.494 |
Epoch: 1 | val loss: 3.160 | val auroc: 0.494 |
---------------------------------------------------------------------------
Epoch: 2 | train loss: 0.805 | train auroc: 0.498 |
Epoch: 2 | val loss: 0.799 | val auroc: 0.498 |
---------------------------------------------------------------------------
Epoch: 3 | train loss: 0.723 | train auroc: 0.501 |
Epoch: 3 | val loss: 0.722 | val auroc: 0.501 |
---------------------------------------------------------------------------
Epoch: 4 | train loss: 0.697 | train auroc: 0.501 |
Epoch: 4 | val loss: 0.697 | val auroc: 0.501 |
---------------------------------------------------------------------------
Epoch: 5 | train loss: 0.695 | train auroc: 0.501 |
Epoch: 5 | val loss: 0.695 | val auroc: 0.501 |
---------------------------------------------------------------------------
Epoch: 6 | train loss: 0.694 | train auroc: 0.500 |
Epoch: 6 | val loss: 0.694 | val auroc: 0.500 |
--------------------

KeyboardInterrupt: 

In [20]:
from plotly.subplots import make_subplots

fig = make_subplots(specs=[[{"secondary_y": True}]])

fig.add_trace(
    go.Scatter(
        x=list(range(1, params["epochs"] + 1)),
        y=loss_train_total,
        mode="lines",
        name="Training Loss",
    ),
    secondary_y=False,
)

fig.add_trace(
    go.Scatter(
        x=list(range(1, params["epochs"] + 1)),
        y=loss_val_total,
        mode="lines",
        name="Validation Loss",
    ),
    secondary_y=False,
)

fig.add_trace(
    go.Scatter(
        x=list(range(1, params["epochs"] + 1)),
        y=auroc_train_total,
        mode="lines",
        name="Training AUROC",
    ),
    secondary_y=True,
)

fig.add_trace(
    go.Scatter(
        x=list(range(1, params["epochs"] + 1)),
        y=auroc_val_total,
        mode="lines",
        name="Validation AUROC",
    ),
    secondary_y=True,
)

# Set x-axis title
fig.update_xaxes(title_text="<b>Epochs</b>")

# Set y-axes titles
fig.update_yaxes(title_text="<b>Loss</b>", secondary_y=False)
fig.update_yaxes(title_text="<b>AUROC</b>", secondary_y=True)


fig.update_layout(
    title="Loss/AUROC plot",
    autosize=False,
    width=700,
    height=500,
)
fig.show()

In [21]:
data = data.to(device)

embed = get_embeddings(data)
ys = data.y.cpu().detach().numpy()

drug_index = np.where(ys == 0)
protien_index = np.where(ys == 1)

tsne_embedding = TSNE(exaggeration=1, n_jobs=4).fit(embed.cpu().detach().numpy())

In [22]:
def plot(embedding, drug_index, protien_index):
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=embedding[drug_index, 0][0],
            y=embedding[drug_index, 1][0],
            mode="markers",
            name="drug",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=embedding[protien_index, 0][0],
            y=embedding[protien_index, 1][0],
            mode="markers",
            name="protien",
        )
    )

    fig.update_layout(
        autosize=False,
        width=700,
        height=700,
        xaxis=dict(visible=False),
        yaxis=dict(visible=False)
    )
    return fig

In [23]:
fig = plot(tsne_embedding, drug_index, protien_index)
fig.show()