# 1. COVID-19 graph embedding alone

Load packages

In [1]:
import numpy as np
import pandas as pd
import time
import re
import math
import random
import pickle

from sklearn.model_selection import train_test_split
from sklearn import metrics 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn.modules import Module
from torch.utils.data import Dataset, DataLoader

from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import train_test_split_edges
from torch_geometric.utils import add_remaining_self_loops, add_self_loops
from torch_geometric.utils import to_undirected
from torch_geometric.nn import GCNConv, SAGEConv,GAE, VGAE

In [2]:
data_path='data/'
exp_id='v0'
device_id='cpu' #'cpu' if CPU, device number if GPU
embedding_size=128

Load preprocessed files

In [111]:
le=pickle.load(open(data_path+'LabelEncoder_'+exp_id+'.pkl', 'rb'))
edge_index=pickle.load(open(data_path+'edge_index_'+exp_id+'.pkl','rb'))
node_feature_alone_np=pickle.load(open(data_path+'dirichlet_node_feature_'+exp_id+'.pkl','rb'))

In [112]:
node_feature_alone=torch.tensor(node_feature_alone_np, dtype=torch.float)

In [113]:
edge=torch.tensor(edge_index[['node1', 'node2']].values, dtype=torch.long)

In [114]:
edge_attr_dict={'gene-drug':0,'gene-gene':1,'bait-gene':2, 'gene-phenotype':3, 'drug-phenotype':4}
edge_index['type']=edge_index['type'].apply(lambda x: edge_attr_dict[x])

In [115]:
edge_index['type'].value_counts()

0    27590
1     5971
3     1618
4     1338
2      247
Name: type, dtype: int64

In [116]:
edge_attr=torch.tensor(edge_index['type'].values,dtype=torch.long)

In [117]:
data = Data(x=node_feature_alone,
            edge_index=edge.t().contiguous(),
            edge_attr=edge_attr
           )

In [118]:
data.num_features, data.num_nodes,data.num_edges

(400, 15444, 36764)

In [34]:
data.contains_isolated_nodes(), data.is_directed()



(True, True)

In [35]:
def train_test_split_edges(data, val_ratio=0.05, test_ratio=0.1):
    r"""Splits the edges of a :obj:`torch_geometric.data.Data` object
    into positive and negative train/val/test edges, and adds attributes of
    `train_pos_edge_index`, `train_neg_adj_mask`, `val_pos_edge_index`,
    `val_neg_edge_index`, `test_pos_edge_index`, and `test_neg_edge_index`
    to :attr:`data`.

    Args:
        data (Data): The data object.
        val_ratio (float, optional): The ratio of positive validation
            edges. (default: :obj:`0.05`)
        test_ratio (float, optional): The ratio of positive test
            edges. (default: :obj:`0.1`)

    :rtype: :class:`torch_geometric.data.Data`
    """

    assert 'batch' not in data  # No batch-mode.

    num_nodes = data.num_nodes
    row, col = data.edge_index
    #data.edge_index = None
    attr = data.edge_attr

    # Return upper triangular portion.
    #mask = row < col
    #row, col = row[mask], col[mask]

    n_v = int(math.floor(val_ratio * row.size(0)))
    n_t = int(math.floor(test_ratio * row.size(0)))

    # Positive edges.
    perm = torch.randperm(row.size(0))
    row, col = row[perm], col[perm]
    attr=attr[perm]

    r, c = row[:n_v], col[:n_v]
    data.val_pos_edge_index = torch.stack([r, c], dim=0)
    data.val_pos_edge_attr = attr[:n_v]
    
    r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t]
    data.test_pos_edge_index = torch.stack([r, c], dim=0)
    data.test_post_edge_attr = attr[n_v:n_v + n_t]

    r, c = row[n_v + n_t:], col[n_v + n_t:]
    data.train_pos_edge_index = torch.stack([r, c], dim=0)
    data.train_pos_edge_attr = attr[n_v+n_t:]

    # Negative edges.
    neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8)
    neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool)
    neg_adj_mask[row, col] = 0

    neg_row, neg_col = neg_adj_mask.nonzero().t()
    perm = random.sample(range(neg_row.size(0)),
                         min(n_v + n_t, neg_row.size(0)))
    perm = torch.tensor(perm)
    perm = perm.to(torch.long)
    neg_row, neg_col = neg_row[perm], neg_col[perm]

    neg_adj_mask[neg_row, neg_col] = 0
    data.train_neg_adj_mask = neg_adj_mask

    row, col = neg_row[:n_v], neg_col[:n_v]
    data.val_neg_edge_index = torch.stack([row, col], dim=0)

    row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t]
    data.test_neg_edge_index = torch.stack([row, col], dim=0)

    return data

