# COVID-19 graph embedding alone

Load packages

In [1]:
import numpy as np
import pandas as pd
import time
import re
import math
import random
import pickle

from sklearn.model_selection import train_test_split
from sklearn import metrics 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn.modules import Module
from torch.utils.data import Dataset, DataLoader

from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import train_test_split_edges
from torch_geometric.utils import add_remaining_self_loops, add_self_loops
from torch_geometric.utils import to_undirected
from torch_geometric.nn import GCNConv, SAGEConv,GAE, VGAE

In [2]:
data_path='data/'
exp_id='v0'
device_id='cpu' #'cpu' if CPU, device number if GPU
embedding_size=128

Load preprocessed files

In [3]:
le=pickle.load(open(data_path+'LabelEncoder_'+exp_id+'.pkl', 'rb'))
edge_index=pickle.load(open(data_path+'edge_index_'+exp_id+'.pkl','rb'))
node_feature_alone_np=pickle.load(open(data_path+'empty_node_feature_'+exp_id+'.pkl','rb'))

In [4]:
node_feature_alone=torch.tensor(node_feature_alone_np, dtype=torch.float)

In [5]:
edge=torch.tensor(edge_index[['node1', 'node2']].values, dtype=torch.long)

In [6]:
edge_attr_dict={'gene-drug':0,'gene-gene':1,'bait-gene':2, 'gene-phenotype':3, 'drug-phenotype':4}
edge_index['type']=edge_index['type'].apply(lambda x: edge_attr_dict[x])

In [7]:
edge_index['type'].value_counts()

0    27590
1     5971
3     1618
4     1338
2      247
Name: type, dtype: int64

In [8]:
edge_attr=torch.tensor(edge_index['type'].values,dtype=torch.long)

In [9]:
data = Data(x=node_feature_alone,
            edge_index=edge.t().contiguous(),
            edge_attr=edge_attr
           )

In [10]:
data.num_features, data.num_nodes,data.num_edges

(400, 15444, 36764)

In [12]:
data.x[0,:5]

tensor([0., 0., 0., 0., 0.])

In [13]:
edge_attr.size()

torch.Size([36764])

In [14]:
data.contains_isolated_nodes(), data.is_directed()



(True, True)

## Batch

In [15]:
def train_test_split_edges(data, val_ratio=0.05, test_ratio=0.1):
    r"""Splits the edges of a :obj:`torch_geometric.data.Data` object
    into positive and negative train/val/test edges, and adds attributes of
    `train_pos_edge_index`, `train_neg_adj_mask`, `val_pos_edge_index`,
    `val_neg_edge_index`, `test_pos_edge_index`, and `test_neg_edge_index`
    to :attr:`data`.

    Args:
        data (Data): The data object.
        val_ratio (float, optional): The ratio of positive validation
            edges. (default: :obj:`0.05`)
        test_ratio (float, optional): The ratio of positive test
            edges. (default: :obj:`0.1`)

    :rtype: :class:`torch_geometric.data.Data`
    """

    assert 'batch' not in data  # No batch-mode.

    num_nodes = data.num_nodes
    row, col = data.edge_index
    #data.edge_index = None
    attr = data.edge_attr

    # Return upper triangular portion.
    #mask = row < col
    #row, col = row[mask], col[mask]

    n_v = int(math.floor(val_ratio * row.size(0)))
    n_t = int(math.floor(test_ratio * row.size(0)))

    # Positive edges.
    perm = torch.randperm(row.size(0))
    row, col = row[perm], col[perm]
    attr=attr[perm]

    r, c = row[:n_v], col[:n_v]
    data.val_pos_edge_index = torch.stack([r, c], dim=0)
    data.val_pos_edge_attr = attr[:n_v]
    
    r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t]
    data.test_pos_edge_index = torch.stack([r, c], dim=0)
    data.test_post_edge_attr = attr[n_v:n_v + n_t]

    r, c = row[n_v + n_t:], col[n_v + n_t:]
    data.train_pos_edge_index = torch.stack([r, c], dim=0)
    data.train_pos_edge_attr = attr[n_v+n_t:]

    # Negative edges.
    neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8)
    neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool)
    neg_adj_mask[row, col] = 0

    neg_row, neg_col = neg_adj_mask.nonzero().t()
    perm = random.sample(range(neg_row.size(0)),
                         min(n_v + n_t, neg_row.size(0)))
    perm = torch.tensor(perm)
    perm = perm.to(torch.long)
    neg_row, neg_col = neg_row[perm], neg_col[perm]

    neg_adj_mask[neg_row, neg_col] = 0
    data.train_neg_adj_mask = neg_adj_mask

    row, col = neg_row[:n_v], neg_col[:n_v]
    data.val_neg_edge_index = torch.stack([row, col], dim=0)

    row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t]
    data.test_neg_edge_index = torch.stack([row, col], dim=0)

    return data

