In [None]:
import os
import math
import time
import torch
import random
import argparse
import numpy as np
import torch.nn as nn
from scipy.io import loadmat
import torch.nn.functional as F
from sklearn.metrics import roc_curve, roc_auc_score, auc, precision_recall_curve, average_precision_score

In [None]:
parser= argparse.ArgumentParser(description= 'Parser for Arguments')
parser.add_argument('-seed', type= int, default= 0)
parser.add_argument('-num_ent', type= int, default= 1209+ 172+ 154)
parser.add_argument('-num_drug', type= int, default= 1209)
parser.add_argument('-num_micr', type= int, default= 172)
parser.add_argument('-num_dise', type= int, default= 154)
# drug_micr_rel, 0; drug_dise_rel, 1; micr_dise_rel, 2; drug_inter_rel, 3; micr_inter_rel, 4; micr_drug_rel, 5; dise_drug_rel, 6; dies_micr_rel, 7.
parser.add_argument('-drug_name_path', type= str, default= '../mdd/drug/drug_name.txt')
parser.add_argument('-micr_name_path', type= str, default= '../mdd/microbe/microbe_name.txt')
parser.add_argument('-dise_name_path', type= str, default= '../mdd/disease/disease_name.txt')
parser.add_argument('-drug_micr_adj_path', type= str, default= '../mdd/adj/microbe_drug_adj.txt')
parser.add_argument('-drug_struct_simi_path', type= str, default= '../mdd/drug/drug_struct_simi.txt')
parser.add_argument('-drug_inter_adj_path', type= str, default= '../mdd/drug/drug_interact_adj.txt')
parser.add_argument('-drug_dise_adj_path', type= str, default= '../mdd/adj/drug_disease_adj.txt')
parser.add_argument('-micr_ani_path', type= str, default= '../mdd/microbe/microbe_ani_simi.txt')
parser.add_argument('-micr_fun_path', type= str, default= '../mdd/microbe/microbe_gene_simi.txt')
parser.add_argument('-micr_inter_adj_path', type= str, default= '../mdd/microbe/microbe_interact_adj.txt')
parser.add_argument('-micr_dise_adj_path', type= str, default= '../mdd/adj/microbe_disease_adj.txt')
parser.add_argument('-dise_simi_path', type= str, default= '../mdd/disease/disease_dag_simi.txt')
parser.add_argument('-train_ratio', type= float, default= 0.8)
parser.add_argument('-valid_ratio', type= float, default= 0.1)
parser.add_argument('-test_ratio', type= float, default= 0.1)
parser.add_argument('-batch_size', type= int, default= 128)
parser.add_argument('-hid_dim', type= int, default= 128)
parser.add_argument('-out_dim', type= int, default= 3)
parser.add_argument('-heads', type= int, default= 128)
parser.add_argument('-dropout', type= int, default= 64)
parser.add_argument('-alpha', type= int, default= 32)
parser.add_argument('-threshold', type= float, default= 0)
parser.add_argument('-epochs', type= int, default= 300)
parser.add_argument('-patience', type= int, default= 6)
parser.add_argument('-lr', type= float, default= 1e-4)
parser.add_argument('-weight_decay', type= float, default= 0)
parser.add_argument('-device', type= str, default= 'cuda:0')
parser.add_argument('-pt_file', type= str, default= 'checkpoint/')
parser.add_argument('-memo_file4mkgcn', type= str, default= 'memo/gsamda.txt')
parser.add_argument('-pt_file_name', type= str, default= 'gsamda.pt')
parser.add_argument('-test_result_file', type= str, default= 'result/gsamda_test_result.txt')
params= parser.parse_args([])

In [None]:
random.seed(params.seed)
np.random.seed(params.seed)
torch.manual_seed(params.seed)
torch.cuda.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
os.environ['PYTHONHASHSEED'] = str(params.seed)    
torch.backends.cudnn.deterministic = True