In [36]:
device=torch.device(device_id)

In [37]:
data_split=train_test_split_edges(data, test_ratio=0.1, val_ratio=0)
x, train_pos_edge_index, train_pos_edge_attr = data_split.x.to(device), \
    data_split.train_pos_edge_index.to(device), data_split.train_pos_edge_attr.to(device)


In [38]:
train_pos_edge_index, train_pos_edge_attr=add_remaining_self_loops(train_pos_edge_index,train_pos_edge_attr)

In [39]:
pd.Series(train_pos_edge_attr.cpu().numpy()).value_counts()

0    24838
1    20755
3     1460
4     1205
2      221
dtype: int64

In [40]:
x,train_pos_edge_index,train_pos_edge_attr = Variable(x),Variable(train_pos_edge_index),Variable(train_pos_edge_attr)

Define VGAE model

In [19]:
class Encoder_VGAE(nn.Module):
    def __init__(self, in_channels, out_channels, isClassificationTask=False):
        super(Encoder_VGAE, self).__init__()
        self.isClassificationTask=isClassificationTask
        self.conv_gene_drug=  SAGEConv(in_channels, 2*out_channels, )
        self.conv_gene_gene = SAGEConv(in_channels, 2*out_channels, )
        self.conv_bait_gene = SAGEConv(in_channels, 2*out_channels, )
        self.conv_gene_phenotype = SAGEConv(in_channels, 2*out_channels, )
        self.conv_drug_phenotype = SAGEConv(in_channels, 2*out_channels)

        self.bn = nn.BatchNorm1d(5*2*out_channels)
        #variational encoder
        self.conv_mu = SAGEConv(5*2*out_channels, out_channels, )
        self.conv_logvar = SAGEConv(5*2*out_channels, out_channels,)

    def forward(self,x,edge_index,edge_attr):
        
        x = F.dropout(x, training=self.training)
        
        index_gene_drug=(edge_attr==0).nonzero().reshape(1,-1)[0]
        edge_index_gene_drug=edge_index[:, index_gene_drug]
        
        index_gene_gene=(edge_attr==1).nonzero().reshape(1,-1)[0]
        edge_index_gene_gene=edge_index[:, index_gene_gene]
        
        index_bait_gene=(edge_attr==2).nonzero().reshape(1,-1)[0]
        edge_index_bait_gene=edge_index[:, index_bait_gene]
        
        index_gene_phenotype=(edge_attr==3).nonzero().reshape(1,-1)[0]
        edge_index_gene_phenotype=edge_index[:, index_gene_phenotype]
        
        index_drug_phenotype=(edge_attr==4).nonzero().reshape(1,-1)[0]
        edge_index_drug_phenotype=edge_index[:, index_drug_phenotype]
        
        
        x_gene_drug = F.dropout(F.relu(self.conv_gene_drug(x,edge_index_gene_drug)), p=0.5, training=self.training, )
        x_gene_gene = F.dropout(F.relu(self.conv_gene_gene(x,edge_index_gene_gene)), p=0.5, training=self.training)
        x_bait_gene = F.dropout(F.relu(self.conv_bait_gene(x,edge_index_bait_gene)), p=0.1, training=self.training)
        x_gene_phenotype = F.dropout(F.relu(self.conv_gene_phenotype(x,edge_index_gene_phenotype)), training=self.training)
        x_drug_phenotype = F.dropout(F.relu(self.conv_drug_phenotype(x,edge_index_drug_phenotype)), training=self.training)

        x=self.bn(torch.cat([x_gene_drug,x_gene_gene,x_bait_gene,x_gene_phenotype,x_drug_phenotype],dim=1))        
        
        return self.conv_mu(x,edge_index), self.conv_logvar(x,edge_index)

