In [None]:
import os
import math
import dgl
import time
import torch
import random
import argparse
import numpy as np
import torch.nn as nn
import dgl.function as fn
import scipy.sparse as sp
import torch.nn.functional as F
from dgl.utils import expand_as_pair
from dgl.nn.pytorch import edge_softmax
from sklearn.model_selection import KFold
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import roc_curve, roc_auc_score, auc, precision_recall_curve

In [None]:
parser= argparse.ArgumentParser(description= 'EHGN4MDA')
parser.add_argument('--root-path', type= str, default= os.path.abspath('..'))
parser.add_argument('--seed', type= int, default= 0)
parser.add_argument('--device', type= str, default= 'cuda:0')
parser.add_argument('--hd4gat', type= int, default= 128)
parser.add_argument('--hd4tf', type= int, default= 64)
parser.add_argument('--layers_num', type= int, default= 2)
parser.add_argument('--head4gat', type= int, default= 4)
parser.add_argument('--head4transformer', type= int, default= 8)
parser.add_argument('--feat_dropout', type= float, default= 0.5)
parser.add_argument('--attn_dropout', type= float, default= 0.5)
parser.add_argument('--decoder', type= str, default= 'mlp')
parser.add_argument('--lr', type= float, default= 1e-4)
parser.add_argument('--weight_decay', type= float, default= 5e-4)
parser.add_argument('--patience', type= int, default= 6)
parser.add_argument('--epochs', type= int, default= 300)
parser.add_argument('--batch_size', type= int, default= 128)
parser.add_argument('--kflod_num', type= int, default= 5)
parser.add_argument('--lbda', type= int, default= 4)
parser.add_argument('--wr_time', type= int, default= 2)
parser.add_argument('--i', type= int, default= 0)
parser.add_argument('--threshold', type= float, default= 0.8)
parser.add_argument('-num_ent', type= int, default= 1209+ 172+ 154)
parser.add_argument('-num_drug', type= int, default= 1209)
parser.add_argument('-num_micr', type= int, default= 172)
parser.add_argument('-num_dise', type= int, default= 154)
# drug_micr_rel, 0; drug_dise_rel, 1; micr_dise_rel, 2; drug_inter_rel, 3; micr_inter_rel, 4; micr_drug_rel, 5; dise_drug_rel, 6; dies_micr_rel, 7.
parser.add_argument('-drug_name_path', type= str, default= '../mdd/drug/drug_name.txt')
parser.add_argument('-micr_name_path', type= str, default= '../mdd/microbe/microbe_name.txt')
parser.add_argument('-dise_name_path', type= str, default= '../mdd/disease/disease_name.txt')
parser.add_argument('-drug_micr_adj_path', type= str, default= '../mdd/adj/microbe_drug_adj.txt')
parser.add_argument('-drug_struct_simi_path', type= str, default= '../mdd/drug/drug_struct_simi.txt')
parser.add_argument('-drug_inter_adj_path', type= str, default= '../mdd/drug/drug_interact_adj.txt')
parser.add_argument('-drug_fringer_simi_path', type= str, default= '../mdd/drug/drug_fringer_simi.txt')
parser.add_argument('-drug_dise_adj_path', type= str, default= '../mdd/adj/drug_disease_adj.txt')
parser.add_argument('-micr_ani_path', type= str, default= '../mdd/microbe/microbe_ani_simi.txt')
parser.add_argument('-micr_inter_adj_path', type= str, default= '../mdd/microbe/microbe_interact_adj.txt')
parser.add_argument('-micr_dise_adj_path', type= str, default= '../mdd/adj/microbe_disease_adj.txt')
parser.add_argument('-micr_gene_simi_path', type= str, default= '../mdd/microbe/microbe_gene_simi.txt')
parser.add_argument('-dise_simi_path', type= str, default= '../mdd/disease/disease_dag_simi.txt')
parser.add_argument('-train_ratio', type= float, default= 0.8)
parser.add_argument('-valid_ratio', type= float, default= 0.1)
parser.add_argument('-test_ratio', type= float, default= 0.1)
parser.add_argument('-pt_file', type= str, default= 'checkpoint/')
parser.add_argument('-memo_file4ngmda', type= str, default= 'memo/ngmda.txt')
parser.add_argument('-pt_file_name', type= str, default= 'ngmda.pt')
parser.add_argument('-test_result_file', type= str, default= 'result/ngmda_test_result.txt')
params= parser.parse_args([])

In [None]:
random.seed(params.seed)
np.random.seed(params.seed)
torch.manual_seed(params.seed)
torch.cuda.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
os.environ['PYTHONHASHSEED'] = str(params.seed)    
torch.backends.cudnn.deterministic = True