In [16]:
device=torch.device(device_id)

In [17]:
data_split=train_test_split_edges(data, test_ratio=0.1, val_ratio=0)
x, train_pos_edge_index, train_pos_edge_attr = data_split.x.to(device), \
    data_split.train_pos_edge_index.to(device), data_split.train_pos_edge_attr.to(device)


In [18]:
train_pos_edge_index, train_pos_edge_attr=add_remaining_self_loops(train_pos_edge_index,train_pos_edge_attr)

In [19]:
pd.Series(train_pos_edge_attr.cpu().numpy()).value_counts()

0    24835
1    20747
3     1459
4     1213
2      226
dtype: int64

In [20]:
x,train_pos_edge_index,train_pos_edge_attr = Variable(x),Variable(train_pos_edge_index),Variable(train_pos_edge_attr)

## Learning VGAE model

Define VGAE model

In [21]:
class Encoder_VGAE(nn.Module):
    def __init__(self, in_channels, out_channels, isClassificationTask=False):
        super(Encoder_VGAE, self).__init__()
        self.isClassificationTask=isClassificationTask
        self.conv_gene_drug=  SAGEConv(in_channels, 2*out_channels, )
        self.conv_gene_gene = SAGEConv(in_channels, 2*out_channels, )
        self.conv_bait_gene = SAGEConv(in_channels, 2*out_channels, )
        self.conv_gene_phenotype = SAGEConv(in_channels, 2*out_channels, )
        self.conv_drug_phenotype = SAGEConv(in_channels, 2*out_channels)

        self.bn = nn.BatchNorm1d(5*2*out_channels)
        #variational encoder
        self.conv_mu = SAGEConv(5*2*out_channels, out_channels, )
        self.conv_logvar = SAGEConv(5*2*out_channels, out_channels,)

    def forward(self,x,edge_index,edge_attr):
        
        x = F.dropout(x, training=self.training)
        
        index_gene_drug=(edge_attr==0).nonzero().reshape(1,-1)[0]
        edge_index_gene_drug=edge_index[:, index_gene_drug]
        
        index_gene_gene=(edge_attr==1).nonzero().reshape(1,-1)[0]
        edge_index_gene_gene=edge_index[:, index_gene_gene]
        
        index_bait_gene=(edge_attr==2).nonzero().reshape(1,-1)[0]
        edge_index_bait_gene=edge_index[:, index_bait_gene]
        
        index_gene_phenotype=(edge_attr==3).nonzero().reshape(1,-1)[0]
        edge_index_gene_phenotype=edge_index[:, index_gene_phenotype]
        
        index_drug_phenotype=(edge_attr==4).nonzero().reshape(1,-1)[0]
        edge_index_drug_phenotype=edge_index[:, index_drug_phenotype]
        
        
        x_gene_drug = F.dropout(F.relu(self.conv_gene_drug(x,edge_index_gene_drug)), p=0.5, training=self.training, )
        x_gene_gene = F.dropout(F.relu(self.conv_gene_gene(x,edge_index_gene_gene)), p=0.5, training=self.training)
        x_bait_gene = F.dropout(F.relu(self.conv_bait_gene(x,edge_index_bait_gene)), p=0.1, training=self.training)
        x_gene_phenotype = F.dropout(F.relu(self.conv_gene_phenotype(x,edge_index_gene_phenotype)), training=self.training)
        x_drug_phenotype = F.dropout(F.relu(self.conv_drug_phenotype(x,edge_index_drug_phenotype)), training=self.training)

        x=self.bn(torch.cat([x_gene_drug,x_gene_gene,x_bait_gene,x_gene_phenotype,x_drug_phenotype],dim=1))        
        
        return self.conv_mu(x,edge_index), self.conv_logvar(x,edge_index)

In [22]:
model=VGAE(Encoder_VGAE(node_feature_alone.shape[1], embedding_size)).to(device)
optimizer=torch.optim.Adam(model.parameters())

