In [None]:
import os
import time
import torch
import random
import warnings
import argparse
import numpy as np
from torch_geometric.nn import conv
from sklearn.metrics import roc_curve, roc_auc_score, auc, precision_recall_curve, average_precision_score
import torch.nn as nn
from tqdm import tqdm
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from collections import defaultdict as ddict, Counter
warnings.filterwarnings('ignore')

In [None]:
parser= argparse.ArgumentParser(description= 'Parser for Arguments')
parser.add_argument('-seed', type= int, default= 0)
parser.add_argument('-num_ent', type= int, default= 1209+ 172+ 154)
parser.add_argument('-num_drug', type= int, default= 1209)
parser.add_argument('-num_micr', type= int, default= 172)
parser.add_argument('-num_dise', type= int, default= 154)
# drug_micr_rel, 0; drug_dise_rel, 1; micr_dise_rel, 2; drug_inter_rel, 3; micr_inter_rel, 4; micr_drug_rel, 5; dise_drug_rel, 6; dies_micr_rel, 7.
parser.add_argument('-num_rel', type= int, default= 8)
parser.add_argument('-drug_name_path', type= str, default= '../mdd/drug/drug_name.txt')
parser.add_argument('-micr_name_path', type= str, default= '../mdd/microbe/microbe_name.txt')
parser.add_argument('-dise_name_path', type= str, default= '../mdd/disease/disease_name.txt')
parser.add_argument('-drug_micr_adj_path', type= str, default= '../mdd/adj/microbe_drug_adj.txt')
parser.add_argument('-drug_struct_simi_path', type= str, default= '../mdd/drug/drug_struct_simi.txt')
parser.add_argument('-drug_inter_adj_path', type= str, default= '../mdd/drug/drug_interact_adj.txt')
parser.add_argument('-drug_fringer_simi_path', type= str, default= '../mdd/drug/drug_fringer_simi.txt')
parser.add_argument('-drug_dise_adj_path', type= str, default= '../mdd/adj/drug_disease_adj.txt')
parser.add_argument('-micr_ani_path', type= str, default= '../mdd/microbe/microbe_ani_simi.txt')
parser.add_argument('-micr_inter_adj_path', type= str, default= '../mdd/microbe/microbe_interact_adj.txt')
parser.add_argument('-micr_dise_adj_path', type= str, default= '../mdd/adj/microbe_disease_adj.txt')
parser.add_argument('-dise_simi_path', type= str, default= '../mdd/disease/disease_dag_simi.txt')
parser.add_argument('-train_ratio', type= float, default= 0.8)
parser.add_argument('-valid_ratio', type= float, default= 0.1)
parser.add_argument('-test_ratio', type= float, default= 0.1)
parser.add_argument('-kg_file', type= str, default= 'kg_data/')
parser.add_argument('-batch_size', type= int, default= 128)
parser.add_argument('-lbl_smooth', type= float, default= 0.2, help= 'Label smoothing enable or disable')
parser.add_argument('-embed_dim', type= int, default= 128)
parser.add_argument('-device', type= str, default= 'cuda:0')
parser.add_argument('-lr_kg', type= float, default= 0.001)
parser.add_argument('-weight_decay_kg', type= float, default= 0)
parser.add_argument('-patience_kg', type= int, default= 50)
parser.add_argument('-epoch_kg', type= int, default= 300)
parser.add_argument('-pt_file', type= str, default= 'checkpoint/')
parser.add_argument('-memo_file4kg', type= str, default= 'memo/memo.txt')
parser.add_argument('-threshold4kg', type= float, default= 0.8)
parser.add_argument('-pt_file_name4net1', type= str, default= 'conve.pt')
parser.add_argument('-test_result_file', type= str, default= 'result/conve_test_result.txt')
params= parser.parse_args([])

In [None]:
random.seed(params.seed)
np.random.seed(params.seed)
torch.manual_seed(params.seed)
torch.cuda.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
os.environ['PYTHONHASHSEED'] = str(params.seed)    
torch.backends.cudnn.deterministic = True

