In [None]:
import os
import time
import torch
import random
import warnings
import argparse
import numpy as np
from torch_geometric.nn import conv
from sklearn.metrics import roc_curve, roc_auc_score, auc, precision_recall_curve, average_precision_score
import torch.nn as nn
from tqdm import tqdm
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from collections import defaultdict as ddict, Counter
warnings.filterwarnings('ignore')

In [None]:
parser= argparse.ArgumentParser(description= 'Parser for Arguments')
parser.add_argument('-seed', type= int, default= 0)
parser.add_argument('-num_ent', type= int, default= 1209+ 172+ 154)
parser.add_argument('-num_drug', type= int, default= 1209)
parser.add_argument('-num_micr', type= int, default= 172)
parser.add_argument('-num_dise', type= int, default= 154)
# drug_micr_rel, 0; drug_dise_rel, 1; micr_dise_rel, 2; drug_inter_rel, 3; micr_inter_rel, 4; micr_drug_rel, 5; dise_drug_rel, 6; dies_micr_rel, 7.
parser.add_argument('-drug_name_path', type= str, default= '../mdd/drug/drug_name.txt')
parser.add_argument('-micr_name_path', type= str, default= '../mdd/microbe/microbe_name.txt')
parser.add_argument('-dise_name_path', type= str, default= '../mdd/disease/disease_name.txt')
parser.add_argument('-drug_micr_adj_path', type= str, default= '../mdd/adj/microbe_drug_adj.txt')
parser.add_argument('-drug_struct_simi_path', type= str, default= '../mdd/drug/drug_struct_simi.txt')
parser.add_argument('-drug_inter_adj_path', type= str, default= '../mdd/drug/drug_interact_adj.txt')
parser.add_argument('-drug_fringer_simi_path', type= str, default= '../mdd/drug/drug_fringer_simi.txt')
parser.add_argument('-drug_dise_adj_path', type= str, default= '../mdd/adj/drug_disease_adj.txt')
parser.add_argument('-micr_ani_path', type= str, default= '../mdd/microbe/microbe_ani_simi.txt')
parser.add_argument('-microbe_gene_simi_path', type= str, default= '../mdd/microbe/microbe_gene_simi.txt')
parser.add_argument('-micr_inter_adj_path', type= str, default= '../mdd/microbe/microbe_interact_adj.txt')
parser.add_argument('-micr_dise_adj_path', type= str, default= '../mdd/adj/microbe_disease_adj.txt')
parser.add_argument('-dise_simi_path', type= str, default= '../mdd/disease/disease_dag_simi.txt')
parser.add_argument('-train_ratio', type= float, default= 0.8)
parser.add_argument('-valid_ratio', type= float, default= 0.1)
parser.add_argument('-test_ratio', type= float, default= 0.1)
parser.add_argument('-kg_file', type= str, default= 'kg_data/')
parser.add_argument('-embed_dim', type= int, default= 128)
parser.add_argument('-gcn_layer_num', type= int, default= 3)
parser.add_argument('-layer1_hidden_units', type= int, default= 128)
parser.add_argument('-layer2_hidden_units', type= int, default= 64)
parser.add_argument('-layer3_hidden_units', type= int, default= 32)
parser.add_argument('-epochs', type= int, default= 300)
parser.add_argument('-patience', type= int, default= 6)
parser.add_argument('-lr', type= float, default= 5e-3)
parser.add_argument('-weight_decay', type= float, default= 1e-4)
parser.add_argument('-lambda1', type= float, default= 2** (-3))
parser.add_argument('-lambda2', type= float, default= 2** (-4))
parser.add_argument('-h1_gamma', type= float, default= 2 ** (-5))
parser.add_argument('-h2_gamma', type= float, default= 2 ** (-3))
parser.add_argument('-h3_gamma', type= float, default= 2 ** (-3))
parser.add_argument('-threshold', type= float, default= 0.8)
parser.add_argument('-device', type= str, default= 'cpu')
parser.add_argument('-pt_file', type= str, default= 'checkpoint/')
parser.add_argument('-memo_file4mkgcn', type= str, default= 'memo/mkgcn.txt')
parser.add_argument('-pt_file_name', type= str, default= 'mkgcn.pt')
parser.add_argument('-test_result_file', type= str, default= 'result/mkgcn_test_result.txt')
params= parser.parse_args([])

