# SARS-CoV-2 Knowledge Graph 

Load packages

In [209]:
import numpy as np
import pandas as pd
import time
import re
import math
import random
import pickle

from sklearn.model_selection import train_test_split
from sklearn import metrics 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import add_remaining_self_loops, negative_sampling
from torch_geometric.nn import VGAE, GCNConv
from torch_geometric.utils import negative_sampling

In [210]:
data_path='data/'
exp_id='v0'
device_id='cpu' #'cpu' if CPU, device number if GPU
embedding_size=128
topk=300

Load preprocessed files

In [211]:
le=pickle.load(open(data_path+'LabelEncoder_'+exp_id+'.pkl', 'rb'))
edge_index=pickle.load(open(data_path+'edge_index_'+exp_id+'.pkl','rb'))
node_feature_np=pickle.load(open(data_path+'node_feature_'+exp_id+'.pkl','rb'))

In [212]:
# For tensor representing node
node_feature=torch.tensor(node_feature_np, dtype=torch.float)
print("node_feature: ", node_feature)

node_feature:  tensor([[-0.7242, -0.4082, -0.4957,  ...,  0.5417, -0.5262, -0.3699],
        [-0.7054,  0.2955, -0.4708,  ..., -0.3246, -0.5933, -0.8806],
        [-0.7455,  0.5428,  0.4459,  ...,  0.3047,  0.5029,  0.1170],
        ...,
        [-0.1645,  0.5358,  0.1716,  ...,  0.6800, -0.1641,  0.5851],
        [-0.6400, -0.5277,  0.4265,  ...,  0.8151, -0.8159, -0.0602],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])


In [213]:
edge=torch.tensor(edge_index[['node1', 'node2']].values, dtype=torch.long)
print(edge)

tensor([[ 403,   44],
        [1287,   44],
        [1689,   44],
        ...,
        [2265,   44],
        [2225,   33],
        [2211,   54]])


In [214]:
# edge_attr_dict maps edge type to the unique numeric identifier
edge_attr_dict={'gene-drug':0,'gene-gene':1,'bait-gene':2, 'phenotype-gene':3, 'phenotype-drug':4}
edge_index['type']=edge_index['type'].apply(lambda x: edge_attr_dict[x])


In [215]:
edge_index['type'].value_counts()

type
1    14242
4      410
3      325
2      281
0      189
Name: count, dtype: int64

In [216]:
edge_attr=torch.tensor(edge_index['type'].values,dtype=torch.long)
print(edge_attr)

tensor([0, 0, 0,  ..., 4, 4, 4])


In [217]:
# Preparing a Data object for a graph, including node features, edge information (transposed and contiguous), and edge attributes
data = Data(x=node_feature,
            edge_index=edge.t().contiguous(),
            edge_attr=edge_attr
           )

## Batch

In [218]:
def train_test_split_edges(data, val_ratio=0.05, test_ratio=0.1):
    r"""Splits the edges of a :obj:`torch_geometric.data.Data` object
    into positive and negative train/val/test edges, and adds attributes of
    `train_pos_edge_index`, `train_neg_adj_mask`, `val_pos_edge_index`,
    `val_neg_edge_index`, `test_pos_edge_index`, and `test_neg_edge_index`
    to :attr:`data`.

    Args:
        data (Data): The data object.
        val_ratio (float, optional): The ratio of positive validation
            edges. (default: :obj:`0.05`)
        test_ratio (float, optional): The ratio of positive test
            edges. (default: :obj:`0.1`)

    :rtype: :class:`torch_geometric.data.Data`
    """

    assert 'batch' not in data  # No batch-mode.

    num_nodes = data.num_nodes
    row, col = data.edge_index
    #data.edge_index = None
    attr = data.edge_attr

    # Return upper triangular portion.
    #mask = row < col
    #row, col = row[mask], col[mask]

    n_v = int(math.floor(val_ratio * row.size(0)))
    n_t = int(math.floor(test_ratio * row.size(0)))

    # Positive edges.
    perm = torch.randperm(row.size(0))
    row, col = row[perm], col[perm]
    attr=attr[perm]

    r, c = row[:n_v], col[:n_v]
    data.val_pos_edge_index = torch.stack([r, c], dim=0)
    data.val_pos_edge_attr = attr[:n_v]
    
    r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t]
    data.test_pos_edge_index = torch.stack([r, c], dim=0)
    data.test_post_edge_attr = attr[n_v:n_v + n_t]

    r, c = row[n_v + n_t:], col[n_v + n_t:]
    data.train_pos_edge_index = torch.stack([r, c], dim=0)
    data.train_pos_edge_attr = attr[n_v+n_t:]

    # Negative edges.
    neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8)
    neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool)
    neg_adj_mask[row, col] = 0

    neg_row, neg_col = neg_adj_mask.nonzero().t()
    perm = random.sample(range(neg_row.size(0)),
                         min(n_v + n_t, neg_row.size(0)))
    perm = torch.tensor(perm)
    perm = perm.to(torch.long)
    neg_row, neg_col = neg_row[perm], neg_col[perm]

    neg_adj_mask[neg_row, neg_col] = 0
    data.train_neg_adj_mask = neg_adj_mask

    row, col = neg_row[:n_v], neg_col[:n_v]
    data.val_neg_edge_index = torch.stack([row, col], dim=0)

    row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t]
    data.test_neg_edge_index = torch.stack([row, col], dim=0)

    return data

