In [None]:
import os
import time
import torch
import random
import warnings
import argparse
import numpy as np
from torch_geometric.nn import conv
from sklearn.metrics import roc_curve, roc_auc_score, auc, precision_recall_curve, average_precision_score
from torch_geometric.nn import GINConv
import torch.nn as nn
from tqdm import tqdm
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from collections import defaultdict as ddict, Counter
warnings.filterwarnings('ignore')

In [None]:
parser= argparse.ArgumentParser(description= 'Parser for Arguments')
parser.add_argument('-seed', type= int, default= 0)
parser.add_argument('-num_ent', type= int, default= 1209+ 172+ 154)
parser.add_argument('-num_drug', type= int, default= 1209)
parser.add_argument('-num_micr', type= int, default= 172)
parser.add_argument('-num_dise', type= int, default= 154)
parser.add_argument('-drug_name_path', type= str, default= '../mdd/drug/drug_name.txt')
parser.add_argument('-micr_name_path', type= str, default= '../mdd/microbe/microbe_name.txt')
parser.add_argument('-dise_name_path', type= str, default= '../mdd/disease/disease_name.txt')
parser.add_argument('-drug_micr_adj_path', type= str, default= '../mdd/adj/microbe_drug_adj.txt')
parser.add_argument('-drug_struct_simi_path', type= str, default= '../mdd/drug/drug_struct_simi.txt')
parser.add_argument('-drug_inter_adj_path', type= str, default= '../mdd/drug/drug_interact_adj.txt')
parser.add_argument('-drug_fringer_simi_path', type= str, default= '../mdd/drug/drug_fringer_simi.txt')
parser.add_argument('-drug_dise_adj_path', type= str, default= '../mdd/adj/drug_disease_adj.txt')
parser.add_argument('-micr_ani_path', type= str, default= '../mdd/microbe/microbe_ani_simi.txt')
parser.add_argument('-micr_inter_adj_path', type= str, default= '../mdd/microbe/microbe_interact_adj.txt')
parser.add_argument('-micr_embedding_path', type= str, default= '../mdd/microbe/microbe_gene_simi.txt')
parser.add_argument('-micr_dise_adj_path', type= str, default= '../mdd/adj/microbe_disease_adj.txt')
parser.add_argument('-micr_gene_simi_path', type= str, default= '../mdd/microbe/microbe_gene_simi.txt')
parser.add_argument('-micr_gene_embedding_path', type= str, default= '../mdd/microbe/microbe_emb.txt')
parser.add_argument('-dise_simi_path', type= str, default= '../mdd/disease/disease_dag_simi.txt')
parser.add_argument('-train_ratio', type= float, default= 0.8)
parser.add_argument('-valid_ratio', type= float, default= 0.1)
parser.add_argument('-test_ratio', type= float, default= 0.1)
parser.add_argument('-batch_size', type= int, default= 128)
parser.add_argument('-threshold', type= float, default= 0.8)
parser.add_argument('-epochs', type= int, default= 500)
parser.add_argument('-patience', type= int, default= 20)
parser.add_argument('-lr', type= float, default= 5e-3)
parser.add_argument('-weight_decay', type= float, default= 1e-4)
parser.add_argument('-device', type= str, default= 'cuda:0')
parser.add_argument('-pt_file', type= str, default= 'checkpoint/')
parser.add_argument('-memo_file4gnaemda', type= str, default= 'memo/gin.txt')
parser.add_argument('-pt_file_name', type= str, default= 'gin.pt')
parser.add_argument('-test_result_file', type= str, default= 'result/gin_test_result.txt')
params= parser.parse_args([])

In [None]:
random.seed(params.seed)
np.random.seed(params.seed)
torch.manual_seed(params.seed)
torch.cuda.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
os.environ['PYTHONHASHSEED'] = str(params.seed)    
torch.backends.cudnn.deterministic = True