In [None]:
class dataloader(object):
    def __init__(self, params):
        super().__init__()
        self.params= params
        self.drug_micr_asso_mat, self.drug_dise_asso_mat, self.drug_inter_mat, self.drug_struct_simi_mat, self.drug_fringer_simi_mat= self.load_drug_data()
        self.train_xy, self.valid_xy, self.test_xy= self.split_dataset()
        self.drug_micr_asso_mat_zy= self.get_asso_mat_zy()
        self.micr_ani_mat, self.micr_inter_mat, self.micr_dise_asso_mat, self.micr_asso_simi_mat, self.micr_gene_simi_mat= self.load_micr_data()
        self.dise_simi_mat, self.drug_dise_drug_simi_mat, self.micr_dise_micr_simi_mat= self.load_dise_data()
        self.micr_inte_simi_mat= torch.where(self.micr_ani_mat> 0, self.micr_ani_mat, self.micr_gene_simi_mat)
        self.hete_graph_mat= torch.cat([torch.cat([self.drug_struct_simi_mat, self.drug_micr_asso_mat_zy], dim= 1),\
                                    torch.cat([self.drug_micr_asso_mat_zy.T, self.micr_asso_simi_mat], dim= 1)], dim= 0).float()
        self.hete_graph_idx= torch.nonzero(self.hete_graph_mat>= self.params.threshold).to(torch.long)
        self.fea= torch.cat((torch.cat((torch.zeros((self.params.num_drug, self.params.num_drug)), self.drug_micr_asso_mat_zy), dim= 1), torch.cat((self.drug_micr_asso_mat_zy.T, torch.zeros((self.params.num_micr, self.params.num_micr))), dim= 1)), dim= 0)
        self.introduce()

    # @introduce data
    def introduce(self):
        print(f'Drug microbe association num: {self.drug_micr_asso_mat.sum()}\nDrug interaction num: {self.drug_inter_mat.sum()}\nMicrobe interaction num: {self.micr_inter_mat.sum()}')
        print(f'Drug disease association num: {self.drug_dise_asso_mat.sum()}\nMicrobe disease association num: {self.micr_dise_asso_mat.sum()}')

    # @ mask
    def get_asso_mat_zy(self):
        asso_mat_zy= self.drug_micr_asso_mat.clone()
        asso_mat_zy[self.valid_xy[:, 0], self.valid_xy[:, 1]]= 0
        asso_mat_zy[self.test_xy[:, 0], self.test_xy[:, 1]]= 0
        return asso_mat_zy
    
    # @create knowledge graph data
    def create_kg_data(self):
        if os.path.exists(self.params.kg_file)== False:os.makedirs(self.params.kg_file)
        # drug microbe association data
        drug_micr_train, drug_micr_valid, drug_micr_test= self.train_xy[self.drug_micr_asso_mat[self.train_xy[:, 0], self.train_xy[:, 1]]== 1],\
        self.valid_xy[self.drug_micr_asso_mat[self.valid_xy[:, 0], self.valid_xy[:, 1]]== 1],\
        self.test_xy[self.drug_micr_asso_mat[self.test_xy[:, 0], self.test_xy[:, 1]]== 1]
        # drug disease association data
        drug_dise_train, micr_dise_train, drug_inter_train, micr_inter_train= self.drug_dise_asso_mat.nonzero(), self.micr_dise_asso_mat.nonzero(), self.drug_inter_mat.nonzero(), self.micr_inter_mat.nonzero()
        # add offset
        drug_micr_valid+= torch.tensor([0, self.params.num_drug]); drug_micr_test+= torch.tensor([0, self.params.num_drug]); drug_micr_train+= torch.tensor([0, self.params.num_drug])
        drug_dise_train+= torch.tensor([0, self.params.num_drug+ self.params.num_micr]); micr_dise_train+= torch.tensor([self.params.num_drug, self.params.num_drug+ self.params.num_micr]); micr_inter_train+= torch.tensor([self.params.num_drug, self.params.num_drug])
        # add rel
        drug_micr_train, drug_micr_valid, drug_micr_test= drug_micr_train[:, [0, 1, 1]], drug_micr_valid[:, [0, 1, 1]], drug_micr_test[:, [0, 1, 1]]
        drug_dise_train, micr_dise_train, drug_inter_train, micr_inter_train= drug_dise_train[:, [0, 1, 1]], micr_dise_train[:, [0, 1, 1]], drug_inter_train[:, [0, 1, 1]], micr_inter_train[:, [0, 1, 1]]
        drug_micr_train[:, 2], drug_micr_valid[:, 2], drug_micr_test[:, 2], drug_dise_train[:, 2], micr_dise_train[:, 2], drug_inter_train[:, 2], micr_inter_train[:, 2]= 0, 0, 0, 1, 2, 3, 4
        # savefile
        train= torch.cat([drug_micr_train, drug_dise_train, micr_dise_train, drug_inter_train, micr_inter_train], dim= 0)
        val= drug_micr_valid
        test= drug_micr_test
        np.savetxt(f'{params.kg_file}//train.txt', train, fmt= '%d', delimiter= '\t', encoding= 'utf-8-sig')
        np.savetxt(f'{params.kg_file}//valid.txt', val, fmt= '%d', delimiter= '\t', encoding= 'utf-8-sig')
        np.savetxt(f'{params.kg_file}//test.txt', test, fmt= '%d', delimiter= '\t', encoding= 'utf-8-sig')
        print(f'Knowledge graph data has prepared...')

    # @split data set
    def split_dataset(self):
        train_xy, valid_xy, test_xy= [], [], []
        for i in range(self.params.num_drug):
            first= True
            for j in range(self.params.num_micr):
                if self.drug_micr_asso_mat[i, j]== 1 and first:
                    train_xy.append([i, j])
                    first= False
                else:
                    num= torch.rand(1)
                    if num< self.params.train_ratio:
                        train_xy.append([i, j])
                    elif num>= self.params.train_ratio and num< self.params.train_ratio+ self.params.valid_ratio:
                        valid_xy.append([i, j])
                    else:
                        test_xy.append([i, j])        
        print(f'Spliting data has finished...')
        return torch.tensor(train_xy), torch.tensor(valid_xy), torch.tensor(test_xy)

    # @load disease data
    def load_dise_data(self):
        dise_simi_mat= torch.from_numpy(np.loadtxt(self.params.dise_simi_path, encoding= 'utf-8-sig'))
        drug_dise_drug_simi_mat, micr_dise_micr_simi_mat= torch.matmul(nn.functional.normalize(self.drug_inter_mat, p= 2, dim= 1), nn.functional.normalize(self.drug_inter_mat, p= 2, dim= 1).T),\
        torch.matmul(nn.functional.normalize(self.micr_inter_mat, p= 2, dim= 1), nn.functional.normalize(self.micr_inter_mat, p= 2, dim= 1).T)
        for i in range(self.params.num_drug): drug_dise_drug_simi_mat[i, i]= 1.0
        for i in range(self.params.num_micr): micr_dise_micr_simi_mat[i, i]= 1.0
        return dise_simi_mat, drug_dise_drug_simi_mat, micr_dise_micr_simi_mat

    # @load micr data
    def load_micr_data(self):
        micr_ani_mat= torch.from_numpy(np.loadtxt(self.params.micr_ani_path, encoding= 'utf-8-sig'))
        micr_gene_simi_mat= (torch.from_numpy(np.loadtxt(self.params.microbe_gene_simi_path, encoding= 'utf-8-sig'))+ 1)/ 2
        micr_inter_mat= self.load_adj_data(self.params.micr_inter_adj_path, sp= (self.params.num_micr, self.params.num_micr))
        micr_dise_mat= self.load_adj_data(self.params.micr_dise_adj_path, sp= (self.params.num_micr, self.params.num_dise))
        micr_asso_simi_mat= torch.matmul(nn.functional.normalize(self.drug_micr_asso_mat_zy.T, p= 2, dim= 1), nn.functional.normalize(self.drug_micr_asso_mat_zy.T, p= 2, dim= 1).T)
        for i in range(self.params.num_micr):micr_asso_simi_mat[i, i]= 1
        for i in range(self.params.num_micr):micr_ani_mat[i, i]= 1
        return micr_ani_mat, micr_inter_mat, micr_dise_mat, micr_asso_simi_mat, micr_gene_simi_mat

    # @load drug data
    def load_drug_data(self):
        drug_dise_asso_mat= self.load_adj_data(self.params.drug_dise_adj_path, sp= (self.params.num_drug, self.params.num_dise))
        drug_micr_asso_mat= self.load_adj_data(self.params.drug_micr_adj_path, sp= (self.params.num_drug, self.params.num_micr))
        drug_inter_mat= self.load_adj_data(self.params.drug_inter_adj_path, sp= (self.params.num_drug, self.params.num_drug))
        drug_struct_simi_mat= torch.from_numpy(np.loadtxt(self.params.drug_struct_simi_path, encoding= 'utf-8-sig'))
        drug_fringer_simi_mat= torch.from_numpy(np.loadtxt(self.params.drug_fringer_simi_path, encoding= 'utf-8-sig'))
        return drug_micr_asso_mat, drug_dise_asso_mat, drug_inter_mat, drug_struct_simi_mat, drug_fringer_simi_mat
    
    # @load adj data
    def load_adj_data(self, path, sp= (1209, 172)):
        idx= torch.from_numpy(np.loadtxt(path, encoding= 'utf-8-sig')).long()- 1
        mat= torch.zeros((sp[0], sp[1]))
        mat[idx[:, 0], idx[:, 1]]= 1
        return mat
    
    # @NRWR, neighborhood random walk with restart
    def random_walk_root(self, A, alpha, epoch, nei, eps, err):
        Mask_mat= (torch.matrix_power(A, nei)- torch.matrix_power(A, nei- 1))> 0
        W= nn.functional.normalize(A, p= 1, dim= 1)
        S_last= torch.eye(A.shape[0])
        for i in range(epoch):
            S_new= alpha* torch.matmul(S_last, W)+ (1- alpha)* torch.eye(A.shape[0])
            S_new= nn.functional.normalize(Mask_mat* S_new, p= 1, dim= 1)
            if (S_new- S_last).abs().sum()<= err* A.shape[0]** 2:
                print('converge...');break
            S_last= S_new
        return S_new

    # write into memo
    def write2memo(self, mr, mrr, hits10):
        with open(f'{self.params.memo_file4kg}', 'a+') as f:
            f.write(f'{self.params.lr_kg}\t{self.params.weight_decay_kg}\t{mr}\t{mrr}\t{hits10}\n')