In [None]:
class dataloader(object):
    def __init__(self, params):
        super().__init__()
        self.params= params
        self.drug_micr_asso_mat, self.drug_dise_asso_mat, self.drug_inter_mat, self.drug_struct_simi_mat, self.drug_fringer_simi_mat= self.load_drug_data()
        self.train_xy, self.valid_xy, self.test_xy= self.split_dataset()
        self.drug_micr_asso_mat_zy= self.get_asso_mat_zy()
        self.micr_ani_mat, self.micr_inter_mat, self.micr_dise_asso_mat, self.micr_asso_simi_mat= self.load_micr_data()
        self.dise_simi_mat, self.drug_dise_drug_simi_mat, self.micr_dise_micr_simi_mat= self.load_dise_data()
        self.hete_graph_mat= torch.cat([torch.cat([self.drug_struct_simi_mat, self.drug_micr_asso_mat_zy, self.drug_dise_asso_mat], dim= 1),\
                                    torch.cat([self.drug_micr_asso_mat_zy.T, self.micr_ani_mat, self.micr_dise_asso_mat], dim= 1),\
                                    torch.cat([self.drug_dise_asso_mat.T, self.micr_dise_asso_mat.T, self.dise_simi_mat], dim= 1)], dim= 0).float()       
        self.create_kg_data()
        self.introduce()

    # @introduce data
    def introduce(self):
        print(f'Drug microbe association num: {self.drug_micr_asso_mat.sum()}\nDrug interaction num: {self.drug_inter_mat.sum()}\nMicrobe interaction num: {self.micr_inter_mat.sum()}')
        print(f'Drug disease association num: {self.drug_dise_asso_mat.sum()}\nMicrobe disease association num: {self.micr_dise_asso_mat.sum()}')

    # @ mask
    def get_asso_mat_zy(self):
        asso_mat_zy= self.drug_micr_asso_mat.clone()
        asso_mat_zy[self.valid_xy[:, 0], self.valid_xy[:, 1]]= 0
        asso_mat_zy[self.test_xy[:, 0], self.test_xy[:, 1]]= 0
        return asso_mat_zy
    
    # @create knowledge graph data
    def create_kg_data(self):
        if os.path.exists(self.params.kg_file)== False:os.makedirs(self.params.kg_file)
        # drug microbe association data
        drug_micr_train, drug_micr_valid, drug_micr_test= self.train_xy[self.drug_micr_asso_mat[self.train_xy[:, 0], self.train_xy[:, 1]]== 1],\
        self.valid_xy[self.drug_micr_asso_mat[self.valid_xy[:, 0], self.valid_xy[:, 1]]== 1],\
        self.test_xy[self.drug_micr_asso_mat[self.test_xy[:, 0], self.test_xy[:, 1]]== 1]
        # drug disease association data
        drug_dise_train, micr_dise_train, drug_inter_train, micr_inter_train= self.drug_dise_asso_mat.nonzero(), self.micr_dise_asso_mat.nonzero(), self.drug_inter_mat.nonzero(), self.micr_inter_mat.nonzero()
        # add offset
        drug_micr_valid+= torch.tensor([0, self.params.num_drug]); drug_micr_test+= torch.tensor([0, self.params.num_drug]); drug_micr_train+= torch.tensor([0, self.params.num_drug])
        drug_dise_train+= torch.tensor([0, self.params.num_drug+ self.params.num_micr]); micr_dise_train+= torch.tensor([self.params.num_drug, self.params.num_drug+ self.params.num_micr]); micr_inter_train+= torch.tensor([self.params.num_drug, self.params.num_drug])
        # add rel
        drug_micr_train, drug_micr_valid, drug_micr_test= drug_micr_train[:, [0, 1, 1]], drug_micr_valid[:, [0, 1, 1]], drug_micr_test[:, [0, 1, 1]]
        drug_dise_train, micr_dise_train, drug_inter_train, micr_inter_train= drug_dise_train[:, [0, 1, 1]], micr_dise_train[:, [0, 1, 1]], drug_inter_train[:, [0, 1, 1]], micr_inter_train[:, [0, 1, 1]]
        drug_micr_train[:, 2], drug_micr_valid[:, 2], drug_micr_test[:, 2], drug_dise_train[:, 2], micr_dise_train[:, 2], drug_inter_train[:, 2], micr_inter_train[:, 2]= 0, 0, 0, 1, 2, 3, 4
        # savefile
        train= torch.cat([drug_micr_train, drug_dise_train, micr_dise_train, drug_inter_train, micr_inter_train], dim= 0)
        val= drug_micr_valid
        test= drug_micr_test
        np.savetxt(f'{params.kg_file}//train.txt', train, fmt= '%d', delimiter= '\t', encoding= 'utf-8-sig')
        np.savetxt(f'{params.kg_file}//valid.txt', val, fmt= '%d', delimiter= '\t', encoding= 'utf-8-sig')
        np.savetxt(f'{params.kg_file}//test.txt', test, fmt= '%d', delimiter= '\t', encoding= 'utf-8-sig')
        print(f'Knowledge graph data has prepared...')

    # @split data set
    def split_dataset(self):
        train_xy, valid_xy, test_xy= [], [], []
        for i in range(self.params.num_drug):
            first= True
            for j in range(self.params.num_micr):
                if self.drug_micr_asso_mat[i, j]== 1 and first:
                    train_xy.append([i, j])
                    first= False
                else:
                    num= torch.rand(1)
                    if num< self.params.train_ratio:
                        train_xy.append([i, j])
                    elif num>= self.params.train_ratio and num< self.params.train_ratio+ self.params.valid_ratio:
                        valid_xy.append([i, j])
                    else:
                        test_xy.append([i, j])        
        print(f'Spliting data has finished...')
        return torch.tensor(train_xy), torch.tensor(valid_xy), torch.tensor(test_xy)

    # @load disease data
    def load_dise_data(self):
        dise_simi_mat= torch.from_numpy(np.loadtxt(self.params.dise_simi_path, encoding= 'utf-8-sig'))
        drug_dise_drug_simi_mat, micr_dise_micr_simi_mat= torch.matmul(nn.functional.normalize(self.drug_inter_mat, p= 2, dim= 1), nn.functional.normalize(self.drug_inter_mat, p= 2, dim= 1).T),\
        torch.matmul(nn.functional.normalize(self.micr_inter_mat, p= 2, dim= 1), nn.functional.normalize(self.micr_inter_mat, p= 2, dim= 1).T)
        for i in range(self.params.num_drug): drug_dise_drug_simi_mat[i, i]= 1.0
        for i in range(self.params.num_micr): micr_dise_micr_simi_mat[i, i]= 1.0
        return dise_simi_mat, drug_dise_drug_simi_mat, micr_dise_micr_simi_mat

    # @load micr data
    def load_micr_data(self):
        micr_ani_mat= torch.from_numpy(np.loadtxt(self.params.micr_ani_path, encoding= 'utf-8-sig'))
        micr_inter_mat= self.load_adj_data(self.params.micr_inter_adj_path, sp= (self.params.num_micr, self.params.num_micr))
        micr_dise_mat= self.load_adj_data(self.params.micr_dise_adj_path, sp= (self.params.num_micr, self.params.num_dise))
        micr_asso_simi_mat= torch.matmul(nn.functional.normalize(self.drug_micr_asso_mat_zy.T, p= 2, dim= 1), nn.functional.normalize(self.drug_micr_asso_mat_zy.T, p= 2, dim= 1).T)
        for i in range(self.params.num_micr):micr_asso_simi_mat[i, i]= 1
        for i in range(self.params.num_micr):micr_ani_mat[i, i]= 1            
        return micr_ani_mat, micr_inter_mat, micr_dise_mat, micr_asso_simi_mat

    # @load drug data
    def load_drug_data(self):
        drug_dise_asso_mat= self.load_adj_data(self.params.drug_dise_adj_path, sp= (self.params.num_drug, self.params.num_dise))
        drug_micr_asso_mat= self.load_adj_data(self.params.drug_micr_adj_path, sp= (self.params.num_drug, self.params.num_micr))
        drug_inter_mat= self.load_adj_data(self.params.drug_inter_adj_path, sp= (self.params.num_drug, self.params.num_drug))
        drug_struct_simi_mat= torch.from_numpy(np.loadtxt(self.params.drug_struct_simi_path, encoding= 'utf-8-sig'))
        drug_fringer_simi_mat= torch.from_numpy(np.loadtxt(self.params.drug_fringer_simi_path, encoding= 'utf-8-sig'))
        return drug_micr_asso_mat, drug_dise_asso_mat, drug_inter_mat, drug_struct_simi_mat, drug_fringer_simi_mat
    
    # @load adj data
    def load_adj_data(self, path, sp= (1209, 172)):
        idx= torch.from_numpy(np.loadtxt(path, encoding= 'utf-8-sig')).long()- 1
        mat= torch.zeros((sp[0], sp[1]))
        mat[idx[:, 0], idx[:, 1]]= 1
        return mat
    
    # write into memo
    def write2memo(self, mr, mrr, hits10):
        with open(f'{self.params.memo_file4kg}', 'a+') as f:
            f.write(f'{self.params.lr_kg}\t{self.params.weight_decay_kg}\t{mr}\t{mrr}\t{hits10}\n')

