In [1]:
import numpy as np
import torch as th
import torch.nn as nn
import argparse

In [2]:
def parse_args():
    parser = argparse.ArgumentParser(description='GRDTI')

    parser.add_argument("--epochs", type=int, default=4000,
                        help="number of training epochs")
    parser.add_argument("--rounds", type=int, default=1,
                        help="number of training rounds")
    parser.add_argument("--device", default='cuda',
                        help="cuda or cpu")
    parser.add_argument("--dim-embedding", type=int, default=1000,
                        help="dimension of embeddings")
    parser.add_argument("--k", type=int, default=3,
                        help="Number of iterations in propagation")
    parser.add_argument("--lr", type=float, default=0.001,
                        help="learning rate")
    parser.add_argument('--weight-decay', type=float, default=0,
                        help="weight decay")
    parser.add_argument('--reg_lambda', type=float, default=1,
                        help="reg_lambda")
    parser.add_argument('--patience', type=int, default=6,
                        help='Early stopping patience.')
    parser.add_argument("--alpha", type=float, default=0.9,
                        help="Restart Probability")
    parser.add_argument("--edge-drop", type=float, default=0.5,
                        help="edge dropout in propagation")
    parser.add_argument('-f'),

    return parser.parse_args()


def row_normalize(t):
    t = t.float()
    row_sums = t.sum(1) + 1e-12
    output = t / row_sums[:, None]
    output[th.isnan(output) | th.isinf(output)] = 0.0
    return output


def col_normalize(a_matrix, substract_self_loop):
    if substract_self_loop:
        np.fill_diagonal(a_matrix, 0)
    a_matrix = a_matrix.astype(float)
    col_sums = a_matrix.sum(axis=0) + 1e-12
    new_matrix = a_matrix / col_sums[np.newaxis, :]
    new_matrix[np.isnan(new_matrix) | np.isinf(new_matrix)] = 0.0
    return new_matrix


def l2_norm(t, axit=1):
    t = t.float()
    norm = th.norm(t, 2, axit, True) + 1e-12
    output = th.div(t, norm)
    output[th.isnan(output) | th.isinf(output)] = 0.0
    return output

In [3]:
import torch as th
from torch import nn
from dgl import function as fn


class Propagation(nn.Module):
    def __init__(self, k, alpha, edge_drop=0.):
        super(Propagation, self).__init__()
        self._k = k
        self._alpha = alpha
        self.edge_drop = nn.Dropout(edge_drop)

    def forward(self, graph, feat):
        graph = graph.local_var().to('cuda')
        norm = th.pow(graph.in_degrees().float().clamp(min=1e-12), -0.5)
        shp = norm.shape + (1,) * (feat.dim() - 1)
        norm = th.reshape(norm, shp).to(feat.device)
        feat_0 = feat
        for _ in range(self._k):
            feat = feat * norm
            graph.ndata['h'] = feat
            #graph.edata['w'] = th.ones(graph.number_of_edges(), 1).to(feat.device)
            graph.edata['w'] = self.edge_drop(th.ones(graph.number_of_edges(), 1).to(feat.device))
            graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
            feat = graph.ndata.pop('h')
            feat = feat * norm
            feat = (1 - self._alpha) * feat + self._alpha * feat_0

        return feat

In [4]:
#MetaPath2Vec

import random
import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F
import tqdm

from dgl.base import NID
from dgl.convert import to_homogeneous, to_heterogeneous
from dgl.random import choice
from dgl.sampling import random_walk