In [None]:
class dataloader(object):
    def __init__(self, params):
        super().__init__()
        self.params= params
        self.drug_micr_asso_mat, self.drug_struct_simi_mat= self.load_drug_data()
        self.train_xy, self.valid_xy, self.test_xy= self.split_dataset()
        self.drug_micr_asso_mat_zy= self.get_asso_mat_zy()
        self.micr_ani_mat, self.micr_gene_simi_mat, self.micr_inte_simi_mat= self.load_micr_data()
        self.hete_graph01_mat= torch.cat([torch.cat([self.drug_struct_simi_mat>= self.params.threshold, self.drug_micr_asso_mat_zy], dim= 1),\
                                    torch.cat([self.drug_micr_asso_mat_zy.T, self.micr_inte_simi_mat>= self.params.threshold], dim= 1)], dim= 0).float()
        self.fea_mats= [self.drug_struct_simi_mat, self.micr_inte_simi_mat, self.drug_struct_simi_mat>= self.params.threshold, self.micr_inte_simi_mat>= self.params.threshold, self.drug_micr_asso_mat_zy]
        self.topo_fea= self.acquire_topo_fea(self.hete_graph01_mat, self.params.wr_time)
        self.g, self.e_feat= self.get_hete_graph(self.hete_graph01_mat)

    def get_hete_graph(self, adj_mat):
        x_y= adj_mat.nonzero()
        x, y, y_offset= x_y[:, 0], x_y[:, 1], torch.tensor(1209)
        adj= sp.coo_matrix((torch.ones(len(x)), (x, y)), shape= (1381, 1381)).tocsr()        
        g= dgl.DGLGraph(adj)
        g= dgl.add_self_loop(g)
        e_feat= []
        for u, v in zip(*g.edges()):
            u, v= u.item(), v.item()
            if u< 1209 and v< 1209:
                e_type= 0
            elif u< 1209 and v>= 1209:
                e_type= 1
            elif u>= 1209 and v< 1209:
                e_type= 2
            else: 
                e_type= 3
            e_feat.append(e_type)
        # g中每条边的边类型信息
        e_feat= torch.tensor(e_feat, dtype= torch.long)
        return g, e_feat
        
    def acquire_topo_fea(self, ass_mat, k):
    	topo_fea= torch.zeros((1381, k))
    	d_mat= torch.diag(ass_mat.sum(dim= 1)).inverse()
    	topo_mat= torch.matmul(ass_mat, d_mat)
    	for i in range(k):
    		topo_fea[:, i]= torch.diag(topo_mat)
    		topo_mat= torch.matmul(topo_mat, topo_mat)
    	return topo_fea
        
    # @introduce data
    def introduce(self):
        print(f'Drug microbe association num: {self.drug_micr_asso_mat.sum()}\nDrug interaction num: {self.drug_inter_mat.sum()}\nMicrobe interaction num: {self.micr_inter_mat.sum()}')
        print(f'Drug disease association num: {self.drug_dise_asso_mat.sum()}\nMicrobe disease association num: {self.micr_dise_asso_mat.sum()}')

    # @ mask
    def get_asso_mat_zy(self):
        asso_mat_zy= self.drug_micr_asso_mat.clone()
        asso_mat_zy[self.valid_xy[:, 0], self.valid_xy[:, 1]]= 0
        asso_mat_zy[self.test_xy[:, 0], self.test_xy[:, 1]]= 0
        return asso_mat_zy

    # @split data set
    def split_dataset(self):
        train_xy, valid_xy, test_xy= [], [], []
        for i in range(self.params.num_drug):
            first= True
            for j in range(self.params.num_micr):
                if self.drug_micr_asso_mat[i, j]== 1 and first:
                    train_xy.append([i, j])
                    first= False
                else:
                    num= torch.rand(1)
                    if num< self.params.train_ratio:
                        train_xy.append([i, j])
                    elif num>= self.params.train_ratio and num< self.params.train_ratio+ self.params.valid_ratio:
                        valid_xy.append([i, j])
                    else:
                        test_xy.append([i, j])        
        print(f'Spliting data has finished...')
        return torch.tensor(train_xy), torch.tensor(valid_xy), torch.tensor(test_xy)

    # @load disease data
    def load_dise_data(self):
        dise_simi_mat= torch.from_numpy(np.loadtxt(self.params.dise_simi_path, encoding= 'utf-8-sig'))
        drug_dise_drug_simi_mat, micr_dise_micr_simi_mat= torch.matmul(nn.functional.normalize(self.drug_inter_mat, p= 2, dim= 1), nn.functional.normalize(self.drug_inter_mat, p= 2, dim= 1).T),\
        torch.matmul(nn.functional.normalize(self.micr_inter_mat, p= 2, dim= 1), nn.functional.normalize(self.micr_inter_mat, p= 2, dim= 1).T)
        for i in range(self.params.num_drug): drug_dise_drug_simi_mat[i, i]= 1.0
        for i in range(self.params.num_micr): micr_dise_micr_simi_mat[i, i]= 1.0
        return dise_simi_mat, drug_dise_drug_simi_mat, micr_dise_micr_simi_mat

    # @load micr data
    def load_micr_data(self):
        micr_ani_mat= torch.from_numpy(np.loadtxt(self.params.micr_ani_path, encoding= 'utf-8-sig'))
        micr_gene_simi_mat= torch.from_numpy(np.loadtxt(self.params.micr_gene_simi_path, encoding= 'utf-8-sig'))
        micr_inter_mat= self.load_adj_data(self.params.micr_inter_adj_path, sp= (self.params.num_micr, self.params.num_micr))
        micr_dise_mat= self.load_adj_data(self.params.micr_dise_adj_path, sp= (self.params.num_micr, self.params.num_dise))
        micr_asso_simi_mat= torch.matmul(nn.functional.normalize(self.drug_micr_asso_mat_zy.T, p= 2, dim= 1), nn.functional.normalize(self.drug_micr_asso_mat_zy.T, p= 2, dim= 1).T)
        for i in range(self.params.num_micr):micr_asso_simi_mat[i, i]= 1
        for i in range(self.params.num_micr):micr_ani_mat[i, i]= 1            
        return micr_ani_mat, micr_gene_simi_mat, torch.where(micr_ani_mat> 0, micr_ani_mat, micr_gene_simi_mat)

    # @load drug data
    def load_drug_data(self):
        drug_micr_asso_mat= self.load_adj_data(self.params.drug_micr_adj_path, sp= (self.params.num_drug, self.params.num_micr))
        drug_struct_simi_mat= torch.from_numpy(np.loadtxt(self.params.drug_struct_simi_path, encoding= 'utf-8-sig'))
        return drug_micr_asso_mat, drug_struct_simi_mat
    
    # @load adj data
    def load_adj_data(self, path, sp= (1209, 172)):
        idx= torch.from_numpy(np.loadtxt(path, encoding= 'utf-8-sig')).long()- 1
        mat= torch.zeros((sp[0], sp[1]))
        mat[idx[:, 0], idx[:, 1]]= 1
        return mat
    
    # @NRWR, neighborhood random walk with restart
    def random_walk_root(self, A, alpha, epoch, nei, eps, err):
        Mask_mat= (torch.matrix_power(A, nei)- torch.matrix_power(A, nei- 1))> 0
        W= nn.functional.normalize(A, p= 1, dim= 1)
        S_last= torch.eye(A.shape[0])
        for i in range(epoch):
            S_new= alpha* torch.matmul(S_last, W)+ (1- alpha)* torch.eye(A.shape[0])
            S_new= nn.functional.normalize(Mask_mat* S_new, p= 1, dim= 1)
            if (S_new- S_last).abs().sum()<= err* A.shape[0]** 2:
                print('converge...');break
            S_last= S_new
        return S_new

    # write into memo
    def write2memo(self, mr, mrr, hits10):
        with open(f'{self.params.memo_file4kg}', 'a+') as f:
            f.write(f'{self.params.lr_kg}\t{self.params.weight_decay_kg}\t{mr}\t{mrr}\t{hits10}\n')