In [41]:
model=VGAE(Encoder_VGAE(node_feature_alone.shape[1], embedding_size)).to(device)
optimizer=torch.optim.Adam(model.parameters())

In [42]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(x, train_pos_edge_index, train_pos_edge_attr)
    loss = model.recon_loss(z, train_pos_edge_index)
    loss = loss + (1 / data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    print(loss.item())
    
def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z=model.encode(x, train_pos_edge_index,train_pos_edge_attr)
    return model.test(z, pos_edge_index, neg_edge_index)

In [43]:
# baseline
model.test(x,data_split.test_pos_edge_index, data_split.test_neg_edge_index )

(0.35942201036988447, 0.4050600513856305)

In [44]:
for epoch in range(1, 10):
    train()
    auc, ap = test(data_split.test_pos_edge_index, data_split.test_neg_edge_index)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))


14.621352195739746
Epoch: 001, AUC: 0.5389, AP: 0.5028
14.286808967590332
Epoch: 002, AUC: 0.5734, AP: 0.5276
13.90215015411377
Epoch: 003, AUC: 0.6072, AP: 0.5535
13.445852279663086
Epoch: 004, AUC: 0.6328, AP: 0.5716
13.195359230041504
Epoch: 005, AUC: 0.6706, AP: 0.6131
12.844427108764648
Epoch: 006, AUC: 0.6965, AP: 0.6473
12.30419921875
Epoch: 007, AUC: 0.7252, AP: 0.7003
11.705829620361328
Epoch: 008, AUC: 0.7405, AP: 0.7326
11.174623489379883
Epoch: 009, AUC: 0.7537, AP: 0.7623


Node embedding

In [45]:
model.eval()
z=model.encode(x, data.edge_index.to(device), data.edge_attr.to(device))
z_np = z.squeeze().detach().cpu().numpy()

Save the new embedding 

In [46]:
pickle.dump(z_np, open(data_path+'COVID_embedding_'+exp_id+'.pkl', 'wb'))

Save the torch model

In [47]:
torch.save(model.state_dict(), data_path+'COVID_VAE_encoders_'+exp_id+'.pkl')

# 2. Hybrid Embedding: COVID-19 graph + DRKG pre-train

## 2.1. embedding dimension: 128 (default)

In [119]:
embedding_size = 128
node_feature_np=pickle.load(open(data_path+'node_feature_'+exp_id+'.pkl','rb'))
node_feature=torch.tensor(node_feature_np, dtype=torch.float)

In [120]:
data = Data(x=node_feature,
            edge_index=edge.t().contiguous(),
            edge_attr=edge_attr
           )

In [121]:
data_split=train_test_split_edges(data, test_ratio=0.1, val_ratio=0)
x, train_pos_edge_index, train_pos_edge_attr = data_split.x.to(device), \
    data_split.train_pos_edge_index.to(device), data_split.train_pos_edge_attr.to(device)

In [122]:
train_pos_edge_index, train_pos_edge_attr=add_remaining_self_loops(train_pos_edge_index,train_pos_edge_attr)

In [123]:
x,train_pos_edge_index,train_pos_edge_attr = Variable(x),Variable(train_pos_edge_index),Variable(train_pos_edge_attr)