In [None]:
class dataloader(object):
    def __init__(self, params):
        super().__init__()
        self.params= params
        self.drug_micr_asso_mat, self.drug_dise_asso_mat, self.drug_inter_mat, self.drug_struct_simi_mat, self.drug_fringer_simi_mat= self.load_drug_data()
        self.train_xy, self.valid_xy, self.test_xy= self.split_dataset()
        self.drug_micr_asso_mat_zy= self.get_asso_mat_zy()
        self.micr_ani_mat, self.micr_inter_mat, self.micr_dise_asso_mat, self.micr_asso_simi_mat, self.micr_gene_simi_mat, self.micr_gene_simi_mat_g= self.load_micr_data()
        self.dise_simi_mat, self.drug_dise_drug_simi_mat, self.micr_dise_micr_simi_mat= self.load_dise_data()
        self.drug_inte_simi_mat= self.drug_struct_simi_mat
        self.drug_topo_fea, self.micr_topo_fea= self.random_walk_root(self.drug_inter_mat), self.random_walk_root(self.micr_inter_mat)
        self.micr_inte_simi_mat= torch.where(self.micr_ani_mat> 0, self.micr_ani_mat, self.micr_gene_simi_mat)
        x_simi= torch.cat([torch.cat([self.drug_inte_simi_mat, self.drug_micr_asso_mat_zy], dim= 1),\
                           torch.cat([self.drug_micr_asso_mat_zy.T, self.micr_inte_simi_mat], dim= 1)], dim= 0)
        # x_seco= torch.cat([torch.cat([self.drug_topo_fea, torch.zeros((self.params.num_drug, self.params.num_micr))], dim= 1),\
                        #    torch.cat([torch.zeros((self.params.num_micr, self.params.num_drug)), self.micr_topo_fea], dim= 1)], dim= 0)
        # self.fea= torch.cat([x_simi, x_seco], dim= 1).float()
        self.fea= x_simi.float()
        self.g= torch.cat([torch.cat([self.drug_struct_simi_mat, self.drug_micr_asso_mat_zy], dim= 1),\
                                    torch.cat([self.drug_micr_asso_mat_zy.T, self.micr_inte_simi_mat], dim= 1)], dim= 0).float()
        self.edge_index= (self.g>= params.threshold).nonzero()
        self.edge_weigh= self.g[self.edge_index[:, 0], self.edge_index[:, 1]]
        self.introduce()

    # @introduce data
    def introduce(self):
        print(f'Drug microbe association num: {self.drug_micr_asso_mat.sum()}\nDrug interaction num: {self.drug_inter_mat.sum()}\nMicrobe interaction num: {self.micr_inter_mat.sum()}')
        print(f'Drug disease association num: {self.drug_dise_asso_mat.sum()}\nMicrobe disease association num: {self.micr_dise_asso_mat.sum()}')

    # @ mask
    def get_asso_mat_zy(self):
        asso_mat_zy= self.drug_micr_asso_mat.clone()
        asso_mat_zy[self.valid_xy[:, 0], self.valid_xy[:, 1]]= 0
        asso_mat_zy[self.test_xy[:, 0], self.test_xy[:, 1]]= 0
        return asso_mat_zy

    def gauss_simi(self, mat, eps= 1e-15):
        # 向量写法
        sim_mat= torch.mul(mat, mat).sum(dim= 1, keepdims= True)+ torch.mul(mat, mat).sum(dim= 1, keepdims= True).T- 2* torch.matmul(mat, mat.T)
        # 高斯核
        sigma= 1/ torch.diag(torch.matmul(mat, mat.T)).mean()        
        # sigma= 1/ (sim_mat.mean()+ eps)
        # 返回高斯核相似性矩阵
        sim_mat= torch.exp(-1* sigma* sim_mat)
        # MTK...
        # sim_mat= 1/ (1+ torch.exp(-15* sim_mat+ math.log(9999)))
        return sim_mat

    # @split data set
    def split_dataset(self):
        train_xy, valid_xy, test_xy= [], [], []
        for i in range(self.params.num_drug):
            first= True
            for j in range(self.params.num_micr):
                if self.drug_micr_asso_mat[i, j]== 1 and first:
                    train_xy.append([i, j])
                    first= False
                else:
                    num= torch.rand(1)
                    if num< self.params.train_ratio:
                        train_xy.append([i, j])
                    elif num>= self.params.train_ratio and num< self.params.train_ratio+ self.params.valid_ratio:
                        valid_xy.append([i, j])
                    else:
                        test_xy.append([i, j])        
        print(f'Spliting data has finished...')
        return torch.tensor(train_xy), torch.tensor(valid_xy), torch.tensor(test_xy)

    # @load disease data
    def load_dise_data(self):
        dise_simi_mat= torch.from_numpy(np.loadtxt(self.params.dise_simi_path, encoding= 'utf-8-sig'))
        drug_dise_drug_simi_mat, micr_dise_micr_simi_mat= torch.matmul(nn.functional.normalize(self.drug_inter_mat, p= 2, dim= 1), nn.functional.normalize(self.drug_inter_mat, p= 2, dim= 1).T),\
        torch.matmul(nn.functional.normalize(self.micr_inter_mat, p= 2, dim= 1), nn.functional.normalize(self.micr_inter_mat, p= 2, dim= 1).T)
        for i in range(self.params.num_drug): drug_dise_drug_simi_mat[i, i]= 1.0
        for i in range(self.params.num_micr): micr_dise_micr_simi_mat[i, i]= 1.0
        return dise_simi_mat, drug_dise_drug_simi_mat, micr_dise_micr_simi_mat

    # @load micr data
    def load_micr_data(self):
        micr_ani_mat= torch.from_numpy(np.loadtxt(self.params.micr_ani_path, encoding= 'utf-8-sig'))
        micr_gene_embedding= torch.from_numpy(np.loadtxt(self.params.micr_gene_embedding_path, encoding= 'utf-8-sig'))
        micr_gene_simi_mat_g= self.gauss_simi(micr_gene_embedding)
        micr_gene_simi_mat= (torch.from_numpy(np.loadtxt(self.params.micr_gene_simi_path, encoding= 'utf-8-sig'))+ 1)/ 2
        micr_inter_mat= self.load_adj_data(self.params.micr_inter_adj_path, sp= (self.params.num_micr, self.params.num_micr))
        micr_dise_mat= self.load_adj_data(self.params.micr_dise_adj_path, sp= (self.params.num_micr, self.params.num_dise))
        micr_asso_simi_mat= torch.matmul(nn.functional.normalize(self.drug_micr_asso_mat_zy.T, p= 2, dim= 1), nn.functional.normalize(self.drug_micr_asso_mat_zy.T, p= 2, dim= 1).T)
        for i in range(self.params.num_micr):micr_asso_simi_mat[i, i]= 1
        for i in range(self.params.num_micr):micr_ani_mat[i, i]= 1            
        return micr_ani_mat, micr_inter_mat, micr_dise_mat, micr_asso_simi_mat, micr_gene_simi_mat, micr_gene_simi_mat_g

    # @load drug data
    def load_drug_data(self):
        drug_dise_asso_mat= self.load_adj_data(self.params.drug_dise_adj_path, sp= (self.params.num_drug, self.params.num_dise))
        drug_micr_asso_mat= self.load_adj_data(self.params.drug_micr_adj_path, sp= (self.params.num_drug, self.params.num_micr))
        drug_inter_mat= self.load_adj_data(self.params.drug_inter_adj_path, sp= (self.params.num_drug, self.params.num_drug))
        drug_struct_simi_mat= torch.from_numpy(np.loadtxt(self.params.drug_struct_simi_path, encoding= 'utf-8-sig'))
        drug_fringer_simi_mat= torch.from_numpy(np.loadtxt(self.params.drug_fringer_simi_path, encoding= 'utf-8-sig'))
        return drug_micr_asso_mat, drug_dise_asso_mat, drug_inter_mat, drug_struct_simi_mat, drug_fringer_simi_mat
    
    # @load adj data
    def load_adj_data(self, path, sp= (1209, 172)):
        idx= torch.from_numpy(np.loadtxt(path, encoding= 'utf-8-sig')).long()- 1
        mat= torch.zeros((sp[0], sp[1]))
        mat[idx[:, 0], idx[:, 1]]= 1
        return mat

    # @ RWR
    def random_walk_root(self, A, epoch= 500, prob= 0.5, eps= 1e-15):
        LSM= torch.eye(A.shape[0])
        T_Mat= A/ (A.sum(dim= 1, keepdims= True)+ eps)
        for i in range(epoch):
            # 你到A结点的概率= 你到中间结点的概率(行) 卷积 中间结点到A的概率(列)
            SM= prob* LSM@ T_Mat+ (1- prob)* torch.eye(A.shape[0])
            if torch.mean(torch.abs(SM- LSM))<= eps:
                print('converge...');break
            LSM= SM
        return SM

    # write into memo
    def write2memo(self, mr, mrr, hits10):
        with open(f'{self.params.memo_file4kg}', 'a+') as f:
            f.write(f'{self.params.lr_kg}\t{self.params.weight_decay_kg}\t{mr}\t{mrr}\t{hits10}\n')