In [None]:
class Loss4MKGCN(nn.Module):
	def __init__(self, args):
		super().__init__()
		self.lambda1, self.lambda2= args.lambda1, args.lambda2
		# alpha1, (1373, 173); alpha2, (173, 1373);
	def forward(self, labels, preds, lnc_lap, dis_lap, alpha1, alpha2):
		# mse
		batch_loss= ((preds- labels)** 2).sum()
		# reg loss
		lnc_reg= torch.trace(torch.matmul(torch.matmul(alpha1.T, lnc_lap), alpha1))
		dis_reg= torch.trace(torch.matmul(torch.matmul(alpha2.T, dis_lap), alpha2))
		reg= (self.lambda1* lnc_reg+ self.lambda2* dis_reg).sum()
		# total
		loss= batch_loss+ reg
		return loss

In [None]:
def normalized_embedding(embeddings):
    [row, col] = embeddings.size()
    ne = torch.zeros([row, col])
    for i in range(row):
        ne[i, :] = (embeddings[i, :] - min(embeddings[i, :])+ 1e-15) / (max(embeddings[i, :]) - min(embeddings[i, :])+ 1e-15)
    return ne

In [None]:
def getGipKernel(y, trans, gamma, normalized=False):
    if trans:y = y.T
    if normalized:y = normalized_embedding(y)
    krnl = torch.mm(y, y.T)
    krnl = krnl / torch.mean(torch.diag(krnl))
    krnl = torch.exp(-kernelToDistance(krnl) * gamma)
    return krnl