In [124]:
model=VGAE(Encoder_VGAE(node_feature.shape[1], embedding_size)).to(device)
optimizer=torch.optim.Adam(model.parameters())

In [125]:
#DRKG's accuracy for comparison
model.test(x,data_split.test_pos_edge_index, data_split.test_neg_edge_index)

(0.35835666399466704, 0.473765308861948)

In [126]:
for epoch in range(1, 10):
    train()
    auc, ap = test(data_split.test_pos_edge_index, data_split.test_neg_edge_index)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

18.6003360748291
Epoch: 001, AUC: 0.6386, AP: 0.5786
17.687053680419922
Epoch: 002, AUC: 0.6800, AP: 0.5970
16.505571365356445
Epoch: 003, AUC: 0.6984, AP: 0.6072
16.146207809448242
Epoch: 004, AUC: 0.7155, AP: 0.6272
16.748153686523438
Epoch: 005, AUC: 0.7362, AP: 0.6656
16.870141983032227
Epoch: 006, AUC: 0.7481, AP: 0.6815
16.659015655517578
Epoch: 007, AUC: 0.7700, AP: 0.7054
16.335309982299805
Epoch: 008, AUC: 0.7942, AP: 0.7402
15.86346435546875
Epoch: 009, AUC: 0.7977, AP: 0.7399


In [127]:
model.eval()
z=model.encode(x, data.edge_index.to(device), data.edge_attr.to(device))
z_np = z.squeeze().detach().cpu().numpy()

In [129]:
pickle.dump(z_np, open(data_path+f'hybrid_embedding_{embedding_size}_'+exp_id+'.pkl', 'wb'))

In [130]:
torch.save(model.state_dict(), data_path+f'hybrid_VAE_encoders_{embedding_size}_'+exp_id+'.pkl')

## 2.2. embedding dimension: 64

In [131]:
embedding_size = 64
model=VGAE(Encoder_VGAE(node_feature.shape[1], embedding_size)).to(device)
optimizer=torch.optim.Adam(model.parameters())

In [132]:
#DRKG's accuracy for comparison
model.test(x,data_split.test_pos_edge_index, data_split.test_neg_edge_index)

(0.35835666399466704, 0.473765308861948)

In [133]:
for epoch in range(1, 10):
    train()
    auc, ap = test(data_split.test_pos_edge_index, data_split.test_neg_edge_index)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

15.738143920898438
Epoch: 001, AUC: 0.6779, AP: 0.6313
14.36874008178711
Epoch: 002, AUC: 0.7193, AP: 0.7028
13.348702430725098
Epoch: 003, AUC: 0.7569, AP: 0.7800
12.40897274017334
Epoch: 004, AUC: 0.7787, AP: 0.8067
12.35556697845459
Epoch: 005, AUC: 0.7993, AP: 0.8266
11.076025009155273
Epoch: 006, AUC: 0.8133, AP: 0.8369
10.4027099609375
Epoch: 007, AUC: 0.8073, AP: 0.8198
9.729873657226562
Epoch: 008, AUC: 0.7676, AP: 0.7744
9.39797306060791
Epoch: 009, AUC: 0.7551, AP: 0.7595


In [134]:
model.eval()
z=model.encode(x, data.edge_index.to(device), data.edge_attr.to(device))
z_np = z.squeeze().detach().cpu().numpy()

In [136]:
pickle.dump(z_np, open(data_path+f'hybrid_embedding_{embedding_size}_'+exp_id+'.pkl', 'wb'))

In [137]:
torch.save(model.state_dict(), data_path+f'hybrid_VAE_encoders_{embedding_size}_'+exp_id+'.pkl')

## 2.3. embedding dimension: 256

In [67]:
embedding_size = 256
model=VGAE(Encoder_VGAE(node_feature.shape[1], embedding_size)).to(device)
optimizer=torch.optim.Adam(model.parameters())