class MetaPath2Vec(nn.Module):
    
    def __init__(self,
                 g,
                 metapath,
                 window_size,
                 emb_dim=128,
                 negative_size=5,
                 sparse=True):
        super().__init__()

        assert len(metapath) + 1 >= window_size, \
            f'Expect len(metapath) >= window_size - 1, got {metapath} and {window_size}'

        self.hg = g
        self.emb_dim = emb_dim
        self.metapath = metapath
        self.window_size = window_size
        self.negative_size = negative_size
        self.device = th.device(args.device)
        

        # convert edge metapath to node metapath
        # get initial source node type
        src_type, _, _ = g.to_canonical_etype(metapath[0])
        node_metapath = [src_type]
        for etype in metapath:
            _, _, dst_type = g.to_canonical_etype(etype)
            node_metapath.append(dst_type)
        self.node_metapath = node_metapath

        # Convert the graph into a homogeneous one for global to local node ID mapping
        g = to_homogeneous(g)
        #g.to('cuda')
        # Convert it back to the hetero one for local to global node ID mapping
        hg = to_heterogeneous(g, self.hg.ntypes, self.hg.etypes)
        #hg.to('cuda')
        local_to_global_nid = hg.ndata[NID]
        for key, val in local_to_global_nid.items():
            #local_to_global_nid[key] = list(val.cpu().numpy())
            local_to_global_nid[key] = list(val.to(torch.device("cuda")))
        self.local_to_global_nid = local_to_global_nid

        num_nodes_total = hg.num_nodes()
        node_frequency = torch.zeros(num_nodes_total)
        # random walk
        for idx in tqdm.trange(hg.num_nodes(node_metapath[0])):
            traces, _ = random_walk(g=hg, nodes=[idx], metapath=metapath)
            #for tr in traces.cpu().numpy():
            for tr in traces.to(torch.device("cuda")):
                tr_nids = [
                    self.local_to_global_nid[node_metapath[i]][tr[i]] for i in range(len(tr))]
                
                #node_frequency.append((torch.ones(size=(len(traces),)) * i).to('cuda'))
                node_frequency[torch.LongTensor(tr_nids)] += 1
                

        neg_prob = node_frequency.pow(0.75)
        self.neg_prob = neg_prob / neg_prob.sum()

        # center node embedding
        self.node_embed = nn.Embedding(num_nodes_total, self.emb_dim, sparse=sparse)
        #self.node_embed.to(device)
        self.context_embed = nn.Embedding(num_nodes_total, self.emb_dim, sparse=sparse)
        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters"""
        init_range = 1.0 / self.emb_dim
        init.uniform_(self.node_embed.weight.data, -init_range, init_range)
        init.constant_(self.context_embed.weight.data, 0)


    def sample(self, indices):
        device = th.device(args.device)
        
        traces, _ = random_walk(g=self.hg, nodes=indices, metapath=self.metapath)
        u_list = []
        v_list = []
        #for tr in traces.cpu().numpy():
        for tr in traces.to(torch.device("cuda")):
            tr_nids = [
                self.local_to_global_nid[self.node_metapath[i]][tr[i]] for i in range(len(tr))]
            for i, u in enumerate(tr_nids):
                for j, v in enumerate(tr_nids[max(i - self.window_size, 0):i + self.window_size]):
                    if i == j:
                        continue
                    u_list.append(u)
                    v_list.append(v)

        neg_v = choice(self.hg.num_nodes(), size=len(u_list) * self.negative_size,
                       prob=self.neg_prob).reshape(len(u_list), self.negative_size)

        return torch.cuda.LongTensor(u_list), torch.cuda.LongTensor(v_list), neg_v.to(device='cuda')
        #return torch.LongTensor(u_list).to('cuda'), torch.LongTensor(v_list).to('cuda'), neg_v.to('cuda')

    def forward(self, pos_u, pos_v, neg_v):

        emb_u = self.node_embed(pos_u)
        #emb_u.to('cuda')
        emb_v = self.context_embed(pos_v)
        #emb_v.to('cuda')
        emb_neg_v = self.context_embed(neg_v)
        #emb_neg_v.to('cuda')

        score = torch.sum(torch.mul(emb_u, emb_v), dim=1)
        score = torch.clamp(score, max=10, min=-10)
        score = -F.logsigmoid(score)

        neg_score = torch.bmm(emb_neg_v, emb_u.unsqueeze(2)).squeeze()
        neg_score = torch.clamp(neg_score, max=10, min=-10)
        neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1)

        return torch.mean(score + neg_score)

In [5]:
import dgl
import torch.nn.functional as F
import torch as th
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import SparseAdam



class GRDTI(nn.Module):
    def __init__(self, g, n_disease, n_drug, n_protein, n_sideeffect, args):
        super(GRDTI, self).__init__()
        self.g = g
        self.device = th.device(args.device)
        self.dim_embedding = args.dim_embedding

        self.activation = F.elu
        self.reg_lambda = args.reg_lambda

        self.num_disease = n_disease
        self.num_drug = n_drug
        self.num_protein = n_protein
        self.num_sideeffect = n_sideeffect
        
        '''
        Calculating Meta-Path2Vec for each node
        
        Constructing drug_feat_embedding
        Constructing protein_feat_embedding
        Constructing disease_feat_embedding
        Constructing sideeffect_feat_embedding
        '''
        
        '''metapath : Len(2) - Drug - MP2''' 
        #for metapath D-P-D
        '''model_DPD = MetaPath2Vec(g, ['drug_protein interaction', 'protein_drug interaction'], window_size=1,emb_dim=1000)
        model_DPD.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('drug')), batch_size=16, shuffle=True, collate_fn=model_DPD.sample)
        optimizer = SparseAdam(model_DPD.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DPD(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Drug_P_nids = torch.cuda.LongTensor(model_DPD.local_to_global_nid['drug'])
        Drug_P_emb = model_DPD.node_embed(Drug_P_nids)
        Drug_P_emb = th.tensor(Drug_P_emb).to(self.device)
        #print(Drug_P_emb)'''

        #for metapath D-Di-D
        model_DDiD = MetaPath2Vec(g, ['drug_disease association', 'disease_drug association'], window_size=1,emb_dim=1000)
        model_DDiD.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('drug')), batch_size=16, shuffle=True, collate_fn=model_DDiD.sample)
        optimizer = SparseAdam(model_DDiD.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DDiD(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Drug_Di_nids = torch.cuda.LongTensor(model_DDiD.local_to_global_nid['drug'])
        Drug_Di_emb = model_DDiD.node_embed(Drug_Di_nids)
        Drug_Di_emb = th.tensor(Drug_Di_emb).to(self.device)
        #print(Drug_Di_emb)

        #for metapath D-Si-D
        '''model_DSiD = MetaPath2Vec(g, ['drug_sideeffect association', 'sideeffect_drug association'], window_size=1,emb_dim=1000)
        model_DSiD.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('drug')), batch_size=16, shuffle=True, collate_fn=model_DSiD.sample)
        optimizer = SparseAdam(model_DSiD.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DSiD(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Drug_Si_nids = torch.cuda.LongTensor(model_DSiD.local_to_global_nid['drug'])
        Drug_Si_emb = model_DSiD.node_embed(Drug_Si_nids)
        Drug_Si_emb = th.tensor(Drug_Si_emb).to(self.device)'''
        
        '''metapath : Len(3) - Drug'''
        #for metapath D-P-Di-D
        model_DPDiD = MetaPath2Vec(g, ['drug_protein interaction', 'protein_disease association',
                                       'disease_drug association'],
                                 window_size=1,emb_dim=1000)
        model_DPDiD.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('drug')), batch_size=16, shuffle=True,
                                collate_fn=model_DPDiD.sample)
        optimizer = SparseAdam(model_DPDiD.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DPDiD(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Drug_PDi_nids = torch.cuda.LongTensor(model_DPDiD.local_to_global_nid['drug'])
        Drug_PDi_emb = model_DPDiD.node_embed(Drug_PDi_nids)
        Drug_PDi_emb = th.tensor(Drug_PDi_emb).to(self.device)
        
        '''#for metapath D-Di-P-D
        model_DDiPD = MetaPath2Vec(g, ['drug_disease association', 'disease_protein association',
                                       'protein_drug interaction'],
                                 window_size=1,emb_dim=1000)
        model_DDiPD.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('drug')), batch_size=16, shuffle=True,
                                collate_fn=model_DDiPD.sample)
        optimizer = SparseAdam(model_DDiPD.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DDiPD(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Drug_DiP_nids = torch.cuda.LongTensor(model_DDiPD.local_to_global_nid['drug'])
        Drug_DiP_emb = model_DDiPD.node_embed(Drug_DiP_nids)
        Drug_DiP_emb = th.tensor(Drug_DiP_emb).to(self.device)'''
        
        '''metapath : Len(4) - Drug'''
        '''#for metapath D-P-Di-P-D
        model_DD4_1 = MetaPath2Vec(g, ['drug_protein interaction','protein_disease association',
                                     'disease_protein association', 'protein_drug interaction'],
                                 window_size=1,emb_dim=1000)
        model_DD4_1.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('drug')), batch_size=16, shuffle=True,
                                collate_fn=model_DD4_1.sample)
        optimizer = SparseAdam(model_DD4_1.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DD4_1(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Drug_D1_nids = torch.cuda.LongTensor(model_DD4_1.local_to_global_nid['drug'])
        Drug_D1_emb = model_DD4_1.node_embed(Drug_D1_nids)
        Drug_D1_emb = th.tensor(Drug_D1_emb).to(self.device)'''
        
        #for metapath D-Di-P-Di-D
        model_DD4_2 = MetaPath2Vec(g, ['drug_disease association','disease_protein association',
                                     'protein_disease association', 'disease_drug association'],
                                 window_size=1,emb_dim=1000)
        model_DD4_2.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('drug')), batch_size=16, shuffle=True,
                                collate_fn=model_DD4_2.sample)
        optimizer = SparseAdam(model_DD4_2.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DD4_2(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Drug_D2_nids = torch.cuda.LongTensor(model_DD4_2.local_to_global_nid['drug'])
        Drug_D2_emb = model_DD4_2.node_embed(Drug_D2_nids)
        Drug_D2_emb = th.tensor(Drug_D2_emb).to(self.device)
        
        '''#for metapath D-Si-D-P-D
        model_DD4_3 = MetaPath2Vec(g, ['drug_sideeffect association','sideeffect_drug association',
                                     'drug_protein interaction', 'protein_drug interaction'],
                                 window_size=1,emb_dim=1000)
        model_DD4_3.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('drug')), batch_size=16, shuffle=True,
                                collate_fn=model_DD4_3.sample)
        optimizer = SparseAdam(model_DD4_3.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DD4_3(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Drug_D3_nids = torch.cuda.LongTensor(model_DD4_3.local_to_global_nid['drug'])
        Drug_D3_emb = model_DD4_3.node_embed(Drug_D3_nids)
        Drug_D3_emb = th.tensor(Drug_D3_emb).to(self.device)'''
        
        #for metapath D-Si-D-Di-D
        model_DD4_4 = MetaPath2Vec(g, ['drug_sideeffect association','sideeffect_drug association',
                                      'drug_disease association', 'disease_drug association'],
                                 window_size=1,emb_dim=1000)
        model_DD4_4.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('drug')), batch_size=16, shuffle=True,
                                collate_fn=model_DD4_4.sample)
        optimizer = SparseAdam(model_DD4_4.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DD4_4(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Drug_D4_nids = torch.cuda.LongTensor(model_DD4_4.local_to_global_nid['drug'])
        Drug_D4_emb = model_DD4_4.node_embed(Drug_D4_nids)
        Drug_D4_emb = th.tensor(Drug_D4_emb).to(self.device)
        
        self.drug_feat_meta = torch.mean(torch.stack((Drug_Di_emb,
                                                      Drug_PDi_emb,
                                                      Drug_D2_emb, Drug_D4_emb),
                                                     dim = 1), dim = 1)
        
        '''metapath : Len(2) - Protein'''
        '''#for metapath P-D-P
        model_PDP = MetaPath2Vec(g, ['protein_drug interaction', 'drug_protein interaction'],
                                 window_size=1,emb_dim=1000)
        model_PDP.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('protein')), batch_size=16, shuffle=True, collate_fn=model_PDP.sample)
        optimizer = SparseAdam(model_PDP.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_PDP(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Protein_D_nids = torch.cuda.LongTensor(model_PDP.local_to_global_nid['protein'])
        Protein_D_emb = model_PDP.node_embed(Protein_D_nids)
        Protein_D_emb = th.tensor(Protein_D_emb).to(self.device)
        #print(Protein_D_emb)'''

        #for metapath P-Di-P
        model_PDiP = MetaPath2Vec(g, ['protein_disease association', 'disease_protein association'],
                                  window_size=1,emb_dim=1000)
        model_PDiP.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('protein')), batch_size=16, shuffle=True,
                                collate_fn=model_PDiP.sample)
        optimizer = SparseAdam(model_PDiP.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_PDiP(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Protein_Di_nids = torch.cuda.LongTensor(model_PDiP.local_to_global_nid['protein'])
        Protein_Di_emb = model_PDiP.node_embed(Protein_Di_nids)
        Protein_Di_emb = th.tensor(Protein_Di_emb).to(self.device)
        
        '''metapath : Len(3) - Protein'''
        #for metapath P-Di-D-P
        model_PDiDP = MetaPath2Vec(g, ['protein_disease association','disease_drug association','drug_protein interaction'],
                                 window_size=1,emb_dim=1000)
        model_PDiDP.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('protein')), batch_size=16,
                                shuffle=True, collate_fn=model_PDiDP.sample)
        optimizer = SparseAdam(model_PDiDP.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_PDiDP(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Protein_DiP_nids = torch.cuda.LongTensor(model_PDiDP.local_to_global_nid['protein'])
        Protein_DiDP_emb = model_PDiDP.node_embed(Protein_DiP_nids)
        Protein_DiDP_emb = th.tensor(Protein_DiDP_emb).to(self.device)
        
        #for metapath P-D-Di-P
        model_PDDiP = MetaPath2Vec(g, ['protein_drug interaction', 'drug_disease association', 'disease_protein association'],
                                 window_size=1,emb_dim=1000)
        model_PDDiP.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('protein')), batch_size=16,
                                shuffle=True, collate_fn=model_PDDiP.sample)
        optimizer = SparseAdam(model_PDDiP.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_PDDiP(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Protein_DDiP_nids = torch.cuda.LongTensor(model_PDDiP.local_to_global_nid['protein'])
        Protein_DDiP_emb = model_PDDiP.node_embed(Protein_DDiP_nids)
        Protein_DDiP_emb = th.tensor(Protein_DDiP_emb).to(self.device)
        
        '''metapath : Len(4) - Protein'''
        '''#for metapath P-D-Di-D-P
        
        model_PP4_1 = MetaPath2Vec(g, ['protein_drug interaction',  'drug_disease association', 
                                       'disease_drug association','drug_protein interaction'],
                                   window_size=1,emb_dim=1000)
        model_PP4_1.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('protein')), batch_size=16,
                                shuffle=True, collate_fn=model_PP4_1.sample)
        optimizer = SparseAdam(model_PP4_1.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_PP4_1(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Protein_P1_nids = torch.cuda.LongTensor(model_PP4_1.local_to_global_nid['protein'])
        Protein_P1_emb = model_PP4_1.node_embed(Protein_P1_nids)
        Protein_P1_emb = th.tensor(Protein_P1_emb).to(self.device)'''
        
        #for metapath P-Di-D-Di-P
        model_PP4_2 = MetaPath2Vec(g, ['protein_disease association',  'disease_drug association', 
                                       'drug_disease association','disease_protein association'],
                                   window_size=1,emb_dim=1000)
        model_PP4_2.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('protein')), batch_size=16,
                                shuffle=True, collate_fn=model_PP4_2.sample)
        optimizer = SparseAdam(model_PP4_2.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_PP4_2(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Protein_P2_nids = torch.cuda.LongTensor(model_PP4_2.local_to_global_nid['protein'])
        Protein_P2_emb = model_PP4_2.node_embed(Protein_P2_nids)
        Protein_P2_emb = th.tensor(Protein_P2_emb).to(self.device)
        
        '''#for metapath P-D-Si-D-P
        model_PP4_3 = MetaPath2Vec(g, ['protein_drug interaction', 'drug_sideeffect association',
                                       'sideeffect_drug association','drug_protein interaction'],
                                   window_size=1,emb_dim=1000)
        model_PP4_3.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('protein')), batch_size=16,
                                shuffle=True, collate_fn=model_PP4_3.sample)
        optimizer = SparseAdam(model_PP4_3.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_PP4_3(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Protein_P3_nids = torch.cuda.LongTensor(model_PP4_3.local_to_global_nid['protein'])
        Protein_P3_emb = model_PP4_3.node_embed(Protein_P3_nids)
        Protein_P3_emb = th.tensor(Protein_P3_emb).to(self.device)'''
        
        self.protein_feat_meta = torch.mean(torch.stack((Protein_Di_emb,
                                                         Protein_DiDP_emb, Protein_DDiP_emb,
                                                         Protein_P2_emb),
                                                        dim = 1), dim = 1)
        
        '''metapath : Len(2) - Disease'''
        #metapath for Disease
        #for meta-path Di-D-Di
        '''model_DiDDi = MetaPath2Vec(g, ['disease_drug association', 'drug_disease association'], window_size=1,emb_dim=1000)
        model_DiDDi.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('disease')), batch_size=16, shuffle=True, collate_fn=model_DiDDi.sample)
        optimizer = SparseAdam(model_DiDDi.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DiDDi(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Disease_D_nids = torch.cuda.LongTensor(model_DiDDi.local_to_global_nid['disease'])
        Disease_D_emb = model_DiDDi.node_embed(Disease_D_nids)
        Disease_D_emb = th.tensor(Disease_D_emb).to(self.device)'''

        #for meta-path Di-P-Di
        model_DiPDi = MetaPath2Vec(g, ['disease_protein association', 'protein_disease association'], window_size=1,emb_dim=1000)
        model_DiPDi.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('disease')), batch_size=16, shuffle=True, collate_fn=model_DiPDi.sample)
        optimizer = SparseAdam(model_DiPDi.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DiPDi(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Disease_P_nids = torch.cuda.LongTensor(model_DiPDi.local_to_global_nid['disease'])
        Disease_P_emb = model_DiPDi.node_embed(Disease_P_nids)
        Disease_P_emb = th.tensor(Disease_P_emb).to(self.device)
        
        '''metapath : Len(3) - Disease'''
        '''#for meta-path Di-P-D-Di
        model_DiPDDi = MetaPath2Vec(g, ['disease_protein association', 'protein_drug interaction', 'drug_disease association'],
                                    window_size=1,emb_dim=1000)
        model_DiPDDi.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('disease')), batch_size=16, shuffle=True,
                                collate_fn=model_DiPDDi.sample)
        optimizer = SparseAdam(model_DiPDDi.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DiPDDi(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Disease_PDDi_nids = torch.cuda.LongTensor(model_DiPDDi.local_to_global_nid['disease'])
        Disease_PDDi_emb = model_DiPDDi.node_embed(Disease_PDDi_nids)
        Disease_PDDi_emb = th.tensor(Disease_PDDi_emb).to(self.device)'''
        
        #for meta-path Di-D-P-Di
        model_DiDPDi = MetaPath2Vec(g, ['disease_drug association', 'drug_protein interaction', 'protein_disease association'],
                                    window_size=1,emb_dim=1000)
        model_DiDPDi.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('disease')), batch_size=16, shuffle=True,
                                collate_fn=model_DiDPDi.sample)
        optimizer = SparseAdam(model_DiDPDi.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DiDPDi(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Disease_DPDi_nids = torch.cuda.LongTensor(model_DiDPDi.local_to_global_nid['disease'])
        Disease_DPDi_emb = model_DiDPDi.node_embed(Disease_DPDi_nids)
        Disease_DPDi_emb = th.tensor(Disease_DPDi_emb).to(self.device)
        
        '''metapath : Len(4) - Disease'''
        '''#for metapath Di-D-P-D-Di
        model_DiDi4_1 = MetaPath2Vec(g, ['disease_drug association','drug_protein interaction',
                                         'protein_drug interaction','drug_disease association'],
                                   window_size=1,emb_dim=1000)
        model_DiDi4_1.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('disease')), batch_size=16,
                                shuffle=True, collate_fn=model_DiDi4_1.sample)
        optimizer = SparseAdam(model_DiDi4_1.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DiDi4_1(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Disease_Di1_nids = torch.cuda.LongTensor(model_DiDi4_1.local_to_global_nid['disease'])
        Disease_Di1_emb = model_DiDi4_1.node_embed(Disease_Di1_nids)
        Disease_Di1_emb = th.tensor(Disease_Di1_emb).to(self.device)'''
        
        #for metapath Di-P-D-P-Di
        model_DiDi4_2 = MetaPath2Vec(g, ['disease_protein association', 'protein_drug interaction',
                                         'drug_protein interaction','protein_disease association'],
                                   window_size=1,emb_dim=1000)
        model_DiDi4_2.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('disease')), batch_size=16,
                                shuffle=True, collate_fn=model_DiDi4_2.sample)
        optimizer = SparseAdam(model_DiDi4_2.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DiDi4_2(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Disease_Di2_nids = torch.cuda.LongTensor(model_DiDi4_2.local_to_global_nid['disease'])
        Disease_Di2_emb = model_DiDi4_2.node_embed(Disease_Di2_nids)
        Disease_Di2_emb = th.tensor(Disease_Di2_emb).to(self.device)
        
        #for metapath Di-D-Si-D-Di
        model_DiDi4_3 = MetaPath2Vec(g, ['disease_drug association', 'drug_sideeffect association',
                                         'sideeffect_drug association','drug_disease association'],
                                   window_size=1,emb_dim=1000)
        model_DiDi4_3.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('disease')), batch_size=16,
                                shuffle=True, collate_fn=model_DiDi4_3.sample)
        optimizer = SparseAdam(model_DiDi4_3.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_DiDi4_3(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Disease_Di3_nids = torch.cuda.LongTensor(model_DiDi4_3.local_to_global_nid['disease'])
        Disease_Di3_emb = model_DiDi4_3.node_embed(Disease_Di3_nids)
        Disease_Di3_emb = th.tensor(Disease_Di3_emb).to(self.device)
        
        self.disease_feat_meta = torch.mean(torch.stack((Disease_P_emb,
                                                         Disease_DPDi_emb,
                                                         Disease_Di2_emb, Disease_Di3_emb),
                                                        dim = 1), dim = 1)

        '''metapath : Len(2) - sideeffect'''
        
        model_SiDSi = MetaPath2Vec(g, ['sideeffect_drug association', 'drug_sideeffect association'],
                                   window_size=1,emb_dim=1000)
        model_SiDSi.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('sideeffect')), batch_size=16, shuffle=True, collate_fn=model_SiDSi.sample)
        optimizer = SparseAdam(model_SiDSi.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_SiDSi(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Sideeffect_D_nids = torch.cuda.LongTensor(model_SiDSi.local_to_global_nid['sideeffect'])
        Sideeffect_D_emb = model_SiDSi.node_embed(Sideeffect_D_nids)
        Sideeffect_D_emb = th.tensor(Sideeffect_D_emb).to(self.device)
        
        '''metapath : Len(4) - sideeffect'''
        #for metapath Si-D-P-D-Si 
        
        model_SiSi4_1 = MetaPath2Vec(g, ['sideeffect_drug association', 'drug_protein interaction',
                                         'protein_drug interaction','drug_sideeffect association'],
                                     window_size=1,emb_dim=1000)
        model_SiSi4_1.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('sideeffect')), batch_size=16,
                                shuffle=True, collate_fn=model_SiSi4_1.sample)
        optimizer = SparseAdam(model_SiSi4_1.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_SiSi4_1(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Sideeffect_Si1_nids = torch.cuda.LongTensor(model_SiSi4_1.local_to_global_nid['sideeffect'])
        Sideeffect_Si1_emb = model_SiSi4_1.node_embed(Sideeffect_Si1_nids)
        Sideeffect_Si1_emb = th.tensor(Sideeffect_Si1_emb).to(self.device)
        
        #for metapath Si-D-Di-D-Si
        model_SiSi4_2 = MetaPath2Vec(g, ['sideeffect_drug association', 'drug_disease association',
                                         'disease_drug association','drug_sideeffect association'],
                                     window_size=1,emb_dim=1000)
        model_SiSi4_2.to(self.device)
        dataloader = DataLoader(torch.arange(g.num_nodes('sideeffect')), batch_size=16,
                                shuffle=True, collate_fn=model_SiSi4_2.sample)
        optimizer = SparseAdam(model_SiSi4_2.parameters(), lr=0.025)

        for (pos_u, pos_v, neg_v) in dataloader:
            loss = model_SiSi4_2(pos_u, pos_v, neg_v)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        Sideeffect_Si2_nids = torch.cuda.LongTensor(model_SiSi4_2.local_to_global_nid['sideeffect'])
        Sideeffect_Si2_emb = model_SiSi4_2.node_embed(Sideeffect_Si2_nids)
        Sideeffect_Si2_emb = th.tensor(Sideeffect_Si2_emb).to(self.device)
        
        self.sideeffect_feat_meta = torch.mean(torch.stack((Sideeffect_D_emb,
                                                            Sideeffect_Si1_emb, Sideeffect_Si2_emb),
                                                           dim = 1), dim = 1)

        
        self.node_feat_meta = th.cat((self.disease_feat_meta, self.drug_feat_meta, self.protein_feat_meta,
                                      self.sideeffect_feat_meta), dim=0)

        
        self.drug_feat = nn.Parameter(th.FloatTensor(self.num_drug, self.dim_embedding))
        nn.init.normal_(self.drug_feat, mean=0, std=0.1)
        self.protein_feat = nn.Parameter(th.FloatTensor(self.num_protein, self.dim_embedding))
        nn.init.normal_(self.protein_feat, mean=0, std=0.1)
        self.disease_feat = nn.Parameter(th.FloatTensor(self.num_disease, self.dim_embedding))
        nn.init.normal_(self.disease_feat, mean=0, std=0.1)
        self.sideeffect_feat = nn.Parameter(th.FloatTensor(self.num_sideeffect, self.dim_embedding))
        nn.init.normal_(self.sideeffect_feat, mean=0, std=0.1)

        # 邻居信息的权重矩阵，对应论文公式（1）中的Wr、br
        self.fc_DDI = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_D_ch = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_D_Di = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_D_Side = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_D_P = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_PPI = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_P_seq = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_P_Di = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_P_D = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_Di_D = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_Di_P = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_Side_D = nn.Linear(self.dim_embedding, self.dim_embedding).float()

        self.propagation = Propagation(args.k, args.alpha, args.edge_drop)

        # Linear transformation for reconstruction
        tmp = th.randn(self.dim_embedding).float()
        self.re_DDI = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        self.re_D_ch = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        self.re_D_Di = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        self.re_D_Side = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        self.re_D_P = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        self.re_PPI = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        self.re_P_seq = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        self.re_P_Di = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        #self.re_P_D = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        #self.re_Di_P = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        #self.re_Di_D = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        #self.re_Side_D = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))

        self.reset_parameters()

    def reset_parameters(self):
        for m in GRDTI.modules(self):
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data, mean=0, std=0.1)
                if m.bias is not None:
                    m.bias.data.fill_(0.1)

    def forward(self, drug_drug, drug_chemical, drug_disease, drug_sideeffect, protein_protein,
                protein_sequence, protein_disease, drug_protein, drug_protein_mask):

        disease_feat = th.mean(th.stack((th.mm(row_normalize(drug_disease.T).float(),
                                               F.relu(self.fc_Di_D(self.drug_feat))),
                                         th.mm(row_normalize(protein_disease.T).float(),
                                               F.relu(self.fc_Di_P(self.protein_feat))),
                                         self.disease_feat), dim=1), dim=1)

        drug_feat = th.mean(th.stack((th.mm(row_normalize(drug_drug).float(),
                                            F.relu(self.fc_DDI(self.drug_feat))),
                                      th.mm(row_normalize(drug_chemical).float(),
                                            F.relu(self.fc_D_ch(self.drug_feat))),
                                      th.mm(row_normalize(drug_disease).float(),
                                            F.relu(self.fc_D_Di(self.disease_feat))),
                                      th.mm(row_normalize(drug_sideeffect).float(),
                                            F.relu(self.fc_D_Side(self.sideeffect_feat))),
                                      th.mm(row_normalize(drug_protein).float(),
                                            F.relu(self.fc_D_P(self.protein_feat))),
                                      self.drug_feat), dim=1), dim=1)

        protein_feat = th.mean(th.stack((th.mm(row_normalize(protein_protein).float(),
                                               F.relu(self.fc_PPI(self.protein_feat))),
                                         th.mm(row_normalize(protein_sequence).float(),
                                               F.relu(self.fc_P_seq(self.protein_feat))),
                                         th.mm(row_normalize(protein_disease).float(),
                                               F.relu(self.fc_P_Di(self.disease_feat))),
                                         th.mm(row_normalize(drug_protein.T).float(),
                                               F.relu(self.fc_P_D(self.drug_feat))),
                                         self.protein_feat), dim=1), dim=1)

        sideeffect_feat = th.mean(th.stack((th.mm(row_normalize(drug_sideeffect.T).float(),
                                                  F.relu(self.fc_Side_D(self.drug_feat))),
                                            self.sideeffect_feat), dim=1), dim=1)

        node_feat = th.cat((disease_feat, drug_feat, protein_feat, sideeffect_feat), dim=0)
        
        '''Concatenate node_feat from GCN with node_feat_meta from Meta-Path'''
        beta = 0.5
        node_feat = beta*node_feat + (1-beta)*self.node_feat_meta
        #node_feat = torch.mean(torch.stack((node_feat,self.node_feat_meta), dim = 1), dim = 1)
        
        node_feat = self.propagation(dgl.to_homogeneous(self.g), node_feat)

        disease_embedding = node_feat[:self.num_disease].to(self.device)
        drug_embedding = node_feat[self.num_disease:self.num_disease + self.num_drug].to(self.device)
        protein_embedding = node_feat[self.num_disease + self.num_drug:self.num_disease + self.num_drug +
                                                                       self.num_protein].to(self.device)
        sideeffect_embedding = node_feat[-self.num_sideeffect:].to(self.device)

        disease_vector = l2_norm(disease_embedding)
        drug_vector = l2_norm(drug_embedding)
        protein_vector = l2_norm(protein_embedding)
        sideeffect_vector = l2_norm(sideeffect_embedding)

        drug_drug_reconstruct = th.mm(th.mm(drug_vector, self.re_DDI), drug_vector.t())
        drug_drug_reconstruct_loss = th.sum(
            (drug_drug_reconstruct - drug_drug.float()) ** 2)

        drug_chemical_reconstruct = th.mm(th.mm(drug_vector, self.re_D_ch), drug_vector.t())
        drug_chemical_reconstruct_loss = th.sum(
            (drug_chemical_reconstruct - drug_chemical.float()) ** 2)

        drug_disease_reconstruct = th.mm(th.mm(drug_vector, self.re_D_Di), disease_vector.t())
        drug_disease_reconstruct_loss = th.sum(
            (drug_disease_reconstruct - drug_disease.float()) ** 2)

        drug_sideeffect_reconstruct = th.mm(th.mm(drug_vector, self.re_D_Side), sideeffect_vector.t())
        drug_sideeffect_reconstruct_loss = th.sum(
            (drug_sideeffect_reconstruct - drug_sideeffect.float()) ** 2)

        protein_protein_reconstruct = th.mm(th.mm(protein_vector, self.re_PPI), protein_vector.t())
        protein_protein_reconstruct_loss = th.sum(
            (protein_protein_reconstruct - protein_protein.float()) ** 2)

        protein_sequence_reconstruct = th.mm(th.mm(protein_vector, self.re_P_seq), protein_vector.t())
        protein_sequence_reconstruct_loss = th.sum(
            (protein_sequence_reconstruct - protein_sequence.float()) ** 2)

        protein_disease_reconstruct = th.mm(th.mm(protein_vector, self.re_P_Di), disease_vector.t())
        protein_disease_reconstruct_loss = th.sum(
            (protein_disease_reconstruct - protein_disease.float()) ** 2)

        drug_protein_reconstruct = th.mm(th.mm(drug_vector, self.re_D_P), protein_vector.t())
        tmp = th.mul(drug_protein_mask.float(), (drug_protein_reconstruct - drug_protein.float()))
        DTI_potential = drug_protein_reconstruct - drug_protein.float()
        drug_protein_reconstruct_loss = th.sum(tmp ** 2)

        other_loss = drug_drug_reconstruct_loss + drug_chemical_reconstruct_loss + drug_disease_reconstruct_loss + \
                     drug_sideeffect_reconstruct_loss + protein_protein_reconstruct_loss + \
                     protein_sequence_reconstruct_loss + protein_disease_reconstruct_loss

        L2_loss = 0.
        for name, param in GRDTI.named_parameters(self):
            if 'bias' not in name:
                L2_loss = L2_loss + th.sum(param.pow(2))
        L2_loss = L2_loss * 0.5

        tloss = drug_protein_reconstruct_loss + 1.0 * other_loss + self.reg_lambda * L2_loss

        return tloss, drug_protein_reconstruct_loss, L2_loss, drug_protein_reconstruct, DTI_potential

In [None]:
import dgl
import time
import torch as th
import numpy as np
import matplotlib.pyplot as plt


from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.model_selection import train_test_split, StratifiedKFold
from imblearn.under_sampling import RandomUnderSampler



def loda_data():
    network_path = 'E:/Alzahra-University/GADTI-main/data/'
    

    drug_drug = np.loadtxt(network_path + '/' + 'mat_drug_drug.txt')
    true_drug = 708
    drug_chemical = np.loadtxt(network_path + '/' + 'Similarity_Matrix_Drugs.txt')
    drug_chemical = drug_chemical[:true_drug, :true_drug]
    drug_disease = np.loadtxt(network_path + '/' + 'mat_drug_disease.txt')
    drug_sideeffect = np.loadtxt(network_path + '/' + 'mat_drug_se.txt')

    protein_protein = np.loadtxt(network_path + '/' + 'mat_protein_protein.txt')
    protein_sequence = np.loadtxt(network_path + '/' + 'Similarity_Matrix_Proteins.txt')
    protein_disease = np.loadtxt(network_path + '/' + 'mat_protein_disease.txt')

    num_drug = len(drug_drug)
    num_protein = len(protein_protein)

    # Removed the self-loop
    drug_chemical = drug_chemical - np.identity(num_drug)
    protein_sequence = protein_sequence / 100.
    protein_sequence = protein_sequence - np.identity(num_protein)

    drug_protein = np.loadtxt(network_path + '/' + 'mat_drug_protein.txt')

    # Removed DTIs with similar drugs or proteins
    #drug_protein = np.loadtxt(network_path + 'mat_drug_protein_homo_protein_drug.txt')

    print("Load data finished.")

    return drug_drug, drug_chemical, drug_disease, drug_sideeffect, protein_protein, protein_sequence, \
           protein_disease, drug_protein


def ConstructGraph(drug_drug, drug_chemical, drug_disease, drug_sideeffect, protein_protein, protein_sequence,
                   protein_disease, drug_protein):
    num_drug = len(drug_drug)
    num_protein = len(protein_protein)
    num_disease = len(drug_disease.T)
    num_sideeffect = len(drug_sideeffect.T)

    list_drug = []
    for i in range(num_drug):
        list_drug.append((i, i))

    list_protein = []
    for i in range(num_protein):
        list_protein.append((i, i))

    list_disease = []
    for i in range(num_disease):
        list_disease.append((i, i))

    list_sideeffect = []
    for i in range(num_sideeffect):
        list_sideeffect.append((i, i))

    list_DDI = []
    for row in range(num_drug):
        for col in range(num_drug):
            if drug_drug[row, col] > 0:
                list_DDI.append((row, col))

    list_PPI = []
    for row in range(num_protein):
        for col in range(num_protein):
            if protein_protein[row, col] > 0:
                list_PPI.append((row, col))

    list_drug_protein = []
    list_protein_drug = []
    for row in range(num_drug):
        for col in range(num_protein):
            if drug_protein[row, col] > 0:
                list_drug_protein.append((row, col))
                list_protein_drug.append((col, row))

    list_drug_sideeffect = []
    list_sideeffect_drug = []
    for row in range(num_drug):
        for col in range(num_sideeffect):
            if drug_sideeffect[row, col] > 0:
                list_drug_sideeffect.append((row, col))
                list_sideeffect_drug.append((col, row))

    list_drug_disease = []
    list_disease_drug = []
    for row in range(num_drug):
        for col in range(num_disease):
            if drug_disease[row, col] > 0:
                list_drug_disease.append((row, col))
                list_disease_drug.append((col, row))

    list_protein_disease = []
    list_disease_protein = []
    for row in range(num_protein):
        for col in range(num_disease):
            if protein_disease[row, col] > 0:
                list_protein_disease.append((row, col))
                list_disease_protein.append((col, row))

    g_HIN = dgl.heterograph({('disease', 'disease_disease virtual', 'disease'): list_disease,
                             ('drug', 'drug_drug virtual', 'drug'): list_drug,
                             ('protein', 'protein_protein virtual', 'protein'): list_protein,
                             ('sideeffect', 'sideeffect_sideeffect virtual', 'sideeffect'): list_sideeffect,
                             ('drug', 'drug_drug interaction', 'drug'): list_DDI, \
                             ('protein', 'protein_protein interaction', 'protein'): list_PPI, \
                             ('drug', 'drug_protein interaction', 'protein'): list_drug_protein, \
                             ('protein', 'protein_drug interaction', 'drug'): list_protein_drug, \
                             ('drug', 'drug_sideeffect association', 'sideeffect'): list_drug_sideeffect, \
                             ('sideeffect', 'sideeffect_drug association', 'drug'): list_sideeffect_drug, \
                             ('drug', 'drug_disease association', 'disease'): list_drug_disease, \
                             ('disease', 'disease_drug association', 'drug'): list_disease_drug, \
                             ('protein', 'protein_disease association', 'disease'): list_protein_disease, \
                             ('disease', 'disease_protein association', 'protein'): list_disease_protein})

    g = g_HIN.edge_type_subgraph(['drug_drug interaction', 'protein_protein interaction',
                                  'drug_protein interaction', 'protein_drug interaction',
                                  'drug_sideeffect association', 'sideeffect_drug association',
                                  'drug_disease association', 'disease_drug association',
                                  'protein_disease association', 'disease_protein association'
                                  ])

    return g


def TrainAndEvaluate(DTItrain, DTIvalid, DTItest, args, drug_drug, drug_chemical, drug_disease,
                     drug_sideeffect, protein_protein, protein_sequence, protein_disease):
    device = th.device(args.device)

    # Numbers of different nodes
    num_disease = len(drug_disease.T)
    num_drug = len(drug_drug)
    num_protein = len(protein_protein)
    num_sideeffect = len(drug_sideeffect.T)

    drug_protein = th.zeros((num_drug, num_protein))
    mask = th.zeros((num_drug, num_protein)).to(device)
    for ele in DTItrain:
        drug_protein[ele[0], ele[1]] = ele[2]
        mask[ele[0], ele[1]] = 1

    best_valid_aupr = 0.
    # best_valid_auc = 0
    test_aupr = 0.
    test_auc = 0.
    patience = 0.

    pos = np.count_nonzero(DTItest[:, 2])
    neg = np.size(DTItest[:, 2]) - pos
    xy_roc_sampling = []
    xy_pr_sampling = []

    g = ConstructGraph(drug_drug, drug_chemical, drug_disease, drug_sideeffect, protein_protein, protein_sequence,
                       protein_disease, drug_protein)
    g.to('cuda')

    drug_drug = th.tensor(drug_drug).to(device)
    drug_chemical = th.tensor(drug_chemical).to(device)
    drug_disease = th.tensor(drug_disease).to(device)
    drug_sideeffect = th.tensor(drug_sideeffect).to(device)
    protein_protein = th.tensor(protein_protein).to(device)
    protein_sequence = th.tensor(protein_sequence).to(device)
    protein_disease = th.tensor(protein_disease).to(device)
    drug_protein = drug_protein.to(device)

    model = GRDTI(g, num_disease, num_drug, num_protein, num_sideeffect, args)
    model.to(device)

    optimizer = th.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    for i in range(args.epochs):

        model.train()
        tloss, dtiloss, l2loss, dp_re, DTI_p = model(drug_drug, drug_chemical, drug_disease, drug_sideeffect,
                                                     protein_protein, protein_sequence, protein_disease,
                                                     drug_protein, mask)

        results = dp_re.detach().cpu()
        optimizer.zero_grad()
        loss = tloss
        loss.backward()
        th.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()

        model.eval()

        if i % 25 == 0:
            with th.no_grad():
                print("step", i, ":", "Total_loss & DTIloss & L2_loss:", loss.cpu().data.numpy(), ",", dtiloss.item(),
                      ",", l2loss.item())

                pred_list = []
                ground_truth = []

                for ele in DTIvalid:
                    pred_list.append(results[ele[0], ele[1]])
                    ground_truth.append(ele[2])

                valid_auc = roc_auc_score(ground_truth, pred_list)
                valid_aupr = average_precision_score(ground_truth, pred_list)

                if valid_aupr >= best_valid_aupr:
                    best_valid_aupr = valid_aupr
                    # best_valid_auc = valid_auc
                    best_DTI_potential = DTI_p
                    patience = 0

                    # Calculating AUC & AUPR (pos:neg=1:10)
                    '''db = []
                    xy_roc = []
                    xy_pr = []
                    for ele in DTItest:
                        db.append([results[ele[0], ele[1]], ele[2]])

                    db = sorted(db, key=lambda x: x[0], reverse=True)

                    tp, fp = 0., 0.
                    for i_db in range(len(db)):
                        if db[i_db][0]:
                            if db[i_db][1]:
                                tp = tp + 1
                            else:
                                fp = fp + 1
                            xy_roc.append([fp / neg, tp / pos])
                            xy_pr.append([tp / pos, tp / (tp + fp)])

                    test_auc = 0.
                    prev_x = 0.
                    for x, y in xy_roc:
                        if x != prev_x:
                            test_auc += (x - prev_x) * y
                            prev_x = x

                    test_aupr = 0.
                    prev_x = 0.
                    for x, y in xy_pr:
                        if x != prev_x:
                            test_aupr += (x - prev_x) * y
                            prev_x = x'''

                    # All unknown DTI pairs all treated as negative examples
                    pred_list = []
                    ground_truth = []
                    for ele in DTItest:
                        pred_list.append(results[ele[0], ele[1]])
                        ground_truth.append(ele[2])
                    test_auc = roc_auc_score(ground_truth, pred_list)
                    test_aupr = average_precision_score(ground_truth, pred_list)

                else:
                    patience += 1
                    if patience > args.patience:
                        print("Early Stopping")

                        # sampling (pos:neg=1:10) for averaging and plotting
                        '''xy_roc_sampling = []
                        xy_pr_sampling = []
                        for i_xy in range(len(xy_roc)):
                            if i_xy % 10 == 0:
                                xy_roc_sampling.append(xy_roc[i_xy])
                                xy_pr_sampling.append(xy_pr[i_xy])'''

                        # Record data for sampling, averaging and plotting.
                        # All unknown DTI pairs all treated as negative examples
                        #t1 = time.localtime()
                        #time_creat_txt = str(t1.tm_year) + '_' + str(t1.tm_mon) + '_' + str(t1.tm_mday) + '_' + str(
                        #    t1.tm_hour) + '_' + str(t1.tm_min)

                        fpr, tpr, threshold = roc_curve(ground_truth, pred_list)
                        print("len(fpr):", len(fpr))
                        np.savetxt('fpr_.csv', fpr)
                        np.savetxt('tpr_.csv', tpr)
                        np.savetxt('ROC_threshold_.csv', threshold)
                        precision, recall, threshold = precision_recall_curve(ground_truth, pred_list)
                        print("len(recall):", len(recall))
                        np.savetxt('precision_.csv', precision)
                        np.savetxt('recall_.csv', recall)
                        np.savetxt('PRC_threshold_.csv', threshold)
                        
                        break

                print('Valid auc & aupr:', valid_auc, valid_aupr, ";  ", 'Test auc & aupr:', test_auc, test_aupr)

    return test_auc, test_aupr, xy_roc_sampling, xy_pr_sampling, best_DTI_potential


def main(args):
    drug_d, drug_ch, drug_di, drug_side, protein_p, protein_seq, protein_di, dti_original = loda_data()

    # sampling
    whole_positive_index = []
    whole_negative_index = []
    for i in range(np.shape(dti_original)[0]):
        for j in range(np.shape(dti_original)[1]):
            if int(dti_original[i][j]) == 1:
                whole_positive_index.append([i, j])
            elif int(dti_original[i][j]) == 0:
                whole_negative_index.append([i, j])

    # pos:neg=1:10
    '''negative_sample_index = np.random.choice(np.arange(len(whole_negative_index)),
                                             size=10 * len(whole_positive_index), replace=False)'''

    # All unknown DTI pairs all treated as negative examples
    negative_sample_index = np.random.choice(np.arange(len(whole_negative_index)),
                                             size=len(whole_negative_index), replace=False)

    data_set = np.zeros((len(negative_sample_index) + len(whole_positive_index), 3), dtype=int)
    count = 0
    for i in whole_positive_index:
        data_set[count][0] = i[0]
        data_set[count][1] = i[1]
        data_set[count][2] = 1
        count += 1
    for i in negative_sample_index:
        data_set[count][0] = whole_negative_index[i][0]
        data_set[count][1] = whole_negative_index[i][1]
        data_set[count][2] = 0
        count += 1

    test_auc_round = []
    test_aupr_round = []
    tpr_mean = []
    fpr = []
    precision_mean = []
    recall = []

    rounds = args.rounds
    for r in range(rounds):
        print("----------------------------------------")

        test_auc_fold = []
        test_aupr_fold = []

        kf = StratifiedKFold(n_splits=10, random_state=None, shuffle=True)
        k_fold = 0

        for train_index, test_index in kf.split(data_set[:, :2], data_set[:, 2]):
            train = data_set[train_index]
            DTItest = data_set[test_index]
            DTItrain, DTIvalid = train_test_split(train, test_size=0.05, random_state=None)
            
            '''X_train = DTItrain[:, :2]
            y_train = DTItrain[:, 2]
            #counter = Counter(y_train)
            #print(counter)
            undersample = RandomUnderSampler(sampling_strategy=0.5)
            X_train_underSample, y_train_underSample  = undersample.fit_resample(X_train, y_train)
            #counter = Counter(X_train_underSample)
            #print(counter)
            DTItrain = np.hstack((X_train_underSample, np.vstack(y_train_underSample)))'''

            k_fold += 1
            print("--------------------------------------------------------------")
            print("round ", r + 1, " of ", rounds, ":", "KFold ", k_fold, " of 10")
            print("--------------------------------------------------------------")

            time_roundStart = time.time()

            t_auc, t_aupr, xy_roc, xy_pr, DTI_potential = TrainAndEvaluate(DTItrain, DTIvalid, DTItest, args, drug_d,
                                                                           drug_ch, drug_di, drug_side, protein_p,
                                                                           protein_seq, protein_di)

            time_roundEnd = time.time()
            print("Time spent in this fold:", time_roundEnd - time_roundStart)
            test_auc_fold.append(t_auc)
            test_aupr_fold.append(t_aupr)

            order_txt1 = 'DTI_potential_' + 'r' + str(r + 1) + '_f' + str(k_fold) + '.csv'
            np.savetxt(order_txt1, DTI_potential.detach().cpu().numpy(), fmt='%-.4f', delimiter=',')
            top_values, top_indices = th.topk(DTI_potential, 40)
            order_txt2 = 'top40_' + 'r' + str(r + 1) + '_f' + str(k_fold) + '.csv'
            np.savetxt(order_txt2, top_indices.detach().cpu().numpy(), fmt='%d', delimiter=',')

            # pos:neg=1:10
            '''if not fpr:
                fpr = [_v[0] for _v in xy_roc]
            if not recall:
                recall = [_v[0] for _v in xy_pr]

            temp = [_v[1] for _v in xy_roc]
            tpr_mean.append(temp)
            temp = [_v[1] for _v in xy_pr]
            precision_mean.append(temp)'''

        print("Training and evaluation is OK.")

        test_auc_round.append(np.mean(test_auc_fold))
        test_aupr_round.append(np.mean(test_aupr_fold))

    #t1 = time.localtime()
    #time_creat_txt = str(t1.tm_year) + '_' + str(t1.tm_mon) + '_' + str(t1.tm_mday) + '_' + str(t1.tm_hour) + '_' + str(
    #    t1.tm_min)
    
    np.savetxt('test_auc_' , test_auc_round)
    np.savetxt('test_aupr_' , test_aupr_round)
    

    # pos:neg=1:10
    '''tpr = (np.mean(np.array(tpr_mean), axis=0)).tolist()
    precision = (np.mean(np.array(precision_mean), axis=0)).tolist()

    np.savetxt('fpr.csv', fpr, fmt='%-.4f', delimiter=',')
    np.savetxt('tpr.csv', tpr, fmt='%-.4f', delimiter=',')
    np.savetxt('recall.csv', recall, fmt='%-.4f', delimiter=',')
    np.savetxt('precision.csv', precision, fmt='%-.4f', delimiter=',')'''


if __name__ == "__main__":
    args = parse_args()
    print(args)

    start = time.time()
    main(args)
    end = time.time()
    print("Total time:", end - start)

Namespace(epochs=4000, rounds=1, device='cuda', dim_embedding=1000, k=3, lr=0.001, weight_decay=0, reg_lambda=1, patience=6, alpha=0.9, edge_drop=0.5, f='C:\\Users\\USER\\AppData\\Roaming\\jupyter\\runtime\\kernel-1896fcd4-420a-4ca8-9110-bc6f9ce89b52.json')
Load data finished.
----------------------------------------
--------------------------------------------------------------
round  1  of  1 : KFold  1  of 10
--------------------------------------------------------------


100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:00<00:00, 879.25it/s]
  Drug_Di_emb = th.tensor(Drug_Di_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:00<00:00, 724.12it/s]
  Drug_PDi_emb = th.tensor(Drug_PDi_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:01<00:00, 646.46it/s]
  Drug_D2_emb = th.tensor(Drug_D2_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:01<00:00, 685.78it/s]
  Drug_D4_emb = th.tensor(Drug_D4_emb).to(self.device)
100%|█████████████████████████████████████████████████████████████████████████████| 1512/1512 [00:01<00:00, 909.53it/s]
  Protein_Di_emb = th.tensor(Protein_Di_emb).to(self.device)
100%|█████████████████████████████████████████████████████████████████████████████| 1512/1512 [00:01<00:00, 787.82it/

step 0 : Total_loss & DTIloss & L2_loss: 2082887.1 , 1650.5286865234375 , 120155.3203125
Valid auc & aupr: 0.4964060859783488 0.001762270828380805 ;   Test auc & aupr: 0.5178592754900125 0.001848806593122057
step 25 : Total_loss & DTIloss & L2_loss: 1147700.5 , 4630.6083984375 , 117214.0
Valid auc & aupr: 0.4157440742383132 0.0013437488647997287 ;   Test auc & aupr: 0.5178592754900125 0.001848806593122057
step 50 : Total_loss & DTIloss & L2_loss: 852685.8 , 4046.52880859375 , 108754.578125
Valid auc & aupr: 0.6307563378905516 0.002498760226950481 ;   Test auc & aupr: 0.623971841501797 0.0026166794252606594
step 75 : Total_loss & DTIloss & L2_loss: 791229.94 , 1819.2357177734375 , 97402.7890625
Valid auc & aupr: 0.8399960415586237 0.02593402102318419 ;   Test auc & aupr: 0.8452322921907982 0.033791794778675215
step 100 : Total_loss & DTIloss & L2_loss: 761383.6 , 1633.046875 , 88902.4609375
Valid auc & aupr: 0.8629096548389441 0.036494416501659944 ;   Test auc & aupr: 0.8817687247607704

Valid auc & aupr: 0.9340195466833021 0.614775083949413 ;   Test auc & aupr: 0.9415888291169822 0.56674941697409
step 1025 : Total_loss & DTIloss & L2_loss: 137848.33 , 556.9945068359375 , 34036.41015625
Valid auc & aupr: 0.9349397590361446 0.6147382206675905 ;   Test auc & aupr: 0.9415888291169822 0.56674941697409
step 1050 : Total_loss & DTIloss & L2_loss: 135152.48 , 551.252197265625 , 34129.62109375
Valid auc & aupr: 0.9315154567114372 0.6169400123578452 ;   Test auc & aupr: 0.9448784902160415 0.5688727920971187
step 1075 : Total_loss & DTIloss & L2_loss: 132549.9 , 545.6062622070312 , 34229.7734375
Valid auc & aupr: 0.9315342467812611 0.6173128847783772 ;   Test auc & aupr: 0.9450767601328218 0.5698666284912717
step 1100 : Total_loss & DTIloss & L2_loss: 130734.14 , 539.7762451171875 , 34327.21484375
Valid auc & aupr: 0.9304406647175101 0.6163585978578102 ;   Test auc & aupr: 0.9450767601328218 0.5698666284912717
step 1125 : Total_loss & DTIloss & L2_loss: 129861.59 , 533.949951171

step 2025 : Total_loss & DTIloss & L2_loss: 94268.69 , 412.9160461425781 , 38214.6328125
Valid auc & aupr: 0.9214747950003384 0.6235221848300329 ;   Test auc & aupr: 0.9367946634990029 0.5870771214258285
step 2050 : Total_loss & DTIloss & L2_loss: 93771.09 , 410.8794860839844 , 38287.5625
Valid auc & aupr: 0.9212407959974647 0.6233612626778574 ;   Test auc & aupr: 0.9367946634990029 0.5870771214258285
step 2075 : Total_loss & DTIloss & L2_loss: 93072.78 , 408.7580261230469 , 38362.2890625
Valid auc & aupr: 0.9204937028212664 0.6225458207040785 ;   Test auc & aupr: 0.9367946634990029 0.5870771214258285
step 2100 : Total_loss & DTIloss & L2_loss: 92646.71 , 407.0987854003906 , 38430.08203125
Valid auc & aupr: 0.9209409064830751 0.6233285343383869 ;   Test auc & aupr: 0.9367946634990029 0.5870771214258285
step 2125 : Total_loss & DTIloss & L2_loss: 92043.48 , 405.0478210449219 , 38498.83984375
Valid auc & aupr: 0.9214286966957036 0.6239048627115289 ;   Test auc & aupr: 0.9348361116578201 

100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:00<00:00, 813.93it/s]
  Drug_Di_emb = th.tensor(Drug_Di_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:00<00:00, 735.90it/s]
  Drug_PDi_emb = th.tensor(Drug_PDi_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:00<00:00, 712.98it/s]
  Drug_D2_emb = th.tensor(Drug_D2_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:00<00:00, 731.41it/s]
  Drug_D4_emb = th.tensor(Drug_D4_emb).to(self.device)
100%|█████████████████████████████████████████████████████████████████████████████| 1512/1512 [00:01<00:00, 875.31it/s]
  Protein_Di_emb = th.tensor(Protein_Di_emb).to(self.device)
100%|█████████████████████████████████████████████████████████████████████████████| 1512/1512 [00:01<00:00, 807.70it/

step 0 : Total_loss & DTIloss & L2_loss: 2075125.4 , 1651.612060546875 , 120118.4453125
Valid auc & aupr: 0.5076651569252049 0.0017476843415531464 ;   Test auc & aupr: 0.5278048222987082 0.001811925014032161
step 25 : Total_loss & DTIloss & L2_loss: 1147878.4 , 6129.96484375 , 117150.3515625
Valid auc & aupr: 0.4392373744724203 0.0013044129814199026 ;   Test auc & aupr: 0.5278048222987082 0.001811925014032161
step 50 : Total_loss & DTIloss & L2_loss: 851635.5 , 3822.64306640625 , 108608.09375
Valid auc & aupr: 0.8470856044200153 0.016559035391032643 ;   Test auc & aupr: 0.8354328658013292 0.018615306529690196
step 75 : Total_loss & DTIloss & L2_loss: 806626.4 , 1654.189208984375 , 95436.7734375
Valid auc & aupr: 0.9243412161607628 0.03201020437808136 ;   Test auc & aupr: 0.8750598894915538 0.026979998519862526
step 100 : Total_loss & DTIloss & L2_loss: 743771.0 , 1963.623291015625 , 85380.1796875
Valid auc & aupr: 0.9175583223814157 0.03148888658642222 ;   Test auc & aupr: 0.8750598894

Valid auc & aupr: 0.9636491169708582 0.6396564806228191 ;   Test auc & aupr: 0.9349437078782495 0.5972895468373834
step 1025 : Total_loss & DTIloss & L2_loss: 133885.28 , 568.009033203125 , 34071.27734375
Valid auc & aupr: 0.9640991395677376 0.6418046287562879 ;   Test auc & aupr: 0.9361527713106098 0.5986798831507829
step 1050 : Total_loss & DTIloss & L2_loss: 131907.19 , 561.0204467773438 , 34185.26171875
Valid auc & aupr: 0.9642383866813009 0.641695392705352 ;   Test auc & aupr: 0.9361527713106098 0.5986798831507829
step 1075 : Total_loss & DTIloss & L2_loss: 129662.76 , 555.1941528320312 , 34305.625
Valid auc & aupr: 0.9643294854805475 0.6438646080159109 ;   Test auc & aupr: 0.9360796504902367 0.6011529917056834
step 1100 : Total_loss & DTIloss & L2_loss: 128035.78 , 548.79931640625 , 34429.49609375
Valid auc & aupr: 0.9642006340257572 0.6456850100294693 ;   Test auc & aupr: 0.9349329434073459 0.6025853659821486
step 1125 : Total_loss & DTIloss & L2_loss: 126964.09 , 542.6446533203

step 2025 : Total_loss & DTIloss & L2_loss: 90643.07 , 411.19891357421875 , 38559.6953125
Valid auc & aupr: 0.9519277381344498 0.6746928099093777 ;   Test auc & aupr: 0.927304958091054 0.6201931896245922
step 2050 : Total_loss & DTIloss & L2_loss: 90419.5 , 408.9823913574219 , 38622.921875
Valid auc & aupr: 0.9525583169100164 0.6746036706934473 ;   Test auc & aupr: 0.927304958091054 0.6201931896245922
step 2075 : Total_loss & DTIloss & L2_loss: 89854.44 , 407.75836181640625 , 38688.67578125
Valid auc & aupr: 0.9514812719471506 0.6755368396415944 ;   Test auc & aupr: 0.9277019828106916 0.6206607549235694
step 2100 : Total_loss & DTIloss & L2_loss: 90554.97 , 405.900390625 , 38760.98046875
Valid auc & aupr: 0.9527572022765398 0.6749928987260885 ;   Test auc & aupr: 0.9277019828106916 0.6206607549235694
step 2125 : Total_loss & DTIloss & L2_loss: 89274.84 , 404.57720947265625 , 38852.34375
Valid auc & aupr: 0.9522718891538261 0.676287111734675 ;   Test auc & aupr: 0.9275804461155559 0.621

100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:01<00:00, 695.25it/s]
  Drug_Di_emb = th.tensor(Drug_Di_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:01<00:00, 648.95it/s]
  Drug_PDi_emb = th.tensor(Drug_PDi_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:01<00:00, 597.41it/s]
  Drug_D2_emb = th.tensor(Drug_D2_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:01<00:00, 589.02it/s]
  Drug_D4_emb = th.tensor(Drug_D4_emb).to(self.device)
100%|█████████████████████████████████████████████████████████████████████████████| 1512/1512 [00:01<00:00, 763.74it/s]
  Protein_Di_emb = th.tensor(Protein_Di_emb).to(self.device)
100%|█████████████████████████████████████████████████████████████████████████████| 1512/1512 [00:02<00:00, 678.03it/

step 0 : Total_loss & DTIloss & L2_loss: 2064643.1 , 1655.5084228515625 , 120097.6328125
Valid auc & aupr: 0.5040966464118504 0.0017929050397698227 ;   Test auc & aupr: 0.49122130729068403 0.0016735810625809503
step 25 : Total_loss & DTIloss & L2_loss: 1143109.0 , 4129.2353515625 , 117149.234375
Valid auc & aupr: 0.7204722026715421 0.005024838431215186 ;   Test auc & aupr: 0.7673893845151922 0.005956853640920093
step 50 : Total_loss & DTIloss & L2_loss: 847997.25 , 4935.6650390625 , 108645.8671875
Valid auc & aupr: 0.8434282608431498 0.042888533990038204 ;   Test auc & aupr: 0.8497768384564699 0.0442824133664447
step 75 : Total_loss & DTIloss & L2_loss: 789470.3 , 4495.3837890625 , 97776.9140625
Valid auc & aupr: 0.713885531780059 0.006127270065766568 ;   Test auc & aupr: 0.8497768384564699 0.0442824133664447
step 100 : Total_loss & DTIloss & L2_loss: 757503.0 , 3171.2451171875 , 87466.8984375
Valid auc & aupr: 0.8417962853396069 0.03350509976866723 ;   Test auc & aupr: 0.8497768384564

Valid auc & aupr: 0.937410561479957 0.6495929190007146 ;   Test auc & aupr: 0.9254729130272936 0.6531958414777271
step 1025 : Total_loss & DTIloss & L2_loss: 127329.42 , 559.07421875 , 34215.85546875
Valid auc & aupr: 0.9373105492354672 0.6520999802535189 ;   Test auc & aupr: 0.9261619846309539 0.6554370002588666
step 1050 : Total_loss & DTIloss & L2_loss: 125062.625 , 550.3356323242188 , 34363.6484375
Valid auc & aupr: 0.9372252584194888 0.6534084577244194 ;   Test auc & aupr: 0.9281584060747303 0.658017520562462
step 1075 : Total_loss & DTIloss & L2_loss: 122832.85 , 542.5899658203125 , 34504.46875
Valid auc & aupr: 0.9351922992441131 0.6573907607958502 ;   Test auc & aupr: 0.9289765543520199 0.6599396896103841
step 1100 : Total_loss & DTIloss & L2_loss: 120886.58 , 535.1491088867188 , 34643.62890625
Valid auc & aupr: 0.9352612329173011 0.6579856376250297 ;   Test auc & aupr: 0.9294327351730202 0.6612049264694398
step 1125 : Total_loss & DTIloss & L2_loss: 119052.2 , 528.63232421875 

step 2025 : Total_loss & DTIloss & L2_loss: 91156.86 , 415.5186462402344 , 38301.546875
Valid auc & aupr: 0.939238121676579 0.684144544217645 ;   Test auc & aupr: 0.9301465359665945 0.6830884683797311
step 2050 : Total_loss & DTIloss & L2_loss: 90424.42 , 413.8144226074219 , 38376.78515625
Valid auc & aupr: 0.9419751390123464 0.6837954295633011 ;   Test auc & aupr: 0.9301465359665945 0.6830884683797311
step 2075 : Total_loss & DTIloss & L2_loss: 90004.84 , 412.0242919921875 , 38435.7421875
Valid auc & aupr: 0.9411568145258906 0.6830021367506904 ;   Test auc & aupr: 0.9301465359665945 0.6830884683797311
step 2100 : Total_loss & DTIloss & L2_loss: 89615.766 , 410.7459716796875 , 38497.94140625
Valid auc & aupr: 0.9436641308421871 0.6851862788036102 ;   Test auc & aupr: 0.9288817106354088 0.6839140574122009
step 2125 : Total_loss & DTIloss & L2_loss: 89806.125 , 408.7608642578125 , 38562.25390625
Valid auc & aupr: 0.9445906461445279 0.6839000416981647 ;   Test auc & aupr: 0.92888171063540

100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:00<00:00, 788.52it/s]
  Drug_Di_emb = th.tensor(Drug_Di_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:01<00:00, 705.18it/s]
  Drug_PDi_emb = th.tensor(Drug_PDi_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:01<00:00, 621.94it/s]
  Drug_D2_emb = th.tensor(Drug_D2_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:01<00:00, 677.51it/s]
  Drug_D4_emb = th.tensor(Drug_D4_emb).to(self.device)
100%|█████████████████████████████████████████████████████████████████████████████| 1512/1512 [00:01<00:00, 824.43it/s]
  Protein_Di_emb = th.tensor(Protein_Di_emb).to(self.device)
100%|█████████████████████████████████████████████████████████████████████████████| 1512/1512 [00:01<00:00, 772.23it/

step 0 : Total_loss & DTIloss & L2_loss: 2063226.5 , 1665.590576171875 , 120177.0625
Valid auc & aupr: 0.4981919146237499 0.014894108467646017 ;   Test auc & aupr: 0.4988419941651538 0.0018198227697487392
step 25 : Total_loss & DTIloss & L2_loss: 1145675.6 , 3629.94189453125 , 117232.0625
Valid auc & aupr: 0.7462467250951281 0.007571093508556738 ;   Test auc & aupr: 0.4988419941651538 0.0018198227697487392
step 50 : Total_loss & DTIloss & L2_loss: 853887.0 , 2660.3154296875 , 108804.921875
Valid auc & aupr: 0.8681169816813259 0.032909033138328846 ;   Test auc & aupr: 0.8764545686019453 0.030987773898313982
step 75 : Total_loss & DTIloss & L2_loss: 787191.75 , 3745.8662109375 , 97521.234375
Valid auc & aupr: 0.7344166510718816 0.010488804480466483 ;   Test auc & aupr: 0.8764545686019453 0.030987773898313982
step 100 : Total_loss & DTIloss & L2_loss: 764816.94 , 4321.70068359375 , 87375.625
Valid auc & aupr: 0.7145570821117418 0.01064895658745281 ;   Test auc & aupr: 0.8764545686019453 0

Valid auc & aupr: 0.9510763000852515 0.6030679014511515 ;   Test auc & aupr: 0.9453007534921735 0.6290477329946483
step 1025 : Total_loss & DTIloss & L2_loss: 125885.14 , 555.5535888671875 , 34388.42578125
Valid auc & aupr: 0.9499802466055352 0.6064079560296294 ;   Test auc & aupr: 0.9442123250014037 0.6297893266078727
step 1050 : Total_loss & DTIloss & L2_loss: 123617.96 , 547.3773193359375 , 34503.7265625
Valid auc & aupr: 0.9501879171605017 0.610184831437998 ;   Test auc & aupr: 0.942691810237262 0.6307927059132348
step 1075 : Total_loss & DTIloss & L2_loss: 121594.14 , 539.922119140625 , 34629.9375
Valid auc & aupr: 0.9494853720915726 0.612253865584599 ;   Test auc & aupr: 0.9436285089402134 0.6318128538030259
step 1100 : Total_loss & DTIloss & L2_loss: 119648.34 , 532.4805297851562 , 34756.1171875
Valid auc & aupr: 0.9500813527956251 0.613532974879713 ;   Test auc & aupr: 0.9444177182959317 0.632006108076364
step 1125 : Total_loss & DTIloss & L2_loss: 117901.57 , 526.2767944335938

step 2025 : Total_loss & DTIloss & L2_loss: 90722.6 , 413.31298828125 , 38243.51953125
Valid auc & aupr: 0.939729014617512 0.6245010432052464 ;   Test auc & aupr: 0.941973177409896 0.6413531134679872
step 2050 : Total_loss & DTIloss & L2_loss: 90887.82 , 411.367431640625 , 38304.8046875
Valid auc & aupr: 0.9392835235897115 0.6248338827926528 ;   Test auc & aupr: 0.941973177409896 0.6413531134679872
step 2075 : Total_loss & DTIloss & L2_loss: 90212.77 , 410.41314697265625 , 38386.44140625
Valid auc & aupr: 0.9392934002869441 0.6251386995534426 ;   Test auc & aupr: 0.941973177409896 0.6413531134679872
step 2100 : Total_loss & DTIloss & L2_loss: 89699.31 , 408.59014892578125 , 38449.2890625
Valid auc & aupr: 0.9392273823633378 0.6256384775305369 ;   Test auc & aupr: 0.9421394319252341 0.641536955755724
step 2125 : Total_loss & DTIloss & L2_loss: 89361.016 , 407.2244873046875 , 38509.7890625
Valid auc & aupr: 0.938432308236126 0.6259351443336132 ;   Test auc & aupr: 0.9432549602431888 0.64

Valid auc & aupr: 0.9370849188031521 0.6365789883607308 ;   Test auc & aupr: 0.9509880616487925 0.6488939603373943
step 3050 : Total_loss & DTIloss & L2_loss: 81392.45 , 370.5084533691406 , 40261.07421875
Valid auc & aupr: 0.9370628261909218 0.6358518896982089 ;   Test auc & aupr: 0.9509880616487925 0.6488939603373943
step 3075 : Total_loss & DTIloss & L2_loss: 81676.516 , 369.957275390625 , 40289.08203125
Valid auc & aupr: 0.9371907034287733 0.6373113710521761 ;   Test auc & aupr: 0.9507641956303381 0.6486541835533259
step 3100 : Total_loss & DTIloss & L2_loss: 81129.37 , 369.33465576171875 , 40316.96875
Valid auc & aupr: 0.9372528226561038 0.637046941871287 ;   Test auc & aupr: 0.9507641956303381 0.6486541835533259
step 3125 : Total_loss & DTIloss & L2_loss: 80890.44 , 368.47027587890625 , 40342.6953125
Valid auc & aupr: 0.9378147547460128 0.6377131454580166 ;   Test auc & aupr: 0.9510131631074261 0.6500315109859675
step 3150 : Total_loss & DTIloss & L2_loss: 80817.875 , 367.77850341

100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:01<00:00, 684.72it/s]
  Drug_Di_emb = th.tensor(Drug_Di_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:00<00:00, 718.78it/s]
  Drug_PDi_emb = th.tensor(Drug_PDi_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:01<00:00, 637.26it/s]
  Drug_D2_emb = th.tensor(Drug_D2_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:01<00:00, 637.50it/s]
  Drug_D4_emb = th.tensor(Drug_D4_emb).to(self.device)
100%|█████████████████████████████████████████████████████████████████████████████| 1512/1512 [00:01<00:00, 808.55it/s]
  Protein_Di_emb = th.tensor(Protein_Di_emb).to(self.device)
100%|█████████████████████████████████████████████████████████████████████████████| 1512/1512 [00:02<00:00, 730.08it/

step 0 : Total_loss & DTIloss & L2_loss: 2072838.1 , 1652.953857421875 , 120177.6015625
Valid auc & aupr: 0.485676278359243 0.0018717217381316277 ;   Test auc & aupr: 0.46628277031980137 0.0016560681113814556
step 25 : Total_loss & DTIloss & L2_loss: 1149187.8 , 4503.40625 , 117209.25
Valid auc & aupr: 0.40783210080218146 0.0015340973682548117 ;   Test auc & aupr: 0.46628277031980137 0.0016560681113814556
step 50 : Total_loss & DTIloss & L2_loss: 856912.44 , 2595.232666015625 , 108600.703125
Valid auc & aupr: 0.7267439493422395 0.0073346051058987075 ;   Test auc & aupr: 0.7523105527116983 0.005407601276134788
step 75 : Total_loss & DTIloss & L2_loss: 787311.5 , 2921.82958984375 , 97761.28125
Valid auc & aupr: 0.7477453995264001 0.023097227854910606 ;   Test auc & aupr: 0.7901428619055195 0.018724951760479162
step 100 : Total_loss & DTIloss & L2_loss: 757555.5 , 1629.2008056640625 , 87925.3046875
Valid auc & aupr: 0.8776645529764178 0.06850157170774614 ;   Test auc & aupr: 0.89077819200

Valid auc & aupr: 0.9340071062121607 0.48465312350282036 ;   Test auc & aupr: 0.9259241820921222 0.5733384996480296
step 1025 : Total_loss & DTIloss & L2_loss: 127185.48 , 531.2709350585938 , 34184.8203125
Valid auc & aupr: 0.9344288411846992 0.48882185832056374 ;   Test auc & aupr: 0.9245526676368016 0.575104635803664
step 1050 : Total_loss & DTIloss & L2_loss: 125486.17 , 523.511474609375 , 34323.36328125
Valid auc & aupr: 0.9349481549205585 0.4907550138475369 ;   Test auc & aupr: 0.9238479259079028 0.5771625465651871
step 1075 : Total_loss & DTIloss & L2_loss: 123345.97 , 517.1658935546875 , 34475.80078125
Valid auc & aupr: 0.9358854420893273 0.4928586183507336 ;   Test auc & aupr: 0.9233750924123604 0.5784919939193343
step 1100 : Total_loss & DTIloss & L2_loss: 121345.14 , 511.1755676269531 , 34608.93359375
Valid auc & aupr: 0.93386660164366 0.49616875049591974 ;   Test auc & aupr: 0.9216622468603194 0.5808956058838873
step 1125 : Total_loss & DTIloss & L2_loss: 119514.766 , 505.10

step 2025 : Total_loss & DTIloss & L2_loss: 91581.51 , 404.3831787109375 , 38201.6875
Valid auc & aupr: 0.9305219291122369 0.5192992236025927 ;   Test auc & aupr: 0.9373548989468891 0.6037177919451784
step 2050 : Total_loss & DTIloss & L2_loss: 91295.375 , 402.50347900390625 , 38262.125
Valid auc & aupr: 0.9288482652443032 0.5192514142063505 ;   Test auc & aupr: 0.9373548989468891 0.6037177919451784
step 2075 : Total_loss & DTIloss & L2_loss: 91254.92 , 400.85198974609375 , 38329.03515625
Valid auc & aupr: 0.9300707656237761 0.5207254333856691 ;   Test auc & aupr: 0.9372840299743584 0.6063333885880817
step 2100 : Total_loss & DTIloss & L2_loss: 90800.2 , 399.6393127441406 , 38408.828125
Valid auc & aupr: 0.9302278652200523 0.5199511487692446 ;   Test auc & aupr: 0.9372840299743584 0.6063333885880817
step 2125 : Total_loss & DTIloss & L2_loss: 90216.97 , 398.41143798828125 , 38477.19921875
Valid auc & aupr: 0.9306146400007436 0.5212178212168842 ;   Test auc & aupr: 0.9376316486209113 0.

100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:00<00:00, 797.48it/s]
  Drug_Di_emb = th.tensor(Drug_Di_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:01<00:00, 607.20it/s]
  Drug_PDi_emb = th.tensor(Drug_PDi_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:00<00:00, 834.91it/s]
  Drug_D2_emb = th.tensor(Drug_D2_emb).to(self.device)
100%|███████████████████████████████████████████████████████████████████████████████| 708/708 [00:01<00:00, 638.99it/s]
  Drug_D4_emb = th.tensor(Drug_D4_emb).to(self.device)
100%|█████████████████████████████████████████████████████████████████████████████| 1512/1512 [00:01<00:00, 803.99it/s]
  Protein_Di_emb = th.tensor(Protein_Di_emb).to(self.device)
100%|█████████████████████████████████████████████████████████████████████████████| 1512/1512 [00:02<00:00, 719.15it/

step 0 : Total_loss & DTIloss & L2_loss: 2072431.2 , 1646.767578125 , 120158.90625
Valid auc & aupr: 0.5381881863073004 0.0020207215320403715 ;   Test auc & aupr: 0.5272051070891589 0.002125988426029931
step 25 : Total_loss & DTIloss & L2_loss: 1142238.5 , 4365.8857421875 , 117235.25
Valid auc & aupr: 0.749767253668448 0.007453578179105176 ;   Test auc & aupr: 0.730256094341712 0.006423455905142364
step 50 : Total_loss & DTIloss & L2_loss: 849404.25 , 2041.3812255859375 , 108906.9296875
Valid auc & aupr: 0.8661674771450116 0.04524836840272414 ;   Test auc & aupr: 0.8643199386101181 0.06772493804009601
step 75 : Total_loss & DTIloss & L2_loss: 801671.0 , 4957.634765625 , 99119.6015625
Valid auc & aupr: 0.8210521436964326 0.0248433775368562 ;   Test auc & aupr: 0.8643199386101181 0.06772493804009601
step 100 : Total_loss & DTIloss & L2_loss: 758157.1 , 1821.94921875 , 88998.234375
Valid auc & aupr: 0.8656114163453564 0.03146139238154334 ;   Test auc & aupr: 0.8643199386101181 0.067724938

In [None]:
import time
import torch as th
import numpy as np
import matplotlib.pyplot as plt
import copy
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.metrics import average_precision_score, precision_recall_curve

'''network_path = 'MSCMF/'
fpr_o = np.loadtxt(network_path + 'fpr.csv', delimiter=',')
tpr_o = np.loadtxt(network_path + 'tpr.csv', delimiter=',')
recall_o = np.loadtxt(network_path + 'recall.csv', delimiter=',')
precision_o = np.loadtxt(network_path + 'precision.csv', delimiter=',')'''

network_path = ''
fpr = np.loadtxt(network_path +'fpr_.csv', delimiter=',')
tpr = np.loadtxt(network_path +'tpr_.csv', delimiter=',')
recall = np.loadtxt(network_path +'recall_.csv', delimiter=',')
precision = np.loadtxt(network_path +'precision_.csv', delimiter=',')

'''roc, prc = [], []
auc, aupr = [], []
color = ['blue', 'orange', 'gray', 'gold', 'royalblue']
label = ['MSCMF', 'TL_HGBI', 'DTINet', 'NeoDTI', 'GADTI']
for i in range(5):
    roc.append(np.loadtxt('roc' + str(i + 1) + '.csv'))
    prc.append(np.loadtxt('prc' + str(i + 1) + '.csv'))
for i in range(5):
    for x, y in roc[i]:
        test_auc = 0.
        prev_x = 0.
        if x != prev_x:
            test_auc += (x - prev_x) * y
            prev_x = x
        auc.append(test_auc)
    for x, y in prc[i]:
        test_aupr = 0.
        prev_x = 0.
        if x != prev_x:
            test_aupr += (x - prev_x) * y
            prev_x = x
        aupr.append(test_aupr)
plt.title("ROC curve")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
for i in range(5):
    x = [x_[0] for x_ in roc[i]]
    y = [x_[1] for x_ in roc[i]]
    plt.plot(x, y, linewidth='1', label=label[i], color=color[i])
    # plt.plot(x, y, linewidth='1', label="test", color=' coral ', linestyle=':', marker='|')
plt.legend()
plt.show()
plt.title("PR curve")
plt.xlabel("Recall")
plt.ylabel("Precision")
for i in range(5):
    x = [x_[0] for x_ in prc[i]]
    y = [x_[1] for x_ in prc[i]]
    plt.plot(x, y, linewidth='1', label=label[i], color=color[i])
plt.legend()
plt.show()'''

'''fpr = []
tpr = []
recall=[]
precision=[]
for i in (range(len(fpr_o))):
    if i % 10== 0:
        fpr.append(fpr_o[i])
        tpr.append(tpr_o[i])
for i in (range(len(recall_o))):
    if i % 10== 0:
        recall.append(recall_o[i])
        precision.append(precision_o[i])
np.savetxt('fpr1.csv', fpr, fmt='%-.4f', delimiter=',')
np.savetxt('tpr1.csv', tpr, fmt='%-.4f', delimiter=',')
np.savetxt('recall1.csv', recall, fmt='%-.4f', delimiter=',')
np.savetxt('precision1.csv', precision, fmt='%-.4f', delimiter=',')'''

auc = 0.
prev_x = 0.
for i in range(len(fpr)):
    if fpr[i] != prev_x:
        auc += (fpr[i] - prev_x) * tpr[i]
        prev_x = fpr[i]

aupr = 0.
prev_x = 0.
for i in range(len(recall)):
    if recall[i] != prev_x:
        aupr += (recall[i] - prev_x) * precision[i]
        prev_x = recall[i]

plt.title("ROC curve of %s (AUC = %.4f)" % ('GADTI', auc))
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.plot(fpr, tpr, 'r')
plt.show()

plt.title("PR curve of %s (AUPR = %.4f)" % ('GADTI', aupr))
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.plot(recall, precision)
plt.show()