In [None]:
def normalized_kernel(K):
    K = abs(K)
    k = K.flatten().sort()[0]
    min_v = k[torch.nonzero(k, as_tuple=False)[0]]
    K[torch.where(K == 0)] = min_v
    D = torch.diag(K)
    D = D.sqrt()
    S = K / (D * D.T)
    return S

In [None]:
def kernelToDistance(k):
    di = torch.diag(k).T
    d = di.repeat(len(k)).reshape(len(k), len(k)).T + di.repeat(len(k)).reshape(len(k), len(k)) - 2 * k
    return d

In [None]:
def laplacian(kernel):
    # 按列求和
    d1 = sum(kernel)
    # 得到对角矩阵
    D_1 = torch.diag(d1)
    # 拉普拉斯矩阵
    L_D_1 = D_1 - kernel
    # 度矩阵进行标准化
    D_5 = D_1.rsqrt()
    # 条件
    D_5 = torch.where(torch.isinf(D_5), torch.full_like(D_5, 0), D_5)
    # 对核进行标准化
    L_D_11 = torch.mm(D_5, L_D_1)
    L_D_11 = torch.mm(L_D_11, D_5)
    return L_D_11

In [None]:
class MKGCN(nn.Module):
    # 参数对象; 药物相似性; 微生物相似性; 遮掩的关联矩阵; 特征矩阵;
    def __init__(self, params, drug_sim, micr_sim, hete_ass_idx, hete_ass_weight, fea):
        super().__init__()
        self.params= params
        self.drug_nums, self.micro_nums= params.num_drug, params.num_micr
        self.gcn1_units, self.gcn2_units, self.gcn3_units= self.params.layer1_hidden_units, self.params.layer2_hidden_units, self.params.layer3_hidden_units
        self.lambda1, self.lambda2, self.h1_gamma, self.h2_gamma, self.h3_gamma= self.params.lambda1, self.params.lambda2, self.params.h1_gamma, self.params.h2_gamma, self.params.h3_gamma
        self.kernel_size= self.params.gcn_layer_num+ 1
        self.drug_sim, self.micro_sim= drug_sim, micr_sim
        self.drug_kernel_weights= torch.ones(self.kernel_size)/ self.kernel_size
        self.micro_kernel_weights= torch.ones(self.kernel_size)/ self.kernel_size
        self.gcn1, self.gcn2, self.gcn3= conv.GCNConv(self.micro_nums+ self.drug_nums, self.gcn1_units), conv.GCNConv(self.gcn1_units, self.gcn2_units), conv.GCNConv(self.gcn2_units, self.gcn3_units)
        self.alpha1, self.alpha2= torch.randn(self.drug_nums, self.micro_nums), torch.randn(self.micro_nums, self.drug_nums)
        self.relu= nn.ReLU()
        self.hete_ass_idx= hete_ass_idx
        self.hete_ass_weight= hete_ass_weight
        self.fea= fea
        self.lap_drug_kernel, self.lap_micro_kernel= [], []
        self.kernel4drug, self.kernel4micro= torch.zeros((self.drug_nums, self.drug_nums)), torch.zeros((self.micro_nums, self.micro_nums))

    # X, 特征矩阵; ass_mat, 遮掩之后的关联矩阵; 
    def forward(self, left, right):
        # gcn1, emb1, (1546, 128);
        emb1= self.relu(self.gcn1(self.fea, self.hete_ass_idx.T, self.hete_ass_weight))
        self.kernel4drug= self.drug_kernel_weights[0]* getGipKernel(emb1[0: self.drug_nums], 0, self.h1_gamma, True)
        self.kernel4micro= self.micro_kernel_weights[0]* getGipKernel(emb1[self.drug_nums: ], 0, self.h1_gamma, True)
        emb2= self.relu(self.gcn2(emb1, self.hete_ass_idx.T, self.hete_ass_weight))
        self.kernel4drug+= self.drug_kernel_weights[1]* getGipKernel(emb2[0: self.drug_nums], 0, self.h2_gamma, True)
        self.kernel4micro+= self.micro_kernel_weights[1]* getGipKernel(emb2[self.drug_nums: ], 0, self.h2_gamma, True)
        emb3= self.relu(self.gcn3(emb2, self.hete_ass_idx.T, self.hete_ass_weight))
        self.kernel4drug+= self.drug_kernel_weights[2]* getGipKernel(emb3[0: self.drug_nums], 0, self.h3_gamma, True)
        self.kernel4micro+= self.micro_kernel_weights[2]* getGipKernel(emb3[self.drug_nums: ], 0, self.h3_gamma, True)
        # create drug& micro kernel
        self.kernel4drug+= self.drug_kernel_weights[3]* self.drug_sim
        self.kernel4micro+= self.micro_kernel_weights[3]* self.micro_sim
        # 对药物混合核进行标准化
        self.kernel4drug, self.kernel4micro= normalized_kernel(self.kernel4drug), normalized_kernel(self.kernel4micro)
        # 对药物核\微生物核进行标准化
        self.lap_drug_kernel, self.lap_micro_kernel= laplacian(self.kernel4drug), laplacian(self.kernel4micro)
        # 获取结果矩阵
        out= (torch.matmul(self.kernel4drug, self.alpha1)+ torch.matmul(self.alpha2.T, self.kernel4micro))/ 2
        # return
        return out[left, right]