In [None]:
class dataloader(object):
    def __init__(self, params):
        super().__init__()
        self.params= params
        self.drug_micr_asso_mat= self.load_adj_data(self.params.drug_micr_adj_path, sp= (self.params.num_drug, self.params.num_micr))
        self.train_xy, self.valid_xy, self.test_xy= self.split_dataset()
        self.drug_micr_asso_mat, self.drug_micr_asso_mat_zy, self.drug_dise_asso_mat, self.drug_inter_mat, self.drug_struct_simi_mat= self.load_drug_data()
        self.micr_ani_mat, self.micr_inter_mat, self.micr_dise_asso_mat, self.micr_asso_simi_mat, self.micr_fun_simi_mat= self.load_micr_data()
        self.dise_simi_mat, self.drug_dise_drug_simi_mat, self.micr_dise_micr_simi_mat= self.load_dise_data()
        self.drug_dise_simi_mat, self.micr_dise_simi_mat= self.cos_sim(self.drug_dise_asso_mat).float(), self.cos_sim(self.micr_dise_asso_mat).float()
        self.drug_hip_simi_mat, self.micr_hip_simi_mat= self.hip_sim(self.drug_micr_asso_mat_zy).float(), self.hip_sim(self.drug_micr_asso_mat_zy.T).float()
        self.drug_gip_simi_mat, self.micr_gip_simi_mat= self.gauss_sim(self.drug_micr_asso_mat_zy).float(), self.gauss_sim(self.drug_micr_asso_mat_zy.T).float()
        self.drug_int_simi_mat, self.micr_int_simi_mat= (self.drug_hip_simi_mat+ self.drug_gip_simi_mat).float()/ 2, (self.micr_hip_simi_mat+ self.micr_gip_simi_mat).float()/2
        self.drug_rwr_mat, self.micr_rwr_mat= self.rwr(self.drug_int_simi_mat).float(), self.rwr(self.micr_int_simi_mat).float()
        self.hete_graph_mat= torch.cat([torch.cat([self.drug_int_simi_mat, self.drug_micr_asso_mat_zy], dim= 1),\
                                    torch.cat([self.drug_micr_asso_mat_zy.T, self.micr_int_simi_mat], dim= 1)], dim= 0).float()
        # self.hete_graph_idx= torch.nonzero(self.hete_graph_mat> self.params.threshold).to(torch.long)
    # @ 计算余弦相似性
    def cos_sim(self, mat):
        return torch.matmul(nn.functional.normalize(mat, p= 2, dim= 1), nn.functional.normalize(mat, p= 2, dim= 1).T)

    # @ 计算高斯核相似性
    def gauss_sim(self, mat):
        # 高斯核
        sigma= 1/ torch.diag(torch.matmul(mat, mat.T)).mean()
        # 向量写法
        sim_mat= torch.mul(mat, mat).sum(dim= 1, keepdims= True)+ torch.mul(mat, mat).sum(dim= 1, keepdims= True).T- 2* torch.matmul(mat, mat.T)
        # 返回高斯核相似性矩阵
        sim_mat= torch.exp(-1* sigma* sim_mat)
        # MTK...
        # sim_mat= 1/ (1+ torch.exp(-15* sim_mat+ math.log(9999)))
        return sim_mat
        
    # @ 带重启的随机游走, mat, like, (1373, 1373)
    def rwr(self, mat, times= 100):
        # 0.1概率继续走, 0.9的概率回初始状态
        alpha= 0.1
        # 行归一化
        trans_mat= mat/ (mat.sum(dim= 1, keepdims= True)+ 1e-15)
        state_mat= torch.eye(mat.shape[0])
        # 游走
        for i in range(times):
            state_mat= alpha* torch.matmul(trans_mat, state_mat)+ (1- alpha)* torch.eye(mat.shape[0])
        return state_mat
    
    # @ 计算hamming interaction profile similarity.
    def hip_sim(self, mat):
        sim_ls, dim= [], mat.shape[1]
        for i in range(mat.shape[0]):
            sim_ls.append(((mat[i]- mat) == 0).sum(dim= 1)/ dim)
        return torch.stack(sim_ls)

    
    # @introduce data
    def introduce(self):
        print(f'Drug microbe association num: {self.drug_micr_asso_mat.sum()}\nDrug interaction num: {self.drug_inter_mat.sum()}\nMicrobe interaction num: {self.micr_inter_mat.sum()}')
        print(f'Drug disease association num: {self.drug_dise_asso_mat.sum()}\nMicrobe disease association num: {self.micr_dise_asso_mat.sum()}')

    # @ mask
    def get_asso_mat_zy(self):
        asso_mat_zy= self.drug_micr_asso_mat.clone()
        asso_mat_zy[self.valid_xy[:, 0], self.valid_xy[:, 1]]= 0
        asso_mat_zy[self.test_xy[:, 0], self.test_xy[:, 1]]= 0
        return asso_mat_zy

    # @split data set
    def split_dataset(self):
        train_xy, valid_xy, test_xy= [], [], []
        for i in range(self.params.num_drug):
            first= True
            for j in range(self.params.num_micr):
                if self.drug_micr_asso_mat[i, j]== 1 and first:
                    train_xy.append([i, j])
                    first= False
                else:
                    num= torch.rand(1)
                    if num< self.params.train_ratio:
                        train_xy.append([i, j])
                    elif num>= self.params.train_ratio and num< self.params.train_ratio+ self.params.valid_ratio:
                        valid_xy.append([i, j])
                    else:
                        test_xy.append([i, j])        
        print(f'Spliting data has finished...')
        return torch.tensor(train_xy), torch.tensor(valid_xy), torch.tensor(test_xy)

    # @load disease data
    def load_dise_data(self):
        dise_simi_mat= torch.from_numpy(np.loadtxt(self.params.dise_simi_path, encoding= 'utf-8-sig'))
        drug_dise_drug_simi_mat, micr_dise_micr_simi_mat= torch.matmul(nn.functional.normalize(self.drug_inter_mat, p= 2, dim= 1), nn.functional.normalize(self.drug_inter_mat, p= 2, dim= 1).T),\
        torch.matmul(nn.functional.normalize(self.micr_inter_mat, p= 2, dim= 1), nn.functional.normalize(self.micr_inter_mat, p= 2, dim= 1).T)
        for i in range(self.params.num_drug): drug_dise_drug_simi_mat[i, i]= 1.0
        for i in range(self.params.num_micr): micr_dise_micr_simi_mat[i, i]= 1.0
        return dise_simi_mat, drug_dise_drug_simi_mat, micr_dise_micr_simi_mat

    # @load micr data
    def load_micr_data(self):
        micr_ani_mat= torch.from_numpy(np.loadtxt(self.params.micr_ani_path, encoding= 'utf-8-sig'))
        micr_inter_mat= self.load_adj_data(self.params.micr_inter_adj_path, sp= (self.params.num_micr, self.params.num_micr))
        micr_dise_mat= self.load_adj_data(self.params.micr_dise_adj_path, sp= (self.params.num_micr, self.params.num_dise))
        micr_asso_simi_mat= torch.matmul(nn.functional.normalize(self.drug_micr_asso_mat_zy.T, p= 2, dim= 1), nn.functional.normalize(self.drug_micr_asso_mat_zy.T, p= 2, dim= 1).T)
        micr_fun_simi_mat= torch.from_numpy(np.loadtxt(self.params.micr_fun_path, encoding= 'utf-8-sig'))
        for i in range(self.params.num_micr):micr_asso_simi_mat[i, i]= 1
        for i in range(self.params.num_micr):micr_ani_mat[i, i]= 1            
        return micr_ani_mat.float(), micr_inter_mat.float(), micr_dise_mat.float(), micr_asso_simi_mat.float(), micr_fun_simi_mat.float()

    # @load drug data
    def load_drug_data(self):
        drug_dise_asso_mat= self.load_adj_data(self.params.drug_dise_adj_path, sp= (self.params.num_drug, self.params.num_dise))
        drug_micr_asso_mat= self.load_adj_data(self.params.drug_micr_adj_path, sp= (self.params.num_drug, self.params.num_micr))
        drug_micr_asso_mat_zy= drug_micr_asso_mat.clone()
        drug_micr_asso_mat_zy[self.valid_xy[:, 0], self.valid_xy[:, 1]]= 0
        drug_micr_asso_mat_zy[self.test_xy[:, 0], self.test_xy[:, 1]]= 0
        drug_inter_mat= self.load_adj_data(self.params.drug_inter_adj_path, sp= (self.params.num_drug, self.params.num_drug))
        drug_struct_simi_mat= torch.from_numpy(np.loadtxt(self.params.drug_struct_simi_path, encoding= 'utf-8-sig'))
        return drug_micr_asso_mat.float(), drug_micr_asso_mat_zy.float(), drug_dise_asso_mat.float(), drug_inter_mat.float(), drug_struct_simi_mat.float()
    
    # @load adj data
    def load_adj_data(self, path, sp= (1209, 172)):
        idx= torch.from_numpy(np.loadtxt(path, encoding= 'utf-8-sig')).long()- 1
        mat= torch.zeros((sp[0], sp[1]))
        mat[idx[:, 0], idx[:, 1]]= 1
        return mat

    # write into memo
    def write2memo(self, mr, mrr, hits10):
        with open(f'{self.params.memo_file4kg}', 'a+') as f:
            f.write(f'{self.params.lr_kg}\t{self.params.weight_decay_kg}\t{mr}\t{mrr}\t{hits10}\n')