In [None]:
class EarlyStopping:
	"""docstring for EarlyStopping"""
	def __init__(self, patience, pt_file= 'checkpoint/', file_name= 'checkpoint.pt', mess_out= True, eps= 0):
		super().__init__()
		self.patience, self.eps, self.pt_file, self.file_name, self.mess_out= patience, eps, pt_file, file_name, mess_out
		self.best_score, self.counter, self.flag= None, 0, False
		if os.path.exists(self.pt_file)== False:os.makedirs(self.pt_file)
	
	def __call__(self, val_loss, model):
		score= -val_loss
		if self.best_score is None:
			self.best_score= score
			self.save_checkpoint(model)
		elif score<= self.best_score- self.eps:
			self.counter+= 1
			if self.mess_out:print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
			if self.counter>= self.patience:
				self.flag= True
		else:
			self.best_score= score
			self.save_checkpoint(model)
			self.counter= 0

	def save_checkpoint(self, model):
		torch.save(model, f'{self.pt_file}//{self.file_name}')

In [None]:
class GIN(torch.nn.Module):
    def __init__(self, in_channels, out_channels, fea, edge_index= None):
        super(GIN, self).__init__()
        self.fea, self.edge_index= fea, edge_index
        self.l1= nn.Sequential(nn.Linear(in_channels, in_channels// 2), nn.ReLU())
        self.l2= nn.Sequential(nn.Linear(in_channels// 2, out_channels), nn.ReLU())
        self.cv1= GINConv(self.l1, train_eps= False)
        self.cv2= GINConv(self.l2, train_eps= False)
        self.fc4out= nn.Sequential(nn.Linear((in_channels+ in_channels// 2+ out_channels)* 2, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 2))
    def init_para(self):
        nn.init.xavier_normal_(self.l1[0].weight, nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.l2[0].weight, nn.init.calculate_gain('relu'))        
        nn.init.xavier_normal_(self.fc4out[0].weight, nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.fc4out[3].weight)
    def forward(self, l, r):
        x1= self.cv1(self.fea, self.edge_index)
        x2= self.cv2(x1, self.edge_index)
        x= torch.cat([self.fea, x1, x2], dim= 1)
        return self.fc4out(torch.cat((x[l], x[r]), dim= 1))

In [None]:
def avg_auc_aupr_cpt(test_xy, test_label, pred, ass_mat_shape):
    label_mat, pred_mat= torch.zeros((ass_mat_shape)) -1, torch.zeros((ass_mat_shape)) -1
    label_mat[test_xy[:, 0], test_xy[:, 1]], pred_mat[test_xy[:, 0], test_xy[:, 1]]= test_label* 1.0, pred
    bool_mat4mark_test_examp= (label_mat!= -1)
    aucs, auprs= [], []
    for i in range(ass_mat_shape[0]):
        test_examp_loc= bool_mat4mark_test_examp[i]
        pos_num= label_mat[i, test_examp_loc].sum()
        if pos_num> 0 and (test_examp_loc).sum()- pos_num> 0:
            fpr4rowi, tpr4rowi, _= roc_curve(label_mat[i, test_examp_loc], pred_mat[i, test_examp_loc])
            prec4rowi, recall4rowi, _= precision_recall_curve(label_mat[i, test_examp_loc], pred_mat[i, test_examp_loc])
            prec4rowi[-1]= [1, 0][(int)(prec4rowi[-2]== 0)]
            aucs.append(auc(fpr4rowi, tpr4rowi));auprs.append(auc(recall4rowi, prec4rowi))
    return np.mean(aucs), np.mean(auprs)

In [None]:
dl= dataloader(params)
train_label, valid_label, test_label= dl.drug_micr_asso_mat[dl.train_xy[:, 0], dl.train_xy[:, 1]].long(), dl.drug_micr_asso_mat[dl.valid_xy[:, 0], dl.valid_xy[:, 1]].long(), dl.drug_micr_asso_mat[dl.test_xy[:, 0], dl.test_xy[:, 1]].long()
# train
train_xy_label_dataset= torch.utils.data.TensorDataset(dl.train_xy, train_label)
train_loader= torch.utils.data.DataLoader(train_xy_label_dataset, batch_size= params.batch_size, shuffle= True)
# valid
valid_xy_label_dataset= torch.utils.data.TensorDataset(dl.valid_xy, valid_label)
valid_loader= torch.utils.data.DataLoader(valid_xy_label_dataset, batch_size= params.batch_size, shuffle= False)
# test, 
test_xy_label_dataset= torch.utils.data.TensorDataset(dl.test_xy, test_label)
test_loader= torch.utils.data.DataLoader(test_xy_label_dataset, batch_size= params.batch_size, shuffle= False)

In [None]:
fea, g, edge_index, edge_weight= dl.fea.to(params.device), dl.g.to(params.device), dl.edge_index.to(params.device), dl.edge_weigh.to(params.device)
net= GIN(1381, 128, fea, edge_index.T).to(params.device)
optimizer= torch.optim.Adam(net.parameters(), lr= params.lr, weight_decay= params.weight_decay)
earlystopping= EarlyStopping(patience= params.patience, pt_file= params.pt_file, file_name= params.pt_file_name, mess_out= True)
loss_func1= nn.CrossEntropyLoss()
pred= []

In [None]:
for ep in range(params.epochs):
    # train
    for step, (xy, label) in enumerate(train_loader):
        net.train()
        time_start= time.time()
        xy, label= xy.to(params.device), label.to(params.device)
        logp= net(xy[:, 0], xy[:, 1]+ 1209)
        loss= loss_func1(logp, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        time_end= time.time()
        if step% 500== 0:
            print(f'epoch: {ep+ 1}, step: {step+ 1}, train loss: {loss}, acc: {(torch.max(logp, dim= 1)[1]== label).float().sum()* 1.0/ len(xy)}, time: {time_end- time_start}')
    # valid
    net.eval(); val_loss= 0; pred= []
    with torch.no_grad():
        for step, (xy, label) in enumerate(valid_loader):
            xy, label= xy.to(params.device), label.to(params.device)
            logp= net(xy[:, 0], xy[:, 1]+ 1209)
            val_loss= loss_func1(logp, label)
            pred.append(logp)
    pred= torch.cat(pred, dim= 0).cpu()
    roc_auc, aupr_auc= avg_auc_aupr_cpt(dl.valid_xy, valid_label, pred[:, 1], (1209, 172))
    print(f'epoch: {ep+ 1}, valid loss: {val_loss}, acc: {(torch.max(pred, dim= 1)[1]== valid_label).float().sum()* 1.0/ len(valid_label)}, auc: {roc_auc}, aupr: {aupr_auc}')
    earlystopping(-(roc_auc+ aupr_auc), net)
    if earlystopping.flag== True:print(f'early_stopping');break;

In [None]:
net= torch.load(f'{params.pt_file}//{params.pt_file_name}')
net.eval();results= []
with torch.no_grad():
    for _, (xy, _) in enumerate(test_loader):
        logp= net(xy[:, 0], xy[:, 1]+ 1209)
        results.append(logp)
pred= torch.cat(results, dim= 0)
results= torch.cat([dl.test_xy.to('cpu'), test_label.view(-1, 1).to('cpu'), pred[:, 1].view(-1, 1).to('cpu')], dim= 1)
print(avg_auc_aupr_cpt(dl.test_xy.to('cpu'), test_label.to('cpu'), pred[:, 1].to('cpu'), (1209, 172)))
# np.savetxt(fname= params.test_result_file, X= results, delimiter= '\t', encoding= 'utf-8-sig')