In [219]:
device=torch.device(device_id)

In [220]:
data_split=train_test_split_edges(data)
x,train_pos_edge_index,train_pos_edge_attr = data_split.x.to(device), data_split.train_pos_edge_index.to(device), data_split.train_pos_edge_attr.to(device)
train_pos_edge_index, train_pos_edge_attr=add_remaining_self_loops(train_pos_edge_index,train_pos_edge_attr)
x,train_pos_edge_index,train_pos_edge_attr = Variable(x),Variable(train_pos_edge_index),Variable(train_pos_edge_attr)

## Learning models

Define GCN model

In [221]:
class Encoder_GCN(nn.Module):
    def __init__(self, in_channels, out_channels, isClassificationTask=False):
        super(Encoder_GCN, self).__init__()
        self.isClassificationTask=isClassificationTask
        self._gene_drug=  GCNConv(in_channels, 2*out_channels)
        self._gene_gene = GCNConv(in_channels, 2*out_channels)
        self._bait_gene = GCNConv(in_channels, 2*out_channels)
        self._gene_phenotype = GCNConv(in_channels, 2*out_channels)
        self._drug_phenotype = GCNConv(in_channels, 2*out_channels)

        self.bn = nn.BatchNorm1d(5*2*out_channels)
        #variational encoder
        self._mu = GCNConv(5*2*out_channels, out_channels)
        self._logvar = GCNConv(5*2*out_channels, out_channels)

    def forward(self,x,edge_index,edge_attr):
        
        x = F.dropout(x, training=self.training)

        # Extract subgraphs based on edge attributes
        edge_indices = []
        for attr_value in range(5):
            indices = (edge_attr == attr_value).nonzero().reshape(1, -1)[0]
            edge_indices.append(edge_index[:, indices])
        
        edge_index_gene_drug, edge_index_gene_gene, edge_index_bait_gene, edge_index_gene_phenotype, edge_index_drug_phenotype = edge_indices

        # gcn -> relu -> dropout
        x_gene_drug = F.dropout(F.relu(self._gene_drug(x, edge_index_gene_drug)), p=0.5, training=self.training)
        x_gene_gene = F.dropout(F.relu(self._gene_gene(x, edge_index_gene_gene)), p=0.5, training=self.training)
        x_bait_gene = F.dropout(F.relu(self._bait_gene(x, edge_index_bait_gene)), p=0.1, training=self.training)
        x_gene_phenotype = F.dropout(F.relu(self._gene_phenotype(x, edge_index_gene_phenotype)), training=self.training)
        x_drug_phenotype = F.dropout(F.relu(self._drug_phenotype(x, edge_index_drug_phenotype)), training=self.training)

        # concat
        batch_input = torch.cat([x_gene_drug, x_gene_gene, x_bait_gene, x_gene_phenotype, x_drug_phenotype], dim=1)

        # batch norm
        x = self.bn(batch_input)

        return self._mu(x, edge_index), self._logvar(x, edge_index)

In [222]:
model = VGAE(Encoder_GCN(node_feature.shape[1], embedding_size)).to(device)
optimizer=torch.optim.Adam(model.parameters(), lr=0.01)
print(model)

VGAE(
  (encoder): Encoder_GCN(
    (_gene_drug): GCNConv(400, 256)
    (_gene_gene): GCNConv(400, 256)
    (_bait_gene): GCNConv(400, 256)
    (_gene_phenotype): GCNConv(400, 256)
    (_drug_phenotype): GCNConv(400, 256)
    (bn): BatchNorm1d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (_mu): GCNConv(1280, 128)
    (_logvar): GCNConv(1280, 128)
  )
  (decoder): InnerProductDecoder()
)