In [23]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(x, train_pos_edge_index, train_pos_edge_attr)
    loss = model.recon_loss(z, train_pos_edge_index)
    loss = loss + (1 / data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    print(loss.item())
    
def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z=model.encode(x, train_pos_edge_index,train_pos_edge_attr)
    return model.test(z, pos_edge_index, neg_edge_index)

In [24]:
# baseline
model.test(x,data_split.test_pos_edge_index, data_split.test_neg_edge_index )

(0.5, 0.5)

In [25]:
for epoch in range(1, 10):
    train()
    auc, ap = test(data_split.test_pos_edge_index, data_split.test_neg_edge_index)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))


14.093152046203613
Epoch: 001, AUC: 0.5631, AP: 0.5631
13.795549392700195
Epoch: 002, AUC: 0.5000, AP: 0.5000
13.72352123260498
Epoch: 003, AUC: 0.6537, AP: 0.6536
13.340164184570312
Epoch: 004, AUC: 0.4599, AP: 0.4808
13.256511688232422
Epoch: 005, AUC: 0.5000, AP: 0.5000
12.484635353088379
Epoch: 006, AUC: 0.6076, AP: 0.6011
12.39892578125
Epoch: 007, AUC: 0.5000, AP: 0.5000
11.845335006713867
Epoch: 008, AUC: 0.5000, AP: 0.5000
11.118988037109375
Epoch: 009, AUC: 0.4850, AP: 0.4926


Node embedding

In [26]:
model.eval()
z=model.encode(x, data.edge_index.to(device), data.edge_attr.to(device))
z_np = z.squeeze().detach().cpu().numpy()

Save the new embedding 

In [29]:
pickle.dump(z_np, open(data_path+'COVID_embedding_'+exp_id+'.pkl', 'wb'))

Save the torch model

In [30]:
torch.save(model.state_dict(), data_path+'COVID_VAE_encoders_'+exp_id+'.pkl')

# Hybrid Embedding: COVID-19 graph + DRKG pre-train

In [31]:
node_feature_np=pickle.load(open(data_path+'node_feature_'+exp_id+'.pkl','rb'))
node_feature=torch.tensor(node_feature_np, dtype=torch.float)

In [32]:
data = Data(x=node_feature,
            edge_index=edge.t().contiguous(),
            edge_attr=edge_attr
           )

In [33]:
data_split=train_test_split_edges(data, test_ratio=0.1, val_ratio=0)
x, train_pos_edge_index, train_pos_edge_attr = data_split.x.to(device), \
    data_split.train_pos_edge_index.to(device), data_split.train_pos_edge_attr.to(device)

In [34]:
train_pos_edge_index, train_pos_edge_attr=add_remaining_self_loops(train_pos_edge_index,train_pos_edge_attr)

In [35]:
x,train_pos_edge_index,train_pos_edge_attr = Variable(x),Variable(train_pos_edge_index),Variable(train_pos_edge_attr)

In [36]:
#DRKG's accuracy for comparison
model.test(x,data_split.test_pos_edge_index, data_split.test_neg_edge_index)

(0.3787092865405814, 0.4853520899962566)

In [37]:
for epoch in range(1, 10):
    train()
    auc, ap = test(data_split.test_pos_edge_index, data_split.test_neg_edge_index)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

17.10679817199707
Epoch: 001, AUC: 0.7388, AP: 0.7223
15.070136070251465
Epoch: 002, AUC: 0.8035, AP: 0.8242
14.061445236206055
Epoch: 003, AUC: 0.8060, AP: 0.8264
13.770694732666016
Epoch: 004, AUC: 0.7829, AP: 0.8067
13.596809387207031
Epoch: 005, AUC: 0.7773, AP: 0.7994
12.387971878051758
Epoch: 006, AUC: 0.7812, AP: 0.7783
11.4154052734375
Epoch: 007, AUC: 0.7551, AP: 0.7263
10.28616714477539
Epoch: 008, AUC: 0.7076, AP: 0.6791
9.336152076721191
Epoch: 009, AUC: 0.6987, AP: 0.6692


In [38]:
model.eval()
z=model.encode(x, data.edge_index.to(device), data.edge_attr.to(device))
z_np = z.squeeze().detach().cpu().numpy()

In [39]:
pickle.dump(z_np, open(data_path+'hybrid_embedding_'+exp_id+'.pkl', 'wb'))

In [40]:
torch.save(model.state_dict(), data_path+'hybrid_VAE_encoders_'+exp_id+'.pkl')