In [None]:
def avg_auc_aupr_cpt(test_xy, test_label, pred, ass_mat_shape):
    label_mat, pred_mat= torch.zeros((ass_mat_shape)) -1, torch.zeros((ass_mat_shape)) -1
    label_mat[test_xy[:, 0], test_xy[:, 1]], pred_mat[test_xy[:, 0], test_xy[:, 1]]= test_label* 1.0, pred
    bool_mat4mark_test_examp= (label_mat!= -1)
    aucs, auprs= [], []
    for i in range(ass_mat_shape[0]):
        test_examp_loc= bool_mat4mark_test_examp[i]
        pos_num= label_mat[i, test_examp_loc].sum()
        if pos_num> 0 and (test_examp_loc).sum()- pos_num> 0:
            fpr4rowi, tpr4rowi, _= roc_curve(label_mat[i, test_examp_loc], pred_mat[i, test_examp_loc])
            prec4rowi, recall4rowi, _= precision_recall_curve(label_mat[i, test_examp_loc], pred_mat[i, test_examp_loc])
            prec4rowi[-1]= [1, 0][(int)(prec4rowi[-2]== 0)]
            aucs.append(auc(fpr4rowi, tpr4rowi));auprs.append(auc(recall4rowi, prec4rowi))
    return np.mean(aucs), np.mean(auprs)