In [223]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(x, train_pos_edge_index, train_pos_edge_attr)
    loss = model.recon_loss(z, train_pos_edge_index)
    loss = loss + (1 / data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    print(loss.item())
    
def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z=model.encode(x, train_pos_edge_index,train_pos_edge_attr)
    return model.test(z, pos_edge_index, neg_edge_index)

In [224]:
#DRKG's accuracy for comparison
model.test(x,data_split.test_pos_edge_index, data_split.test_neg_edge_index )

(0.8260458753926281, 0.7778310666373031)

In [225]:
for epoch in range(1, 10):
    train()
    auc, ap = test(data_split.test_pos_edge_index, data_split.test_neg_edge_index)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

30.776477813720703
Epoch: 001, AUC: 0.6098, AP: 0.5623
4225.87158203125
Epoch: 002, AUC: 0.5654, AP: 0.5352
1297.2353515625
Epoch: 003, AUC: 0.5610, AP: 0.5326
1541.6563720703125
Epoch: 004, AUC: 0.5655, AP: 0.5352
1304.3856201171875
Epoch: 005, AUC: 0.5873, AP: 0.5481
2693.078857421875
Epoch: 006, AUC: 0.6041, AP: 0.5584
1544.1636962890625
Epoch: 007, AUC: 0.6314, AP: 0.5760
3082.4345703125
Epoch: 008, AUC: 0.6436, AP: 0.5842
938.4722290039062
Epoch: 009, AUC: 0.6571, AP: 0.5936


Node embedding

In [226]:
model.eval()
z=model.encode(x, data.edge_index.to(device), data.edge_attr.to(device))
z_np = z.squeeze().detach().cpu().numpy()

Save the new embedding 

In [227]:
pickle.dump(z_np, open(data_path+'node_embedding_'+exp_id+'.pkl', 'wb'))

Save the torch model

In [228]:
torch.save(model.state_dict(), data_path+'VAE_encoders_'+exp_id+'.pkl')

In [229]:
model.load_state_dict(torch.load(data_path+'VAE_encoders_'+exp_id+'.pkl'))
model.eval()

VGAE(
  (encoder): Encoder_GCN(
    (_gene_drug): GCNConv(400, 256)
    (_gene_gene): GCNConv(400, 256)
    (_bait_gene): GCNConv(400, 256)
    (_gene_phenotype): GCNConv(400, 256)
    (_drug_phenotype): GCNConv(400, 256)
    (bn): BatchNorm1d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (_mu): GCNConv(1280, 128)
    (_logvar): GCNConv(1280, 128)
  )
  (decoder): InnerProductDecoder()
)

## Ranking model

In [230]:
from sklearn.model_selection import train_test_split
from sklearn import metrics

In [231]:
types=np.array([classtype.split('_')[0] for classtype in le.classes_ ])

Load drugs under clinical trial

In [232]:
#label
trials=pd.read_excel(data_path+'literature-mining/All_trails_5_24.xlsx',header=1,index_col=0)
trials_drug=set([drug.strip().upper() for lst in trials.loc[trials['study_category'].apply(lambda x: 'drug' in x.lower()),'intervention'].apply(lambda x: re.split(r'[+|/|,]',x.replace(' vs. ', '/').replace(' vs ', '/').replace(' or ', '/').replace(' with and without ', '/').replace(' /wo ', '/').replace(' /w ', '/').replace(' and ', '/').replace(' - ', '/').replace(' (', '/').replace(') ', '/'))).values for drug in lst])
drug_labels=[1 if drug.split('_')[1] in trials_drug else 0 for drug in le.classes_[types=='drug'] ]

In [233]:
indices = np.arange(len(drug_labels))
X_train, X_test, y_train, y_test, indices_train, indices_test = train_test_split(z_np[types=='drug'], drug_labels, indices, test_size=0.5)

In [234]:
#Variable wrapping for torch.tensor
_X_train, _y_train=Variable(torch.tensor(X_train,dtype=torch.float).to(device)), Variable(torch.tensor(y_train,dtype=torch.float).to(device))
_X_test, _y_test=Variable(torch.tensor(X_test,dtype=torch.float).to(device)), Variable(torch.tensor(y_test,dtype=torch.float).to(device))

NN

In [235]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__() 
        self.fc1=nn.Linear(128,128)
        self.fc2=nn.Linear(128,1)
        self.bn=nn.BatchNorm1d(128)
    def forward(self, x):
        residual = x
        out = F.dropout(F.relu(self.fc1(x)))
        out = self.bn(out)
        out += residual
        return self.fc2(x)        

BPR Loss Function

In [236]:
from torch.utils.data import BatchSampler, WeightedRandomSampler
class BPRLoss(nn.Module):
    def __init__(self, num_neg_samples):
        super(BPRLoss, self).__init__()
        self.num_neg_samples=num_neg_samples
    
    def forward(self, output, label):
        positive_output=output[label==1]
        negative_output=output[label!=1]

        log_prob = F.logsigmoid(positive_output.view(-1,1) - negative_output).mean()

        return -log_prob

In [237]:
clf=Classifier().to(device)
optimizer=torch.optim.Adam(clf.parameters())
criterion=BPRLoss(num_neg_samples=15)

In [238]:
best_auprc=0
for epoch in range(30):
    clf.train()
    optimizer.zero_grad()
    out = clf(_X_train)
    loss=criterion(out.squeeze(), _y_train)
    loss.backward()
    optimizer.step()   
    print('training loss',loss.item())

    clf.eval()
    print('test loss', criterion(clf(_X_test).squeeze(), _y_test).item())
    prob=torch.sigmoid(clf(_X_test)).cpu().detach().numpy().squeeze()
    auprc=metrics.average_precision_score(y_test,prob)
    if auprc>best_auprc:
        best_auproc=auprc
        torch.save(clf, data_path+'nn_clf-temp.pt')


training loss 1.3339946269989014
test loss 0.5207222104072571
training loss 1.0608125925064087
test loss 0.5801140069961548
training loss 0.9250597357749939
test loss 0.8643887042999268
training loss 1.0317251682281494
test loss 0.8753237128257751
training loss 1.0192252397537231
test loss 0.7023117542266846
training loss 0.9109266400337219
test loss 0.5697870850563049
training loss 0.8488369584083557
test loss 0.5141294002532959
training loss 0.8488251566886902
test loss 0.49606671929359436
training loss 0.8623873591423035
test loss 0.4897958040237427
training loss 0.8599364757537842
test loss 0.48686501383781433
training loss 0.8348497152328491
test loss 0.48861661553382874
training loss 0.7930968403816223
test loss 0.502312421798706
training loss 0.7478501796722412
test loss 0.5382651686668396
training loss 0.7162315845489502
test loss 0.5988163948059082
training loss 0.709517240524292
test loss 0.6545520424842834
training loss 0.7136891484260559
test loss 0.6603466868400574
trainin

In [239]:
clf.load_state_dict(torch.load(data_path+'nn_clf-temp.pt').state_dict())

<All keys matched successfully>

In [240]:
#Compute AUC
clf.eval()

prob=torch.sigmoid(clf(_X_test)).cpu().detach().numpy().squeeze()
print("AUROC", metrics.roc_auc_score(y_test,prob))
print("AUPRC", metrics.average_precision_score(y_test,prob))

AUROC 0.7142857142857143
AUPRC 0.6692628205128205


In [241]:
z_np[types=='drug']
res_tensor = -clf(torch.from_numpy(z_np[types=='drug']))
sorted_picks = np.argsort(res_tensor.squeeze().detach().numpy())
sorted_picks

array([40, 27, 29, 45, 21, 12, 15, 41, 22, 36, 25,  7, 19, 33, 38,  8, 24,
       10, 17, 39, 31,  6, 32,  3,  9, 28,  5, 37, 26, 16, 18, 46,  1, 44,
       35, 30, 14, 43, 23, 11, 42, 20,  4,  2, 34, 13,  0], dtype=int64)

Save the high-ranked drugs into csv file

In [242]:

topk_drugs=pd.DataFrame([(rank, drug.split('_')[1]) for rank,drug in enumerate(le.inverse_transform((types=='drug').nonzero()[0][sorted_picks])[:topk+1])], columns=['rank', 'drug'])
topk_drugs['under_trials']=topk_drugs['drug'].isin(trials_drug).astype(int)
topk_drugs['is_used_in_training']=topk_drugs['drug'].isin(np.array([drug.split('_')[1] for drug in le.classes_[types=='drug']])[indices_train]).astype(int)
topk_drugs.to_csv('top300_drugs.csv')