In [None]:
class HGNConv(nn.Module):
	def __init__(self, in_feats, out_feats, num_heads, feat_drop= 0., attn_drop= 0.):
		super(HGNConv, self).__init__()
		self._num_heads, self._out_feats, self.node_type_num= num_heads, out_feats, 2
		self.fc4node_proj= nn.ModuleList([nn.Linear(in_feats, out_feats* num_heads, bias= False) for i in range(self.node_type_num)])
		self.fc4prob= nn.ModuleList([nn.Linear(out_feats, out_feats, bias= False) for i in range(self._num_heads)])
		# 边类型attn 因为一共四种类型的关系 d2d, d2m, m2d, m2m
		self.etype_attn= nn.Parameter(torch.tensor([1, 1, 1, 1]).to(torch.float))
		self.feat_drop4cont, self.feat_drop4loc, self.attn_drop= nn.Dropout(feat_drop), nn.Dropout(feat_drop), nn.Dropout(attn_drop)
		# fc 4 residual
		self.fc4res= nn.ModuleList([nn.Linear(in_feats, out_feats* num_heads, bias= False) for i in range(self.node_type_num)])
		# tao为1 即没有拓扑
		self.relu, self.tao= nn.ReLU(), 0.5
		self.reset_parameters()

	# 初始化权重
	def reset_parameters(self):
		for fc in self.fc4node_proj: nn.init.xavier_normal_(fc.weight)
		for fc in self.fc4res: nn.init.xavier_normal_(fc.weight)
		for fc in self.fc4prob: nn.init.xavier_normal_(fc.weight, gain= nn.init.calculate_gain('tanh'))

	# 计算注意力
	def edge_attention(self, edges):
		src_feat_nrom, src_sloc, dst_sloc, src_tp_fea, dst_tp_fea= edges.src['ft_norm'], edges.src['sloc'], edges.dst['sloc'], edges.src['topo_fea'], edges.dst['topo_fea']
		a= torch.mul(edges.dst['ft_attn'], src_feat_nrom).sum(dim= 2, keepdim= True)
		sloc_sim= ((self.tao* self.cosine(src_sloc, dst_sloc)+ (1- self.tao)* self.cosine(src_tp_fea, dst_tp_fea))+ 1)/ 2
		# sloc_sim= 1
		return {'e': (sloc_sim* a)* edges.data['e_type_attn'].unsqueeze(-1).unsqueeze(-1).repeat(1, a.shape[1], 1)}

	# 计算余弦相似性
	def cosine(self, a, b):
		return torch.mul(F.normalize(a, p= 2, dim= 2), F.normalize(b, p= 2, dim= 2)).sum(dim= 2, keepdim= True)

	def forward(self, graph, feat, sloc, topo_fea):
		with graph.local_scope():
			sloc= self.feat_drop4loc(sloc.unsqueeze(1))
			topo_fea= self.feat_drop4loc(topo_fea.unsqueeze(1))
			h_src= h_dst= self.feat_drop4cont(feat)
			# feat_src, feat_dst, (1546, 4, 16);
			feat_src= feat_dst= torch.cat([self.fc4node_proj[0](h_src[0: 1209, :]), self.fc4node_proj[1](h_src[1209: , :])], dim= 0).view(-1, self._num_heads, self._out_feats)
			# (1546, 4, 16) >> [(1546, 16), ..., ] >> (1546, 64) >> (1546, 4, 16); 
			feat_dst_attn= torch.cat([F.softmax(torch.tanh(self.fc4prob[i](feat_dst[:, i, :])), dim= 1) for i in range(self._num_heads)], dim= 1).view(-1, self._num_heads, self._out_feats)
			graph.edata.update({'e_type_attn': self.etype_attn[graph.edata['e_feat']]})
			# ft, (1546, 4, 16)按dim= 0, L2;
			graph.srcdata.update({'ft_norm': F.normalize(feat_src, p= 2, dim= 0), 'ft': feat_src, 'sloc': sloc, 'topo_fea': topo_fea})
			# ft_attn, (1546, 4, 16);
			graph.dstdata.update({'ft_attn': feat_dst_attn, 'sloc': sloc, 'topo_fea': topo_fea})
			# 更新边 计算边上attention
			graph.apply_edges(self.edge_attention)
			graph.edata['a']= self.attn_drop(edge_softmax(graph, graph.edata.pop('e')))
			# 消息传递
			graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
			graph.update_all(fn.u_mul_e('sloc', 'a', 'm2'), fn.sum('m2', 'sloc'))
			graph.update_all(fn.u_mul_e('topo_fea', 'a', 'm3'), fn.sum('m3', 'topo_fea'))
			# 更新的特征
			# rst, (1546, 4, 16);
			rst= graph.dstdata['ft']
			sloc2= graph.dstdata['sloc']
			topo_fea2= graph.dstdata['topo_fea']
			# 残差
			resval= torch.cat([self.fc4res[0](h_dst[0: 1209]), self.fc4res[1](h_dst[1209:])], dim= 0).view(h_dst.shape[0], -1, self._out_feats)
			# (1546, 4, 64); -->> (1546, 4, 64);
			restult_fea= self.relu(rst+ resval).mean(1)
			sloc2= (sloc2+ sloc).mean(1)
			topo_fea= (topo_fea2+ topo_fea).mean(1)
			# 返回特征
			return restult_fea, sloc2, topo_fea