In [None]:
class EarlyStopping:
	"""docstring for EarlyStopping"""
	def __init__(self, patience, pt_file= 'checkpoint/', file_name= 'checkpoint.pt', mess_out= True, eps= 0):
		super().__init__()
		self.patience, self.eps, self.pt_file, self.file_name, self.mess_out= patience, eps, pt_file, file_name, mess_out
		self.best_score, self.counter, self.flag= None, 0, False
		if os.path.exists(self.pt_file)== False:os.makedirs(self.pt_file)
	
	def __call__(self, val_loss, model):
		score= -val_loss
		if self.best_score is None:
			self.best_score= score
			self.save_checkpoint(model)
		elif score<= self.best_score- self.eps:
			self.counter+= 1
			if self.mess_out:print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
			if self.counter>= self.patience:
				self.flag= True
		else:
			self.best_score= score
			self.save_checkpoint(model)
			self.counter= 0

	def save_checkpoint(self, model):
		torch.save(model, f'{self.pt_file}//{self.file_name}')

In [None]:
random.seed(params.seed)
np.random.seed(params.seed)
torch.manual_seed(params.seed)
torch.cuda.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
os.environ['PYTHONHASHSEED'] = str(params.seed)    
torch.backends.cudnn.deterministic = True
dl= dataloader(params)
hete_graph_weight= dl.hete_graph_mat[dl.hete_graph_idx[:, 0], dl.hete_graph_idx[:, 1]].clone()
train_xy, valid_xy, test_xy= dl.train_xy.clone(),  dl.valid_xy.clone(),  dl.test_xy.clone()
train_label, valid_label, test_label= dl.drug_micr_asso_mat[train_xy[:, 0], train_xy[:, 1]], dl.drug_micr_asso_mat[valid_xy[:, 0], valid_xy[:, 1]], dl.drug_micr_asso_mat[test_xy[:, 0], test_xy[:, 1]]
net, reg_mse= MKGCN(params, dl.drug_struct_simi_mat.clone(), dl.micr_asso_simi_mat.clone(), dl.hete_graph_idx.clone(), hete_graph_weight.clone(), dl.fea.clone()).to(params.device), Loss4MKGCN(params).to(params.device)
optimizer= torch.optim.Adam(net.parameters(), lr= params.lr, weight_decay= params.weight_decay)
earlystopping= EarlyStopping(patience= params.patience, pt_file= params.pt_file, file_name= params.pt_file_name, mess_out= True)
pred= []
for ep in range(params.epochs):
    net.train()
    time_start= time.time()
    logp= net(train_xy[:, 0], train_xy[:, 1])
    loss= reg_mse(train_label, logp, net.lap_drug_kernel, net.lap_micro_kernel, net.alpha1, net.alpha2)
    net.alpha1 = torch.matmul(\
        torch.matmul((torch.matmul(net.kernel4drug, net.kernel4drug) + net.lambda1 * net.lap_drug_kernel).inverse(), net.kernel4drug),\
        2 * dl.drug_micr_asso_mat_zy - torch.matmul(net.alpha2.T, net.kernel4micro.T)).detach()
    net.alpha2 = torch.mm(torch.mm((torch.mm(net.kernel4micro, net.kernel4micro) + net.lambda2 * net.lap_micro_kernel).inverse(), net.kernel4micro),\
        2 * dl.drug_micr_asso_mat_zy.T - torch.mm(net.alpha1.T, net.kernel4drug.T)).detach()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    time_end= time.time()
    print(f'epoch: {ep+ 1}, train loss: {loss}, time: {time_end- time_start}')
    # valid
    net.eval(); val_loss= 0; pred= []
    with torch.no_grad():
        logp= net(valid_xy[:, 0], valid_xy[:, 1])
        val_loss= reg_mse(valid_label, logp, net.lap_drug_kernel, net.lap_micro_kernel, net.alpha1, net.alpha2).item()
        pred.append(logp)
    pred= torch.cat(pred).cpu()
    roc_auc, aupr_auc= avg_auc_aupr_cpt(dl.valid_xy, valid_label, pred, (1209, 172))
    print(f'epoch: {ep+ 1}, valid loss: {val_loss}, auc: {roc_auc}, aupr: {aupr_auc}')
    earlystopping(-(roc_auc+ aupr_auc), net)
    if earlystopping.flag== True:print(f'early_stopping');break;
net= torch.load(f'{params.pt_file}//{params.pt_file_name}')
net.eval()
with torch.no_grad():
    pred= net(valid_xy[:, 0], valid_xy[:, 1])
roc_auc, roc_aupr= avg_auc_aupr_cpt(valid_xy, valid_label, pred, (1209, 172))
# with open(file= params.memo_file4mkgcn, mode= 'a') as f:
    # f.write(f'{params.lr}\t{params.weight_decay}\t{roc_auc}\t{roc_aupr}\n')
print(f'{params.lr}\t{params.weight_decay}\t{roc_auc}\t{roc_aupr}\n')

In [None]:
net= torch.load(f'{params.pt_file}//{params.pt_file_name}')
net.eval();results= []
with torch.no_grad():
    pred= net(test_xy[:, 0], test_xy[:, 1])
results= torch.cat([dl.test_xy.to('cpu'), test_label.view(-1, 1).to('cpu'), pred.view(-1, 1).to('cpu')], dim= 1)
print(avg_auc_aupr_cpt(dl.test_xy.to('cpu'), test_label.to('cpu'), pred.to('cpu'), (1209, 172)))
# np.savetxt(fname= params.test_result_file, X= results, delimiter= '\t', encoding= 'utf-8-sig')