In [None]:
dl= dataloader(params)
train_label, valid_label, test_label= dl.drug_micr_asso_mat[dl.train_xy[:, 0], dl.train_xy[:, 1]].long(), dl.drug_micr_asso_mat[dl.valid_xy[:, 0], dl.valid_xy[:, 1]].long(), dl.drug_micr_asso_mat[dl.test_xy[:, 0], dl.test_xy[:, 1]].long()
# train
train_xy_label_dataset= torch.utils.data.TensorDataset(dl.train_xy, train_label)
train_loader= torch.utils.data.DataLoader(train_xy_label_dataset, batch_size= params.batch_size, shuffle= True)
# valid
valid_xy_label_dataset= torch.utils.data.TensorDataset(dl.valid_xy, valid_label)
valid_loader= torch.utils.data.DataLoader(valid_xy_label_dataset, batch_size= params.batch_size, shuffle= False)
# test, 
test_xy_label_dataset= torch.utils.data.TensorDataset(dl.test_xy, test_label)
test_loader= torch.utils.data.DataLoader(test_xy_label_dataset, batch_size= params.batch_size, shuffle= False)

In [None]:
class EarlyStopping:
	"""docstring for EarlyStopping"""
	def __init__(self, patience, pt_file= 'checkpoint/', file_name= 'checkpoint.pt', mess_out= True, eps= 0):
		super().__init__()
		self.patience, self.eps, self.pt_file, self.file_name, self.mess_out= patience, eps, pt_file, file_name, mess_out
		self.best_score, self.counter, self.flag= None, 0, False
		if os.path.exists(self.pt_file)== False:os.makedirs(self.pt_file)
	
	def __call__(self, val_loss, model):
		score= -val_loss
		if self.best_score is None:
			self.best_score= score
			self.save_checkpoint(model)
		elif score< self.best_score- self.eps:
			self.counter+= 1
			if self.mess_out:print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
			if self.counter>= self.patience:
				self.flag= True
		else:
			self.best_score= score
			self.save_checkpoint(model)
			self.counter= 0

	def save_checkpoint(self, model):
		torch.save(model, f'{self.pt_file}//{self.file_name}')