In [68]:
#DRKG's accuracy for comparison
model.test(x,data_split.test_pos_edge_index, data_split.test_neg_edge_index)

(0.3765956514686327, 0.4823402969907613)

In [69]:
for epoch in range(1, 10):
    train()
    auc, ap = test(data_split.test_pos_edge_index, data_split.test_neg_edge_index)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

21.637834548950195
Epoch: 001, AUC: 0.5987, AP: 0.5107
18.98118782043457
Epoch: 002, AUC: 0.6702, AP: 0.6053
18.94562339782715
Epoch: 003, AUC: 0.6729, AP: 0.6091
19.052785873413086
Epoch: 004, AUC: 0.6692, AP: 0.6064
19.239566802978516
Epoch: 005, AUC: 0.6709, AP: 0.6080
18.930299758911133
Epoch: 006, AUC: 0.6842, AP: 0.6191
18.367544174194336
Epoch: 007, AUC: 0.7063, AP: 0.6377
17.764646530151367
Epoch: 008, AUC: 0.7266, AP: 0.6554
17.29031753540039
Epoch: 009, AUC: 0.7442, AP: 0.6719


In [70]:
model.eval()
z=model.encode(x, data.edge_index.to(device), data.edge_attr.to(device))
z_np = z.squeeze().detach().cpu().numpy()

In [71]:
pickle.dump(z_np, open(data_path+f'hybrid_embedding_{embedding_size}_'+exp_id+'.pkl', 'wb'))

In [72]:
torch.save(model.state_dict(), data_path+f'hybrid_VAE_encoders_{embedding_size}_'+exp_id+'.pkl')

## 2.4. default embedding size (128) without bait-prey information

In [74]:
edge_index = pickle.load(open(data_path+'edge_index_no_bp_'+exp_id+'.pkl','rb'))
node_feature_no_bp_np = pickle.load(open(data_path+'node_feature_no_bp_'+exp_id+'.pkl','rb'))

In [75]:
node_feature_no_bp = torch.tensor(node_feature_no_bp_np, dtype=torch.float)

In [78]:
edge=torch.tensor(edge_index[['node1', 'node2']].values, dtype=torch.long)

In [79]:
edge_attr_dict={'gene-drug':0,'gene-gene':1,'gene-phenotype':3, 'drug-phenotype':4}
edge_index['type']=edge_index['type'].apply(lambda x: edge_attr_dict[x])

In [81]:
edge_index['type'].value_counts()

0    27590
1     5971
3     1618
4     1338
Name: type, dtype: int64

In [94]:
edge_attr=torch.tensor(edge_index['type'].values,dtype=torch.long)

In [96]:
data = Data(x=node_feature_no_bp,
            edge_index=edge.t().contiguous(),
            edge_attr=edge_attr
           )

In [97]:
data.num_features, data.num_nodes,data.num_edges

(400, 15444, 36517)

In [98]:
data_split=train_test_split_edges(data, test_ratio=0.1, val_ratio=0)
x, train_pos_edge_index, train_pos_edge_attr = data_split.x.to(device), \
    data_split.train_pos_edge_index.to(device), data_split.train_pos_edge_attr.to(device)


In [99]:
train_pos_edge_index, train_pos_edge_attr=add_remaining_self_loops(train_pos_edge_index,train_pos_edge_attr)

In [100]:
x,train_pos_edge_index,train_pos_edge_attr = Variable(x),Variable(train_pos_edge_index),Variable(train_pos_edge_attr)