In [None]:
dl= dataloader(params)

In [None]:
class EarlyStopping:
	"""docstring for EarlyStopping"""
	def __init__(self, patience, pt_file= 'checkpoint/', file_name= 'checkpoint.pt', mess_out= True, eps= 0):
		super().__init__()
		self.patience, self.eps, self.pt_file, self.file_name, self.mess_out= patience, eps, pt_file, file_name, mess_out
		self.best_score, self.counter, self.flag= None, 0, False
		if os.path.exists(self.pt_file)== False:os.makedirs(self.pt_file)
	
	def __call__(self, val_loss, model):
		score= -val_loss
		if self.best_score is None:
			self.best_score= score
			self.save_checkpoint(model)
		elif score< self.best_score- self.eps:
			self.counter+= 1
			if self.mess_out:print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
			if self.counter>= self.patience:
				self.flag= True
		else:
			self.best_score= score
			self.save_checkpoint(model)
			self.counter= 0

	def save_checkpoint(self, model):
		torch.save(model, f'{self.pt_file}//{self.file_name}')

In [None]:
class MDDataset(Dataset):
    
    def __init__(self, triples, split, params):
        self.triples= triples
        self.split= split
        self.params= params

    def __len__(self):
        return len(self.triples)

    def __getitem__(self, idx):
        ele= self.triples[idx]
        triple, label= torch.LongTensor(ele['triple']), np.int32(ele['label'])
        label= self.get_label(label)
        if self.split== 'train' and self.params.lbl_smooth!= 0.0:
            label= (1.0- self.params.lbl_smooth)* label+ (1.0/ self.params.num_ent)
        return triple, label

    @staticmethod
    def collate_fn(data):
        triple= torch.stack([_[0] for _ in data], dim= 0)
        label= torch.stack([_[1] for _ in data], dim= 0)
        return triple, label

    def get_label(self, label):
        y= np.zeros([self.params.num_ent], dtype= np.float32)
        for e2 in label: y[e2]= 1.0
        return torch.FloatTensor(y)