In [None]:
# transformer block
class block_tf(nn.Module):
	def __init__(self, node_dim= 64, edge_dim= 64, heads= 2):
		super().__init__()
		self.node_dim, self.edge_dim, self.heads= node_dim, edge_dim, heads
		# Q, K, V , (128, 64)
		[self.Wq1, self.Wk1, self.Wv1]= [nn.Linear(self.node_dim+ self.edge_dim, self.node_dim) for i in range(3)]
		[self.Wq2, self.Wk2, self.Wv2]= [nn.Linear(self.node_dim+ self.edge_dim, self.node_dim) for i in range(3)]		
		# 特征融合
		self.W1, self.W2, self.W3= nn.Linear(self.node_dim, self.node_dim), nn.Linear(self.node_dim, self.node_dim), nn.Linear(self.node_dim, self.node_dim)
		self.relu= nn.ReLU()
		self.zs= torch.zeros((4, node_dim)).to('cuda:0')
		self.layer_norm= nn.LayerNorm(self.node_dim)
		# init parameters
		self.reset_parameters()

	# 初始化权重
	def reset_parameters(self):	
		nn.init.xavier_normal_(self.Wq1.weight)
		nn.init.xavier_normal_(self.Wk1.weight)
		nn.init.xavier_normal_(self.Wv1.weight)
		nn.init.xavier_normal_(self.Wq2.weight)
		nn.init.xavier_normal_(self.Wk2.weight)
		nn.init.xavier_normal_(self.Wv2.weight)
		nn.init.xavier_normal_(self.W1.weight)
		nn.init.xavier_normal_(self.W2.weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.W3.weight)		

	def forward(self, node_mat, edge_type_mat):
		drug_mat, micro_mat, idx0, idx1, idx2, idx3= node_mat[0: 1209], node_mat[1209:], torch.tensor([0]).cuda(), torch.tensor([1]).cuda(), torch.tensor([2]).cuda(), torch.tensor([3]).cuda()
		# 在一个关系之下, rel0, d2d, rel1, d2m, rel2, m2d, rel3, m2m;
		# h1= torch.cat((drug_mat, edge_type_mat(idx0).repeat(drug_mat.shape[0], 1)), dim= 1)
		h1= torch.cat((drug_mat, self.zs[0].repeat(drug_mat.shape[0], 1)), dim= 1)
		# (1373, 64)>> (1373, heads, 64// heads)>> (heads, 1373, 64// heads)
		Q1= self.Wq1(h1).view(h1.shape[0], self.heads, (self.node_dim)// self.heads).transpose(0, 1)
		# (1373, 64)>> (1373, heads, 64// heads)>> (heads, 1373, 64// heads)
		K1= self.Wk1(h1).view(h1.shape[0], self.heads, (self.node_dim)// self.heads).transpose(0, 1)
		# (1373, 64)>> (1373, heads, 64// heads)>> (heads, 1373, 64// heads)
		V1= self.Wv1(h1).view(h1.shape[0], self.heads, (self.node_dim)// self.heads).transpose(0, 1)
		# h2i, h2j= torch.cat((drug_mat, edge_type_mat(idx2).repeat(drug_mat.shape[0], 1)), dim= 1), torch.cat((micro_mat, edge_type_mat(idx2).repeat(micro_mat.shape[0], 1)), dim= 1)
		h2i, h2j= torch.cat((drug_mat, self.zs[1].repeat(drug_mat.shape[0], 1)), dim= 1), torch.cat((micro_mat, self.zs[2].repeat(micro_mat.shape[0], 1)), dim= 1)		
		# (heads, 1373, 64// heads)
		Q2= self.Wq1(h2i).view(h2i.shape[0], self.heads, (self.node_dim)// self.heads).transpose(0, 1)
		# (heads, 173, 64// heads)
		K2= self.Wk2(h2j).view(h2j.shape[0], self.heads, (self.node_dim)// self.heads).transpose(0, 1)
		# (heads, 173, 64// heads)
		V2= self.Wv2(h2j).view(h2j.shape[0], self.heads, (self.node_dim)// self.heads).transpose(0, 1)
		# (heads, 1373, 64// heads), (heads, 1373, 64// heads)>> (4, 1373, 1373); (heads, 1373, 64// heads), (heads, 173, 64// heads)>> (4, 1373, 173); (4, 1373, 1546) 
		att4drug_kij_hat= torch.cat((torch.matmul(Q1, K1.transpose(1, 2))/ math.sqrt(Q1.shape[-1]), torch.matmul(Q2, K2.transpose(1, 2))/ math.sqrt(Q2.shape[-1])), dim= 2)
		# (4, 1373, 1546),>> (4, 1373, 1546); 按邻居数做softmax
		att4drug_kij= F.softmax(att4drug_kij_hat, dim= 2)
		# (4, 1373, 1546), (4, 1546, 32)>> (4, 1373, 32)>> (1373, 4, 32)>> (1373, 128)
		mess1= torch.matmul(att4drug_kij, torch.cat((V1, V2), dim= 1)).transpose(0, 1).reshape(drug_mat.shape[0], -1)
		# 融合, (1373, 128)
		cofusion4drug= self.layer_norm(self.W1(mess1)+ drug_mat)
		# 获得新药物特征
		drug_mat_new= self.layer_norm(self.W3(self.relu(self.W2(cofusion4drug)))+ cofusion4drug)
		# 
		# h3i, h3j= torch.cat((micro_mat, edge_type_mat(idx1).repeat(micro_mat.shape[0], 1)), dim= 1), torch.cat((drug_mat, edge_type_mat(idx1).repeat(drug_mat.shape[0], 1)), dim= 1)
		h3i, h3j= torch.cat((micro_mat, self.zs[1].repeat(micro_mat.shape[0], 1)), dim= 1), torch.cat((drug_mat, self.zs[1].repeat(drug_mat.shape[0], 1)), dim= 1)		
		Q3= self.Wq2(h3i).view(h3i.shape[0], self.heads, (self.node_dim)// self.heads).transpose(0, 1)
		K3= self.Wk1(h3j).view(h3j.shape[0], self.heads, (self.node_dim)// self.heads).transpose(0, 1)
		V3= self.Wv1(h3j).view(h3j.shape[0], self.heads, (self.node_dim)// self.heads).transpose(0, 1)		
		# h4= torch.cat((micro_mat, edge_type_mat(idx3).repeat(micro_mat.shape[0], 1)), dim= 1)
		h4= torch.cat((micro_mat, self.zs[3].repeat(micro_mat.shape[0], 1)), dim= 1)		
		Q4= self.Wq2(h4).view(h4.shape[0], self.heads, (self.node_dim)// self.heads).transpose(0, 1)
		K4= self.Wk2(h4).view(h4.shape[0], self.heads, (self.node_dim)// self.heads).transpose(0, 1)
		V4= self.Wv2(h4).view(h4.shape[0], self.heads, (self.node_dim)// self.heads).transpose(0, 1)				
		att4micro_kij_hat= torch.cat((torch.matmul(Q3, K3.transpose(1, 2))/ math.sqrt(Q3.shape[-1]), torch.matmul(Q4, K4.transpose(1, 2))/ math.sqrt(Q4.shape[-1])), dim= 2)
		att4micro_kij= F.softmax(att4micro_kij_hat, dim= 2)
		mess2= torch.matmul(att4micro_kij, torch.cat((V3, V4), dim= 1)).transpose(0, 1).reshape(micro_mat.shape[0], -1)
		cofusion4micro= self.layer_norm(self.W1(mess2)+ micro_mat)
		micro_mat_new= self.layer_norm(self.W3(self.relu(self.W2(cofusion4micro)))+ cofusion4micro)
		# (1546, 64)
		return torch.cat((drug_mat_new, micro_mat_new), dim= 0)

In [None]:
# @
class Decoder(nn.Module):
	def __init__(self, in_dim1, in_dim2):
		super().__init__()
		self.seq41= nn.Sequential(nn.Linear(in_dim1* 2, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.5))
		self.seq42= nn.Sequential(nn.Linear(in_dim2* 2, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 64), nn.ReLU(), nn.Dropout(0.5))
		self.seq4ori= nn.Sequential(nn.Linear(1381* 2, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.5))
		self.seq4out= nn.Sequential(nn.Linear(448, 64), nn.ReLU(), nn.Dropout(0.5), nn.Linear(64, 2))
		self.reset_para()
	
	def reset_para(self):
		for mode in self.seq41:
			if isinstance(mode, nn.Linear):
				nn.init.xavier_normal_(mode.weight, gain= nn.init.calculate_gain('relu'))
		for mode in self.seq42:
			if isinstance(mode, nn.Linear):
				nn.init.xavier_normal_(mode.weight, gain= nn.init.calculate_gain('relu'))
		for mode in self.seq4ori:
			if isinstance(mode, nn.Linear):
				nn.init.xavier_normal_(mode.weight, gain= nn.init.calculate_gain('relu'))				
		nn.init.xavier_normal_(self.seq4out[0].weight, gain= nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.seq4out[3].weight)

	def forward(self, left_emb1, right_emb1, left_emb2, right_emb2):
		left_ori, right_ori= left_emb1[:, 0: 1381]+ left_emb2[:, 0: 1381], right_emb1[:, 0: 1381]+ right_emb2[:, 0: 1381]
		ori_out= self.seq4ori(torch.cat((left_ori, right_ori), dim= 1))
		# print(out1.shape)
		out1, out2= self.seq41(torch.cat((left_emb1[:, 1381:], right_emb1[:, 1381: ]), dim= 1)), self.seq42(torch.cat((left_emb2[:, 1381: ], right_emb2[:, 1381: ]), dim= 1))
		return self.seq4out(torch.cat((ori_out, out1, out2), dim= 1))

In [None]:
# ensemble heterogeneous model for microbe and drug associate prediction.
class EHGN4MDA(nn.Module):

	def __init__(self, g, features_list, e_feat, hd4gat, hd4tf, num_layers, head4gat, head4tf, feat_drop, attn_drop, k, decoder= 'mlp'):
		super(EHGN4MDA, self).__init__()
		self.g= g
		self.g.edata['e_feat']= e_feat
		self.features_list= features_list
		self.gat_layers= nn.ModuleList()
		self.num_layers= num_layers
		# AE 4 GAT
		self.encoder4sloc= nn.Sequential(nn.Linear(1381, 512), nn.ReLU(), nn.Linear(512, hd4gat), nn.ReLU())
		self.decoder4sloc= nn.Sequential(nn.Linear(hd4gat, 512), nn.ReLU(), nn.Linear(512, 1381), nn.Sigmoid())			
		self.encoder4micro= nn.Sequential(nn.Linear(172, hd4gat), nn.ReLU())
		self.decoder4micro= nn.Sequential(nn.Linear(hd4gat, 172), nn.Sigmoid())
		self.encoder4drug= nn.Sequential(nn.Linear(1209, 512), nn.ReLU(), nn.Linear(512, hd4gat), nn.ReLU())
		self.decoder4drug= nn.Sequential(nn.Linear(hd4gat, 512), nn.ReLU(), nn.Linear(512, 1209), nn.Sigmoid())
		self.decoder4classify= nn.Linear(hd4gat, 2)
		# AE 4 transformer
		self.encoder4drug2= nn.Sequential(nn.Linear(1381, 512), nn.ReLU(), nn.Linear(512, hd4tf), nn.ReLU())
		self.encoder4micro2= nn.Sequential(nn.Linear(1381, 512), nn.ReLU(), nn.Linear(512, hd4tf), nn.ReLU())
		self.decoder4drug2= nn.Sequential(nn.Linear(hd4tf, 512), nn.ReLU(), nn.Linear(512, 1381), nn.Sigmoid())			
		self.decoder4micro2= nn.Sequential(nn.Linear(hd4tf, 512), nn.ReLU(), nn.Linear(512, 1381), nn.Sigmoid())
		self.decoder4classify2= nn.Linear(hd4tf, 2)
		# 边类型嵌入
		self.edge_type_emb= nn.Embedding(4, hd4tf)
		self.edge_type_emb.weight= nn.Parameter(torch.zeros((4, hd4tf)).to(torch.float))
		# self.edge_type_emb= nn.Parameter(torch.zeros((4, hd4tf)).to(torch.float))
		# transformer layers
		self.tf_layers= nn.ModuleList([block_tf(node_dim= hd4tf, edge_dim= hd4tf, heads= head4tf) for i in range(num_layers)])
		# gat layers
		for l in range(0, num_layers):
			self.gat_layers.append(HGNConv(hd4gat, hd4gat, head4gat, feat_drop, attn_drop))
		# Decoder layers
		self.decoder= Decoder((hd4gat* 2+ 2)* (num_layers+ 1), (hd4tf)* (num_layers+ 1))
		# self.decoder= Decoder4Path(1546+ (hd4tf)* (num_layers+ 1))
		# init parameters
		self.reset_parameters()

	# 初始化权重
	def reset_parameters(self):
		# encoder& decoder4 gat
		nn.init.xavier_normal_(self.encoder4sloc[0].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.encoder4sloc[2].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.decoder4sloc[0].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.decoder4sloc[2].weight, nn.init.calculate_gain('sigmoid'))
		nn.init.xavier_normal_(self.encoder4drug[0].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.encoder4drug[2].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.decoder4drug[0].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.decoder4drug[2].weight, nn.init.calculate_gain('sigmoid'))
		nn.init.xavier_normal_(self.encoder4micro[0].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.decoder4micro[0].weight, nn.init.calculate_gain('sigmoid'))
		nn.init.xavier_normal_(self.decoder4classify.weight)
		# encoder& decoder4 transformer
		nn.init.xavier_normal_(self.encoder4drug2[0].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.encoder4drug2[2].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.decoder4drug2[0].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.decoder4drug2[2].weight, nn.init.calculate_gain('sigmoid'))
		nn.init.xavier_normal_(self.encoder4micro2[0].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.encoder4micro2[2].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.decoder4micro2[0].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_normal_(self.decoder4micro2[2].weight, nn.init.calculate_gain('sigmoid'))
		nn.init.xavier_normal_(self.decoder4classify2.weight)		

	def forward(self, left, right, topo_fea, loss_fun2, loss_fun1):

		# 对位置信息进行编码
		ass_mat= torch.cat((torch.cat((self.features_list[2], self.features_list[-1]), dim= 1), torch.cat((self.features_list[-1].T, self.features_list[3]), dim= 1)), dim= 0).to(torch.float)
		ass_mat2= torch.cat((torch.cat((self.features_list[0], self.features_list[-1]), dim= 1), torch.cat((self.features_list[-1].T, self.features_list[1]), dim= 1)), dim= 0).to(torch.float)
		sloc= self.encoder4sloc(ass_mat)
		sloc0= sloc.clone()
		# 编码并形成特征用于Net1
		drug_emb, micro_emb= self.encoder4drug(self.features_list[0]), self.encoder4micro(self.features_list[1])		
		h1= torch.cat((drug_emb, micro_emb), dim= 0)
		# 编码并形成特征用于Net2
		drug_emb2, micro_emb2= self.encoder4drug2(ass_mat2[0: 1209]), self.encoder4micro2(ass_mat2[1209:])
		h2= torch.cat((drug_emb2, micro_emb2), dim= 0)
		emb1, emb2= [], []
		# transformer模块
		emb2.append(torch.cat((ass_mat2, h2), dim= 1))
		for tf_layer in self.tf_layers:
			h2= tf_layer(h2, self.edge_type_emb)
			emb2.append(h2)
		# GAT模块, (64+ 64+ 3)* 3
		emb1.append(torch.cat((ass_mat, h1, sloc, topo_fea), dim= 1))
		for l in range(self.num_layers):
			h1, sloc, topo_fea= self.gat_layers[l](self.g, h1, sloc, topo_fea)
			emb1.append(torch.cat((h1, sloc, topo_fea), dim= 1))
		emb1, emb2= torch.cat(emb1, dim= 1), torch.cat(emb2, dim= 1)
		batch_label0, right4micro= torch.zeros(1381).to(torch.long).cuda(), right- 1209
		loss4gat_recon= loss_fun2(self.decoder4drug(drug_emb[left]), self.features_list[0][left])+ loss_fun2(self.decoder4micro(micro_emb[right4micro]), self.features_list[1][right4micro]) + loss_fun2(self.decoder4sloc(sloc0[left]), ass_mat[left])+ loss_fun2(self.decoder4sloc(sloc0[right]), ass_mat[right])
		loss4gat_classify= loss_fun1(self.decoder4classify(drug_emb[left]), batch_label0[left])+ loss_fun1(self.decoder4classify(micro_emb[right4micro]), batch_label0[right4micro]+ 1)
		loss4transformer_recon= loss_fun2(self.decoder4drug2(drug_emb2[left]), ass_mat2[left])+ loss_fun2(self.decoder4micro2(micro_emb2[right4micro]), ass_mat2[right])
		loss4transformer_classify= loss_fun1(self.decoder4classify2(drug_emb2[left]), batch_label0[left])+ loss_fun1(self.decoder4classify2(micro_emb2[right4micro]), batch_label0[right]+ 1)
		return loss4gat_recon+ loss4gat_classify+ loss4transformer_recon+ loss4transformer_classify, self.decoder(emb1[left], emb1[right], emb2[left], emb2[right])

In [None]:
dl= dataloader(params)
train_label, valid_label, test_label= dl.drug_micr_asso_mat[dl.train_xy[:, 0], dl.train_xy[:, 1]].long(), dl.drug_micr_asso_mat[dl.valid_xy[:, 0], dl.valid_xy[:, 1]].long(), dl.drug_micr_asso_mat[dl.test_xy[:, 0], dl.test_xy[:, 1]].long()
# train
train_xy_label_dataset= torch.utils.data.TensorDataset(dl.train_xy, train_label)
train_loader= torch.utils.data.DataLoader(train_xy_label_dataset, batch_size= params.batch_size, shuffle= True)
# valid
valid_xy_label_dataset= torch.utils.data.TensorDataset(dl.valid_xy, valid_label)
valid_loader= torch.utils.data.DataLoader(valid_xy_label_dataset, batch_size= params.batch_size, shuffle= False)
# test, 
test_xy_label_dataset= torch.utils.data.TensorDataset(dl.test_xy, test_label)
test_loader= torch.utils.data.DataLoader(test_xy_label_dataset, batch_size= params.batch_size, shuffle= False)

In [None]:
def avg_auc_aupr_cpt(test_xy, test_label, pred, ass_mat_shape):
    label_mat, pred_mat= torch.zeros((ass_mat_shape)) -1, torch.zeros((ass_mat_shape)) -1
    label_mat[test_xy[:, 0], test_xy[:, 1]], pred_mat[test_xy[:, 0], test_xy[:, 1]]= test_label* 1.0, pred
    bool_mat4mark_test_examp= (label_mat!= -1)
    aucs, auprs= [], []
    for i in range(ass_mat_shape[0]):
        test_examp_loc= bool_mat4mark_test_examp[i]
        pos_num= label_mat[i, test_examp_loc].sum()
        if pos_num> 0 and (test_examp_loc).sum()- pos_num> 0:
            fpr4rowi, tpr4rowi, _= roc_curve(label_mat[i, test_examp_loc], pred_mat[i, test_examp_loc])
            prec4rowi, recall4rowi, _= precision_recall_curve(label_mat[i, test_examp_loc], pred_mat[i, test_examp_loc])
            prec4rowi[-1]= [1, 0][(int)(prec4rowi[-2]== 0)]
            aucs.append(auc(fpr4rowi, tpr4rowi));auprs.append(auc(recall4rowi, prec4rowi))
    return np.mean(aucs), np.mean(auprs)

In [None]:
class EarlyStopping:
	"""docstring for EarlyStopping"""
	def __init__(self, patience, pt_file= 'checkpoint/', file_name= 'checkpoint.pt', mess_out= True, eps= 0):
		super().__init__()
		self.patience, self.eps, self.pt_file, self.file_name, self.mess_out= patience, eps, pt_file, file_name, mess_out
		self.best_score, self.counter, self.flag= None, 0, False
		if os.path.exists(self.pt_file)== False:os.makedirs(self.pt_file)
	
	def __call__(self, val_loss, model):
		score= -val_loss
		if self.best_score is None:
			self.best_score= score
			self.save_checkpoint(model)
		elif score< self.best_score- self.eps:
			self.counter+= 1
			if self.mess_out:print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
			if self.counter>= self.patience:
				self.flag= True
		else:
			self.best_score= score
			self.save_checkpoint(model)
			self.counter= 0

	def save_checkpoint(self, model):
		torch.save(model, f'{self.pt_file}//{self.file_name}')

In [None]:
dl.fea_mats, dl.g, topo_fea, dl.e_feat= [item.to(torch.float).to(params.device) for item in dl.fea_mats], dl.g.to(params.device), dl.topo_fea.to(params.device), dl.e_feat.to(params.device)
net= EHGN4MDA(dl.g, dl.fea_mats, dl.e_feat, params.hd4gat, params.hd4tf, params.layers_num, params.head4gat, params.head4transformer, params.feat_dropout, params.attn_dropout, params.wr_time, params.decoder).to(params.device)
optimizer= torch.optim.Adam(net.parameters(), lr= params.lr, weight_decay= params.weight_decay)
earlystopping= EarlyStopping(patience= params.patience, pt_file= params.pt_file, file_name= params.pt_file_name, mess_out= True)
loss_func1, loss_func2= nn.CrossEntropyLoss(), nn.MSELoss()
pred= []

In [None]:
for ep in range(params.epochs):
    # train
    for step, (xy, label) in enumerate(train_loader):
        net.train()
        time_start= time.time()
        xy, label= xy.to(params.device), label.to(params.device)
        train_loss2, logp= net(xy[:, 0], xy[:, 1]+ 1209, topo_fea, loss_func2, loss_func1)
        train_loss1= loss_func1(logp, label)
        loss= 0.8* train_loss1+  0.2* train_loss2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        time_end= time.time()
        if step%500== 0:
            print(f'epoch: {ep+ 1}, step: {step+ 1}, train loss: {loss}, time: {time_end- time_start}')
    # valid
    net.eval(); val_loss= 0; pred= []
    with torch.no_grad():
        for step, (xy, label) in enumerate(valid_loader):
            xy, label= xy.to(params.device), label.to(params.device)
            valid_loss2, logp= net(xy[:, 0], xy[:, 1]+ 1209, topo_fea, loss_func2, loss_func1)
            valid_loss1= loss_func1(logp, label)
            val_loss= 0.8* valid_loss1+ 0.2* valid_loss2
            pred.append(logp)
    pred= torch.cat(pred).cpu()
    roc_auc, aupr_auc= avg_auc_aupr_cpt(dl.valid_xy, valid_label, pred[:, 1], (1209, 172))
    print(f'epoch: {ep+ 1}, valid loss: {val_loss}, auc: {roc_auc}, aupr: {aupr_auc}')
    earlystopping(-(roc_auc+ aupr_auc), net)
    if earlystopping.flag== True:print(f'early_stopping');break;

In [None]:
net= torch.load(f'{params.pt_file}//{params.pt_file_name}')
net.eval();results= []
with torch.no_grad():
    for _, (xy, _) in enumerate(test_loader):
        _, logp= net(xy[:, 0], xy[:, 1]+ 1209, topo_fea, loss_func2, loss_func1)
        results.append(logp)
pred= torch.cat(results, dim= 0)
results= torch.cat([dl.test_xy.to('cpu'), test_label.view(-1, 1).to('cpu'), pred[:, 1].view(-1, 1).to('cpu')], dim= 1)
print(avg_auc_aupr_cpt(dl.test_xy.to('cpu'), test_label.to('cpu'), pred[:, 1].to('cpu'), (1209, 172)))
# np.savetxt(fname= params.test_result_file, X= results, delimiter= '\t', encoding= 'utf-8-sig')