In [101]:
class Encoder_VGAE_no_bp(nn.Module):
    def __init__(self, in_channels, out_channels, isClassificationTask=False):
        super(Encoder_VGAE_no_bp, self).__init__()
        self.isClassificationTask=isClassificationTask
        self.conv_gene_drug=  SAGEConv(in_channels, 2*out_channels, )
        self.conv_gene_gene = SAGEConv(in_channels, 2*out_channels, )
        self.conv_gene_phenotype = SAGEConv(in_channels, 2*out_channels, )
        self.conv_drug_phenotype = SAGEConv(in_channels, 2*out_channels)

        self.bn = nn.BatchNorm1d(4*2*out_channels)
        #variational encoder
        self.conv_mu = SAGEConv(4*2*out_channels, out_channels, )
        self.conv_logvar = SAGEConv(4*2*out_channels, out_channels,)

    def forward(self,x,edge_index,edge_attr):
        
        x = F.dropout(x, training=self.training)
        
        index_gene_drug=(edge_attr==0).nonzero().reshape(1,-1)[0]
        edge_index_gene_drug=edge_index[:, index_gene_drug]
        
        index_gene_gene=(edge_attr==1).nonzero().reshape(1,-1)[0]
        edge_index_gene_gene=edge_index[:, index_gene_gene]
        
        index_gene_phenotype=(edge_attr==3).nonzero().reshape(1,-1)[0]
        edge_index_gene_phenotype=edge_index[:, index_gene_phenotype]
        
        index_drug_phenotype=(edge_attr==4).nonzero().reshape(1,-1)[0]
        edge_index_drug_phenotype=edge_index[:, index_drug_phenotype]
        
        
        x_gene_drug = F.dropout(F.relu(self.conv_gene_drug(x,edge_index_gene_drug)), p=0.5, training=self.training, )
        x_gene_gene = F.dropout(F.relu(self.conv_gene_gene(x,edge_index_gene_gene)), p=0.5, training=self.training)
        x_gene_phenotype = F.dropout(F.relu(self.conv_gene_phenotype(x,edge_index_gene_phenotype)), training=self.training)
        x_drug_phenotype = F.dropout(F.relu(self.conv_drug_phenotype(x,edge_index_drug_phenotype)), training=self.training)

        x=self.bn(torch.cat([x_gene_drug,x_gene_gene,x_gene_phenotype,x_drug_phenotype],dim=1))        
        
        return self.conv_mu(x,edge_index), self.conv_logvar(x,edge_index)

In [102]:
embedding_size = 128
model=VGAE(Encoder_VGAE_no_bp(node_feature.shape[1], embedding_size)).to(device)
optimizer=torch.optim.Adam(model.parameters())

In [103]:
#DRKG's accuracy for comparison
model.test(x,data_split.test_pos_edge_index, data_split.test_neg_edge_index)

(0.40109878609590643, 0.4553820212669144)

In [104]:
for epoch in range(1, 10):
    train()
    auc, ap = test(data_split.test_pos_edge_index, data_split.test_neg_edge_index)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

19.09906005859375
Epoch: 001, AUC: 0.4841, AP: 0.4471
18.170408248901367
Epoch: 002, AUC: 0.5105, AP: 0.4581
18.1473331451416
Epoch: 003, AUC: 0.5472, AP: 0.4769
17.576248168945312
Epoch: 004, AUC: 0.5846, AP: 0.4984
17.136661529541016
Epoch: 005, AUC: 0.6121, AP: 0.5166
15.71886157989502
Epoch: 006, AUC: 0.6325, AP: 0.5338
15.182626724243164
Epoch: 007, AUC: 0.6502, AP: 0.5560
14.717133522033691
Epoch: 008, AUC: 0.6614, AP: 0.5829
14.302988052368164
Epoch: 009, AUC: 0.6656, AP: 0.6041


In [105]:
model.eval()
z=model.encode(x, data.edge_index.to(device), data.edge_attr.to(device))
z_np = z.squeeze().detach().cpu().numpy()

In [106]:
pickle.dump(z_np, open(data_path+f'hybrid_embedding_{embedding_size}_no_bp_'+exp_id+'.pkl', 'wb'))

In [107]:
torch.save(model.state_dict(), data_path+f'hybrid_VAE_encoders_{embedding_size}_no_bp_'+exp_id+'.pkl')