In [None]:
def load_data(params):

    sr2d, data, tp= ddict(set), ddict(list), 0
    for split in ['train', 'valid', 'test']:
        for line in open(f'{params.kg_file}{split}.txt', encoding= 'utf-8-sig'):
            src_id, dst_id, rel_id= line.strip().split('\t')
            src_id, rel_id, dst_id= int(float(src_id)), int(float(rel_id)), int(float(dst_id))
            data[split].append((src_id, rel_id, dst_id))
            if split== 'train':
                sr2d[(src_id, rel_id)].add(dst_id)
                if rel_id in [0, 1, 2]:sr2d[(dst_id, rel_id+ 5)].add(src_id)
    sr2d4tr= {k: list(v) for k, v in sr2d.items()}
    triples= ddict(list)
    for (src_id, rel_id), dst_id in sr2d4tr.items():
        triples['train'].append({'triple': (src_id, rel_id, -1), 'label': dst_id})
    for split in ['test', 'valid']:
        for src_id, rel_id, dst_id in data[split]:
            sr2d[(src_id, rel_id)].add(dst_id)
            sr2d[(dst_id, rel_id+ 5)].add(src_id)
    sr2d4val_te= {k: list(v) for k, v in sr2d.items()}
    for split in ['valid', 'test']:
        for src_id, rel_id, dst_id in data[split]:
            triples[f'{split}'].append({'triple': (src_id, rel_id, dst_id), 'label': sr2d4val_te[(src_id, rel_id)]})
    triples= dict(triples)
    def get_data_loader(dataset_class, split, batch_size, shuffle= True):
        return DataLoader(dataset_class(triples[split], split, params), batch_size= batch_size, shuffle= shuffle, collate_fn= dataset_class.collate_fn)
    data_iter= {
        'train': get_data_loader(MDDataset, 'train', params.batch_size),
        'valid': get_data_loader(MDDataset, 'valid', params.batch_size),
        'test': get_data_loader(MDDataset, 'test', params.batch_size)}
    return data_iter, triples

In [None]:
data_iter, triples= load_data(params)