In [None]:
# @ Gat层
class GraphAttentionLayer(nn.Module):
	def __init__(self, in_dim, out_dim, dropout, alpha, concat= True):
		super().__init__()
		# 
		self.out_dim= out_dim
		self.leakyrelu= nn.LeakyReLU(alpha)
		self.dropout= dropout
		self.W= nn.Parameter(torch.zeros(size= (in_dim, out_dim)))
		self.a= nn.Parameter(torch.zeros(size= (2* out_dim, 1)))
		nn.init.xavier_uniform_(self.W.data, nn.init.calculate_gain('leaky_relu'))
		nn.init.xavier_uniform_(self.a.data, nn.init.calculate_gain('leaky_relu'))
	def forward(self, x, adj_mat):
		# W, (in_dim, out_dim); x, (1546, in_dim); x, (1546, out_dim);
		x= torch.matmul(x, self.W)
		# 计算注意力
		# x, (1546, out_dim);(out_dim, 1); el, (1546, 1);
		el= torch.matmul(x, self.a[0: self.out_dim])
		er= torch.matmul(x, self.a[self.out_dim:]).T
		att_hat= self.leakyrelu(el+ er)
		zero_mat= -9e15* torch.ones_like(att_hat)
		att_hat= torch.where(adj_mat> 0, att_hat, zero_mat)
		att= F.softmax(att_hat, dim= 1)
		att= F.dropout(att, self.dropout, training= self.training)
		# 聚合消息, (1546, out_dim)
		x= torch.matmul(att, x)
		return F.relu(x)

In [None]:
# @ CNN层
class CNN(nn.Module):
	def __init__(self):
		super(CNN, self).__init__()
		self.seq1=nn.Sequential(nn.Conv2d(1,16,kernel_size=3,padding=1),nn.BatchNorm2d(16),nn.ReLU())
		self.seq2=nn.Sequential(nn.Conv2d(16,6,kernel_size=3,padding=1),nn.BatchNorm2d(6),nn.ReLU())
		self.fc= nn.Sequential(nn.Linear(38640, 256), nn.Dropout(0.5), nn.ReLU(), nn.Linear(256, 2))
		nn.init.xavier_uniform_(self.fc[0].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_uniform_(self.fc[3].weight)

    # x, (batch_size, 1, 2, -1)
	def forward(self, x):
		x=self.seq1(x)
		x=self.seq2(x)
		return self.fc(x.view(x.shape[0], -1))

In [None]:
# @ gat
class GAT(nn.Module):
	def __init__(self, in_dim, hid_dim, out_dim, heads, dropout, alpha= 0.2):
		super().__init__()
		# gat层1
		self.attention_layer1= nn.ModuleList([GraphAttentionLayer(in_dim, hid_dim// heads, dropout, alpha) for i in range(heads)])
		# gat层2
		self.attention_layer2= GraphAttentionLayer(hid_dim, out_dim, dropout, alpha)
	def forward(self, x, adj_mat):
		# 考虑多头
		emb= []
		for attention_layer in self.attention_layer1:
			emb.append(attention_layer(x, adj_mat))
		emb= torch.cat(emb, dim= 1)
		# (1546, 128)
		x= self.attention_layer2(emb, adj_mat)
		return x, torch.sigmoid(torch.matmul(x, x.T))

In [None]:
# 自动编码器
class AE(nn.Module):
	def __init__(self, in_dim):
		super().__init__()
		self.encoder= nn.Sequential(nn.Linear(in_dim, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 32), nn.ReLU())
		self.decoder= nn.Sequential(nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, in_dim), nn.Sigmoid())
		nn.init.xavier_uniform_(self.encoder[0].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_uniform_(self.encoder[2].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_uniform_(self.encoder[4].weight, nn.init.calculate_gain('relu'))		
		nn.init.xavier_uniform_(self.decoder[0].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_uniform_(self.decoder[2].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_uniform_(self.decoder[4].weight, nn.init.calculate_gain('sigmoid'))
	def forward(self, x):
		emb= self.encoder(x)
		x_hat= self.decoder(emb)
		return emb, x_hat

In [None]:
class MLP(nn.Module):
	def __init__(self, in_dim= 4798):
		super().__init__()
		self.fc= nn.Sequential(nn.Linear(in_dim* 2, 1024), nn.ReLU(), nn.Dropout(0.2), nn.Linear(1024, 64), nn.ReLU(), nn.Dropout(0.2), nn.Linear(64, 1), nn.Sigmoid())
		nn.init.xavier_uniform_(self.fc[0].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_uniform_(self.fc[3].weight, nn.init.calculate_gain('relu'))
		nn.init.xavier_uniform_(self.fc[6].weight, nn.init.calculate_gain('sigmoid'))
	def forward(self, x):
		return self.fc(x)

In [None]:
class GSAMDA(nn.Module):
	def __init__(self, fea4drug, fea4micro, x, graph1, in_dim, hid_dim, out_dim, heads, dropout, alpha, loss_mse, loss_kl, ass_mat1, drug_struct_sim, micro_funct_sim, drug_feaByrwr1, micro_feaByrwr1, Sr_dis, Sm_dis):
		super().__init__()
		self.fea4drug, self.fea4micro, self.g1, self.mse, self.loss_kl, self.x= fea4drug, fea4micro, graph1, loss_mse, loss_kl, x
		self.gat4graph= GAT(in_dim, hid_dim, out_dim, heads, dropout, alpha)
		self.ae4drug, self.ae4micro= AE(fea4drug.shape[1]), AE(fea4micro.shape[1])
		self.fc= MLP()
		self.ass_mat1, self.drug_struct_sim, self.micro_funct_sim, self.drug_feaByrwr1, self.micro_feaByrwr1, self.Sr_dis, self.Sm_dis= ass_mat1, drug_struct_sim, micro_funct_sim, drug_feaByrwr1, micro_feaByrwr1, Sr_dis, Sm_dis

	def forward(self, left, right):
		# 先过gat
		emb1, recon1= self.gat4graph(self.x, self.g1)
		gat_emb4drug, gat_emb4micro= emb1[0: 1209], emb1[1209:]
		# ae4drug
		(emb4drug, recon4drug), (emb4micro, recon4micro)= self.ae4drug(self.fea4drug), self.ae4micro(self.fea4micro)
		# ls_loss
		emb4drug_exp_act_val, emb4micro_exp_act_val= 0.05* torch.ones_like(emb4drug), 0.05* torch.ones_like(emb4micro)
		kl_loss4drug= self.loss_kl(F.log_softmax(emb4drug, dim= 1), F.softmax(emb4drug_exp_act_val, dim= 1))
		kl_loss4micro= self.loss_kl(F.log_softmax(emb4micro, dim= 1), F.softmax(emb4micro_exp_act_val, dim= 1))
		kl_loss= kl_loss4drug+ kl_loss4micro
		# 构建特征
		drug_fea= torch.cat((gat_emb4drug, emb4drug, self.drug_struct_sim, self.ass_mat1, self.Sr_dis, self.ass_mat1, self.drug_feaByrwr1, self.ass_mat1), dim= 1)
		micro_fea= torch.cat((gat_emb4micro, emb4micro, self.ass_mat1.T, self.micro_funct_sim, self.ass_mat1.T, self.Sm_dis, self.ass_mat1.T, self.micro_feaByrwr1), dim= 1)
		# drug_fea, (1373, 4456); micro_fea, (173, 4456)
		ass_mat_hat= torch.matmul(F.normalize(drug_fea, p= 2, dim= 1), F.normalize(micro_fea, p= 2, dim= 1).T)
		fea= torch.cat((drug_fea, micro_fea), dim= 0).to(torch.float)
		# 重构损失
		recon_loss= self.mse(recon1, self.g1)+ self.mse(recon4drug, self.fea4drug)+ self.mse(recon4micro, self.fea4micro)+ 0.1* kl_loss
		return ass_mat_hat[left, right].to(torch.float32), recon_loss

In [None]:
def avg_auc_aupr_cpt(test_xy, test_label, pred, ass_mat_shape):
    label_mat, pred_mat= torch.zeros((ass_mat_shape)) -1, torch.zeros((ass_mat_shape)) -1
    label_mat[test_xy[:, 0], test_xy[:, 1]], pred_mat[test_xy[:, 0], test_xy[:, 1]]= test_label* 1.0, pred
    bool_mat4mark_test_examp= (label_mat!= -1)
    aucs, auprs= [], []
    for i in range(ass_mat_shape[0]):
        test_examp_loc= bool_mat4mark_test_examp[i]
        pos_num= label_mat[i, test_examp_loc].sum()
        if pos_num> 0 and (test_examp_loc).sum()- pos_num> 0:
            fpr4rowi, tpr4rowi, _= roc_curve(label_mat[i, test_examp_loc], pred_mat[i, test_examp_loc])
            prec4rowi, recall4rowi, _= precision_recall_curve(label_mat[i, test_examp_loc], pred_mat[i, test_examp_loc])
            prec4rowi[-1]= [1, 0][(int)(prec4rowi[-2]== 0)]
            aucs.append(auc(fpr4rowi, tpr4rowi));auprs.append(auc(recall4rowi, prec4rowi))
    return np.mean(aucs), np.mean(auprs)

In [None]:
# params.device= 'cpu'
drug_fea= torch.cat([dl.drug_micr_asso_mat_zy, dl.drug_struct_simi_mat, dl.drug_micr_asso_mat_zy, dl.drug_rwr_mat, dl.drug_dise_simi_mat], dim= 1)
micr_fea= torch.cat([dl.drug_micr_asso_mat_zy.T, dl.micr_ani_mat, dl.drug_micr_asso_mat_zy.T, dl.micr_rwr_mat, dl.micr_dise_simi_mat], dim= 1)
loss_fuc, loss_mse, loss_kl= nn.BCELoss(), nn.MSELoss(reduction= 'mean'), nn.KLDivLoss(reduction= 'batchmean')
net= GSAMDA(drug_fea.to(params.device), micr_fea.to(params.device), dl.hete_graph_mat.clone().to(params.device), dl.hete_graph_mat.to(params.device), dl.hete_graph_mat.shape[1], 256, 128, 1, 0.4, 0.2, loss_mse, loss_kl, dl.drug_micr_asso_mat_zy.to(params.device), dl.drug_struct_simi_mat.to(params.device), dl.micr_ani_mat.to(params.device), dl.drug_dise_simi_mat.to(params.device), dl.micr_dise_simi_mat.to(params.device), dl.drug_int_simi_mat.to(params.device), dl.micr_int_simi_mat.to(params.device)).to(params.device)
optimizer= torch.optim.Adam(net.parameters(), lr= params.lr)
earlystopping= EarlyStopping(patience= params.patience, pt_file= params.pt_file, file_name= params.pt_file_name, mess_out= True)
pred= []

In [None]:
for ep in range(params.epochs):
    # train
    for step, (xy, label) in enumerate(train_loader):
        net.train()
        time_start= time.time()
        xy, label= xy.to(params.device), label.to(params.device)
        logp, loss2= net(xy[:, 0], xy[:, 1])
        loss= loss_fuc(logp, label.float())+ loss2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        time_end= time.time()
        if step% 500== 0:
            print(f'epoch: {ep+ 1}, step: {step+ 1}, train loss: {loss}, time: {time_end- time_start}')
    # valid
    net.eval(); val_loss= 0; pred= []
    with torch.no_grad():
        for step, (xy, label) in enumerate(valid_loader):
            xy, label= xy.to(params.device), label.to(params.device)
            logp, loss2= net(xy[:, 0], xy[:, 1])
            val_loss= loss_fuc(logp, label.float())+ loss2
            pred.append(logp)
    pred= torch.cat(pred).cpu()
    roc_auc, aupr_auc= avg_auc_aupr_cpt(dl.valid_xy, valid_label, pred, (1209, 172))
    print(f'epoch: {ep+ 1}, valid loss: {val_loss}, auc: {roc_auc}, aupr: {aupr_auc}')
    earlystopping(-(roc_auc+ aupr_auc), net)
    if earlystopping.flag== True:print(f'early_stopping');break;

In [None]:
net= torch.load(f'{params.pt_file}//{params.pt_file_name}')
net.eval();results= []
with torch.no_grad():
    for _, (xy, _) in enumerate(test_loader):
        logp, _= net(xy[:, 0], xy[:, 1])
        results.append(logp)
pred= torch.cat(results, dim= 0)
results= torch.cat([dl.test_xy.to('cpu'), test_label.view(-1, 1).to('cpu'), pred.view(-1, 1).to('cpu')], dim= 1)
print(avg_auc_aupr_cpt(dl.test_xy.to('cpu'), test_label.to('cpu'), pred.to('cpu'), (1209, 172)))
# np.savetxt(fname= params.test_result_file, X= results, delimiter= '\t', encoding= 'utf-8-sig')