In [None]:
class ConvE(torch.nn.Module):
    def __init__(self, params):
        super().__init__()
        self.params= params
        self.ent_embed= torch.nn.Embedding(self.params.num_ent, self.params.embed_dim, padding_idx= None); nn.init.xavier_normal_(self.ent_embed.weight)
        self.rel_embed= torch.nn.Embedding(self.params.num_rel, self.params.embed_dim, padding_idx= None); nn.init.xavier_normal_(self.rel_embed.weight)
        self.input_drop, self.feature_drop, self.hidden_drop= torch.nn.Dropout(0.3), torch.nn.Dropout2d(0.3), torch.nn.Dropout(0.3)
        self.bn0, self.bn1, self.bn2= torch.nn.BatchNorm2d(1), torch.nn.BatchNorm2d(32), torch.nn.BatchNorm1d(self.params.embed_dim)
        self.cv1= nn.Conv2d(1, out_channels= 32, kernel_size= (3, 3), stride= 1, padding= 0, bias= True)
        self.fc= nn.Sequential(nn.Linear(32* 14* 14, 2* self.params.embed_dim), nn.ReLU(), nn.Dropout(0.3), nn.Linear(2* self.params.embed_dim, self.params.embed_dim));nn.init.xavier_normal_(self.fc[0].weight, nn.init.calculate_gain('relu'));nn.init.xavier_normal_(self.fc[3].weight)
        self.act1, self.act2= nn.ReLU(), nn.Sigmoid()
        self.register_parameter('bias', nn.Parameter(torch.zeros(self.params.num_ent)))        
        self.loss_fuc= torch.nn.BCELoss()
        
    def loss(self, pred, true_label):
        return self.loss_fuc(pred, true_label)
        
    def forward(self, src, rel):
        semb, remb= self.ent_embed(src).view(-1, 1, 8, 16), self.rel_embed(rel).view(-1, 1, 8, 16)
        # (128, 1, 16, 16)
        emb= torch.cat([semb, remb], dim= 2)
        x= self.input_drop(self.bn0(emb))
        # (128, 32, 14, 14)
        x= self.feature_drop(self.act1(self.bn1(self.cv1(x))))
        x= self.act1(self.bn2(self.hidden_drop(self.fc(x.view(src.shape[0], -1)))))
        x= torch.mm(x, self.ent_embed.weight.transpose(1, 0))
        x+= self.bias.expand_as(x)
        pred= self.act2(x)
        return pred

In [None]:
def evaluate(net, data_iter, params, split= 'valid'):
    net.eval()
    with torch.no_grad():
        results= {}
        train_iter= iter(data_iter[split])
        for step, batch in enumerate(train_iter):
            triple, label= [_.to(params.device) for _ in batch]
            src, rel, dst, label= triple[:, 0], triple[:, 1], triple[:, 2], label
            pred= net.forward(src, rel)
            b_range= torch.arange(pred.size()[0], device= params.device)
            target_pred= pred[b_range, dst]
            pred= torch.where(label.byte(), torch.zeros_like(pred), pred)
            pred[b_range, dst]= target_pred
            pred= pred.cpu().numpy()
            dst= dst.cpu().numpy()
            for i in range(pred.shape[0]):
                scores= pred[i]
                target= dst[i]
                tar_scr= scores[target]
                scores= np.delete(scores, target)
                rand= np.random.randint(scores.shape[0])
                scores= np.insert(scores, rand, tar_scr)
                sorted_indices= np.argsort(-scores, kind= 'stable')
                _filter= np.where(sorted_indices== rand)[0][0]
                results['count']= 1+ results.get('count', 0.0)
                results['mr']= (_filter+ 1)+ results.get('mr', 0.0)
                results['mrr']= (1.0/ (_filter+ 1))+ results.get('mrr', 0.0)
                for k in range(10):
                    if _filter<= k:
                        results[f'hits@{k+ 1}']= 1+ results.get(f'hits@{k+ 1}', 0.0)
    results['mr']= round(results['mr']/ float(results['count']), 5)
    results['mrr']= round(results['mrr']/ float(results['count']), 5)
    for k in range(10):
        results[f'hits@{k+1}']= round(results.get(f'hits@{k+ 1}', 0)/ float(results['count']), 5)
    return results

In [None]:
net= ConvE(params).to(params.device)
optimizer= torch.optim.Adam(net.parameters(), lr= params.lr_kg, weight_decay= params.weight_decay_kg)
earlystopping4kg= EarlyStopping(patience= params.patience_kg, pt_file= params.pt_file, file_name= params.pt_file_name4net1, mess_out= True)

In [None]:
for epoch in range(params.epoch_kg):
    # train
    net.train()
    losses= []
    train_iter= iter(data_iter['train'])
    for step, batch in enumerate(train_iter):
        optimizer.zero_grad()
        triple, label= [_.to(params.device) for _ in batch]
        src, rel, dst, label= triple[:, 0], triple[:, 1], triple[:, 2], label
        pred= net.forward(src, rel)
        loss= net.loss(pred, label)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        if step% 5000== 0:
            print(f'epoch: {epoch}, Train loss: {np.mean(losses)}, rel ratio: {((rel== 0)).sum()/ len(rel)}, acc: {(torch.where(pred< 0.5, 0, 1)== torch.where(label< 0.5, 0, 1)).sum()/ label.shape[0]/ label.shape[1]}')
    # valid
    results= evaluate(net, data_iter, params, 'valid')
    print(f'val mr: {results["mr"]}, mrr: {results["mrr"]}, hits10: {results["hits@10"]}')
    earlystopping4kg(-(results['mrr']+ results['hits@10']), net)
    if earlystopping4kg.flag== True: break

In [None]:
net= torch.load(f'{params.pt_file}//{params.pt_file_name4net1}')
valid_results= evaluate(net, data_iter, params, 'valid')
print(f'valid results, {valid_results}')

In [None]:
ent_emb, rel_emb= net.ent_embed.weight, net.rel_embed.weight

In [None]:
train_label, valid_label, test_label= dl.drug_micr_asso_mat[dl.train_xy[:, 0], dl.train_xy[:, 1]].long(), dl.drug_micr_asso_mat[dl.valid_xy[:, 0], dl.valid_xy[:, 1]].long(), dl.drug_micr_asso_mat[dl.test_xy[:, 0], dl.test_xy[:, 1]].long()
# train
train_xy_label_dataset= torch.utils.data.TensorDataset(dl.train_xy, train_label)
train_loader= torch.utils.data.DataLoader(train_xy_label_dataset, batch_size= params.batch_size, shuffle= False)
# valid
valid_xy_label_dataset= torch.utils.data.TensorDataset(dl.valid_xy, valid_label)
valid_loader= torch.utils.data.DataLoader(valid_xy_label_dataset, batch_size= params.batch_size, shuffle= False)
# test, 
test_xy_label_dataset= torch.utils.data.TensorDataset(dl.test_xy, test_label)
test_loader= torch.utils.data.DataLoader(test_xy_label_dataset, batch_size= params.batch_size, shuffle= False)

In [None]:
def avg_auc_aupr_cpt(test_xy, test_label, pred, ass_mat_shape):
    label_mat, pred_mat= torch.zeros((ass_mat_shape)) -1, torch.zeros((ass_mat_shape)) -1
    label_mat[test_xy[:, 0], test_xy[:, 1]], pred_mat[test_xy[:, 0], test_xy[:, 1]]= test_label* 1.0, pred
    bool_mat4mark_test_examp= (label_mat!= -1)
    aucs, auprs= [], []
    for i in range(ass_mat_shape[0]):
        test_examp_loc= bool_mat4mark_test_examp[i]
        pos_num= label_mat[i, test_examp_loc].sum()
        if pos_num> 0 and (test_examp_loc).sum()- pos_num> 0:
            fpr4rowi, tpr4rowi, _= roc_curve(label_mat[i, test_examp_loc], pred_mat[i, test_examp_loc])
            prec4rowi, recall4rowi, _= precision_recall_curve(label_mat[i, test_examp_loc], pred_mat[i, test_examp_loc])
            prec4rowi[-1]= [1, 0][(int)(prec4rowi[-2]== 0)]
            aucs.append(auc(fpr4rowi, tpr4rowi));auprs.append(auc(recall4rowi, prec4rowi))
    return np.mean(aucs), np.mean(auprs)

In [None]:
pred= net(torch.arange(0, 1535, 1).to(torch.long).to(params.device), torch.zeros(1535).to(torch.long).to(params.device))[:, 1209: 1209+ 172].detach().clone().cpu()
pred= pred[dl.test_xy[:, 0], dl.test_xy[:, 1]]
results= torch.cat([dl.test_xy.to('cpu'), test_label.view(-1, 1).to('cpu'), pred.view(-1, 1).to('cpu')], dim= 1)
print(avg_auc_aupr_cpt(dl.test_xy.to('cpu'), test_label.to('cpu'), pred.to('cpu'), (1209, 172)))
# np.savetxt(fname= params.test_result_file, X= results, delimiter= '\t', encoding= 'utf-8-sig')