### move file

In [31]:
cp -r /kaggle/input/mdd-dataset /kaggle/working/

In [None]:
mkdir /kaggle/working/checkpoint/

In [33]:
cp /kaggle/input/pcmda/pytorch/default/1/net1s1.pt /kaggle/working/checkpoint/

In [34]:
cp /kaggle/input/pcmda/pytorch/default/1/net2s1.pt /kaggle/working/checkpoint/

### import package

In [35]:
import os
import time
import torch
import random
import warnings
import argparse
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import roc_curve, roc_auc_score, auc, precision_recall_curve
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from collections import defaultdict as ddict
warnings.filterwarnings('ignore')

### setting parser

In [36]:
parser= argparse.ArgumentParser(description= 'Parser for Arguments')
parser.add_argument('-seed', type= int, default= 0)
parser.add_argument('-num_ent', type= int, default= 1209+ 172+ 154)
parser.add_argument('-num_drug', type= int, default= 1209)
parser.add_argument('-num_micr', type= int, default= 172)
parser.add_argument('-num_dise', type= int, default= 154)
# drug_micr_rel, 0; drug_dise_rel, 1; micr_dise_rel, 2; drug_inter_rel, 3; micr_inter_rel, 4; micr_drug_rel, 5; dise_drug_rel, 6; dies_micr_rel, 7.
parser.add_argument('-num_rel', type= int, default= 8)
parser.add_argument('-drug_name_path', type= str, default= './mdd-dataset/mdd/drug/drug_name.txt')
parser.add_argument('-micr_name_path', type= str, default= './mdd-dataset/mdd/microbe/microbe_name.txt')
parser.add_argument('-dise_name_path', type= str, default= './mdd-dataset/mdd/disease/disease_name.txt')
parser.add_argument('-drug_micr_adj_path', type= str, default= './mdd-dataset/mdd/adj/microbe_drug_adj.txt')
parser.add_argument('-drug_struct_simi_path', type= str, default= './mdd-dataset/mdd/drug/drug_struct_simi.txt')
parser.add_argument('-drug_inter_adj_path', type= str, default= './mdd-dataset/mdd/drug/drug_interact_adj.txt')
parser.add_argument('-drug_dise_adj_path', type= str, default= './mdd-dataset/mdd/adj/drug_disease_adj.txt')
parser.add_argument('-micr_ani_path', type= str, default= './mdd-dataset/mdd/microbe/microbe_ani_simi.txt')
parser.add_argument('-micr_inter_adj_path', type= str, default= './mdd-dataset/mdd/microbe/microbe_interact_adj.txt')
parser.add_argument('-micr_gene_simi_path', type= str, default= './mdd-dataset/mdd/microbe/microbe_gene_simi.txt')
parser.add_argument('-micr_dise_adj_path', type= str, default= './mdd-dataset/mdd/adj/microbe_disease_adj.txt')
parser.add_argument('-dise_simi_path', type= str, default= './mdd-dataset/mdd/disease/disease_dag_simi.txt')
parser.add_argument('-train_ratio', type= float, default= 0.9)
parser.add_argument('-valid_ratio', type= float, default= 0.09)
parser.add_argument('-test_ratio', type= float, default= 0.01)
parser.add_argument('-kg_file', type= str, default= 'kg_data/')
parser.add_argument('-batch_size', type= int, default= 128)
parser.add_argument('-lbl_smooth', type= float, default= 0.2, help= 'Label smoothing enable or disable')
parser.add_argument('-embed_dim', type= int, default= 128)
parser.add_argument('-device', type= str, default= 'cuda:0')
parser.add_argument('-lr_net1', type= float, default= 1e-3)
parser.add_argument('-weight_decay_net1', type= float, default= 0)
parser.add_argument('-patience_net1', type= int, default= 50)
parser.add_argument('-epoch_net1', type= int, default= 10000)
parser.add_argument('-lr_net2', type= float, default= 1e-3)
parser.add_argument('-weight_decay_net2', type= float, default= 1e-4)
parser.add_argument('-patience_net2', type= int, default= 6)
parser.add_argument('-epoch_net2', type= int, default= 300)
parser.add_argument('-pt_file', type= str, default= 'checkpoint/')
parser.add_argument('-balance_factor_net1', type= float, default= 1)
parser.add_argument('-threshold_net1', type= float, default= 0.8)
parser.add_argument('-walk_prob', type= float, default= 0.8)
parser.add_argument('-walk_epoch', type= int, default= 500)
parser.add_argument('-walk_eps', type= float, default= 1e-10)
parser.add_argument('-walk_err', type= float, default= 1e-6)
parser.add_argument('-pt_file_name_net1', type= str, default= 'net1s1.pt')
parser.add_argument('-pt_file_name_net2', type= str, default= 'net2s1.pt')
parser.add_argument('-test_result_file', type= str, default= './result/pcmda_test_result.txt')
params= parser.parse_args([])
params.device= 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [37]:
class dataloader(object):
    def __init__(self, params):
        super().__init__()
        self.params= params
        self.drug_micr_asso_mat, self.drug_dise_asso_mat, self.drug_inter_mat, self.drug_struct_simi_mat= self.load_drug_data()
        self.train_xy, self.valid_xy, self.test_xy= self.split_dataset()
        self.drug_micr_asso_mat_zy= self.get_asso_mat_zy()
        self.micr_ani_mat, self.micr_inter_mat, self.micr_dise_asso_mat, self.micr_asso_simi_mat, self.micr_gene_simi_mat, self.micr_inte_simi_mat, self.drug_asso_simi_mat= self.load_micr_data()
        self.dise_simi_mat, self.drug_dise_drug_simi_mat, self.micr_dise_micr_simi_mat= self.load_dise_data()
        # to ensure less noise, we use consistence matrix instead of microbe integate similarity matrix to construct hete graph.
        self.hete_graph_mat= torch.cat([torch.cat([self.drug_struct_simi_mat, self.drug_micr_asso_mat_zy, self.drug_dise_asso_mat], dim= 1),\
                                    torch.cat([self.drug_micr_asso_mat_zy.T, self.micr_ani_mat, self.micr_dise_asso_mat], dim= 1),\
                                    torch.cat([self.drug_dise_asso_mat.T, self.micr_dise_asso_mat.T, self.dise_simi_mat], dim= 1)], dim= 0).float()
        self.struct_coef_mat1= self.random_walk_root(self.hete_graph_mat, params.walk_prob, params.walk_epoch, 1, params.walk_eps, params.walk_err)
        self.struct_coef_mat2= self.random_walk_root(self.hete_graph_mat, params.walk_prob, params.walk_epoch, 2, params.walk_eps, params.walk_err)
        self.struct_coef_mat3= self.random_walk_root(self.hete_graph_mat, params.walk_prob, params.walk_epoch, 3, params.walk_eps, params.walk_err)
        self.create_kg_data()
        self.introduce()

    # @introduce data
    def introduce(self):
        print(f'Drug microbe association num: {self.drug_micr_asso_mat.sum()}\nDrug interaction num: {self.drug_inter_mat.sum()}\nMicrobe interaction num: {self.micr_inter_mat.sum()}')
        print(f'Drug disease association num: {self.drug_dise_asso_mat.sum()}\nMicrobe disease association num: {self.micr_dise_asso_mat.sum()}')

    # @ mask
    def get_asso_mat_zy(self):
        asso_mat_zy= self.drug_micr_asso_mat.clone()
        asso_mat_zy[self.valid_xy[:, 0], self.valid_xy[:, 1]]= 0
        asso_mat_zy[self.test_xy[:, 0], self.test_xy[:, 1]]= 0
        return asso_mat_zy

    # @ 计算hamming interaction profile similarity.
    def hip_sim(self, mat):
        sim_ls, dim= [], mat.shape[1]
        for i in range(mat.shape[0]):
            sim_ls.append(((mat[i]- mat) == 0).sum(dim= 1)/ dim)
        return torch.stack(sim_ls)
    
    # norm1
    def norm_z(self, x, eps= 1e-10):
        return (x- x.mean())/ (x.std()+ eps)

    # norm2
    def norm_min_max(self, x, eps= 1e-10):
        return (x- x.min())/ (x.max()- x.min()+ eps)    
    
    # @create knowledge graph data
    def create_kg_data(self):
        if os.path.exists(self.params.kg_file)== False:os.makedirs(self.params.kg_file)
        # drug microbe association data
        drug_micr_train, drug_micr_valid, drug_micr_test= self.train_xy[self.drug_micr_asso_mat[self.train_xy[:, 0], self.train_xy[:, 1]]== 1],\
        self.valid_xy[self.drug_micr_asso_mat[self.valid_xy[:, 0], self.valid_xy[:, 1]]== 1],\
        self.test_xy[self.drug_micr_asso_mat[self.test_xy[:, 0], self.test_xy[:, 1]]== 1]
        # drug disease association data
        drug_dise_train, micr_dise_train, drug_inter_train, micr_inter_train= self.drug_dise_asso_mat.nonzero(), self.micr_dise_asso_mat.nonzero(), self.drug_inter_mat.nonzero(), self.micr_inter_mat.nonzero()
        # add offset
        drug_micr_valid+= torch.tensor([0, self.params.num_drug]); drug_micr_test+= torch.tensor([0, self.params.num_drug]); drug_micr_train+= torch.tensor([0, self.params.num_drug])
        drug_dise_train+= torch.tensor([0, self.params.num_drug+ self.params.num_micr]); micr_dise_train+= torch.tensor([self.params.num_drug, self.params.num_drug+ self.params.num_micr]); micr_inter_train+= torch.tensor([self.params.num_drug, self.params.num_drug])
        # add rel
        drug_micr_train, drug_micr_valid, drug_micr_test= drug_micr_train[:, [0, 1, 1]], drug_micr_valid[:, [0, 1, 1]], drug_micr_test[:, [0, 1, 1]]
        drug_dise_train, micr_dise_train, drug_inter_train, micr_inter_train= drug_dise_train[:, [0, 1, 1]], micr_dise_train[:, [0, 1, 1]], drug_inter_train[:, [0, 1, 1]], micr_inter_train[:, [0, 1, 1]]
        drug_micr_train[:, 2], drug_micr_valid[:, 2], drug_micr_test[:, 2], drug_dise_train[:, 2], micr_dise_train[:, 2], drug_inter_train[:, 2], micr_inter_train[:, 2]= 0, 0, 0, 1, 2, 3, 4
        # savefile
        train= torch.cat([drug_micr_train, drug_dise_train, micr_dise_train, drug_inter_train, micr_inter_train], dim= 0)
        val= drug_micr_valid
        test= drug_micr_test
        np.savetxt(f'{params.kg_file}//train.txt', train, fmt= '%d', delimiter= '\t', encoding= 'utf-8-sig')
        np.savetxt(f'{params.kg_file}//valid.txt', val, fmt= '%d', delimiter= '\t', encoding= 'utf-8-sig')
        np.savetxt(f'{params.kg_file}//test.txt', test, fmt= '%d', delimiter= '\t', encoding= 'utf-8-sig')
        print(f'Knowledge graph data has prepared...')

    # @split data set
    def split_dataset(self):
        train_xy, valid_xy, test_xy= [], [], []
        for i in range(self.params.num_drug):
            first= True
            for j in range(self.params.num_micr):
                if self.drug_micr_asso_mat[i, j]== 1 and first:
                    train_xy.append([i, j])
                    first= False
                else:
                    num= torch.rand(1)
                    if num< self.params.train_ratio:
                        train_xy.append([i, j])
                    elif num>= self.params.train_ratio and num< self.params.train_ratio+ self.params.valid_ratio:
                        valid_xy.append([i, j])
                    else:
                        test_xy.append([i, j])        
        print(f'Spliting data has finished...')
        return torch.tensor(train_xy), torch.tensor(valid_xy), torch.tensor(test_xy)

    # @load disease data
    def load_dise_data(self):
        dise_simi_mat= torch.from_numpy(np.loadtxt(self.params.dise_simi_path, encoding= 'utf-8-sig'))
        drug_dise_drug_simi_mat, micr_dise_micr_simi_mat= torch.matmul(nn.functional.normalize(self.drug_inter_mat, p= 2, dim= 1), nn.functional.normalize(self.drug_inter_mat, p= 2, dim= 1).T),\
        torch.matmul(nn.functional.normalize(self.micr_inter_mat, p= 2, dim= 1), nn.functional.normalize(self.micr_inter_mat, p= 2, dim= 1).T)
        for i in range(self.params.num_drug): drug_dise_drug_simi_mat[i, i]= 1.0
        for i in range(self.params.num_micr): micr_dise_micr_simi_mat[i, i]= 1.0
        return dise_simi_mat, drug_dise_drug_simi_mat, micr_dise_micr_simi_mat

    # @load micr data
    def load_micr_data(self):
        micr_ani_mat= torch.from_numpy(np.loadtxt(self.params.micr_ani_path, encoding= 'utf-8-sig'))/ 100
        micr_gene_simi_mat= (torch.from_numpy(np.loadtxt(self.params.micr_gene_simi_path, encoding= 'utf-8-sig'))+ 1)/ 2
        micr_inter_mat= self.load_adj_data(self.params.micr_inter_adj_path, sp= (self.params.num_micr, self.params.num_micr))
        micr_dise_mat= self.load_adj_data(self.params.micr_dise_adj_path, sp= (self.params.num_micr, self.params.num_dise))
        micr_asso_simi_mat= torch.matmul(nn.functional.normalize(self.drug_micr_asso_mat_zy.T, p= 2, dim= 1), nn.functional.normalize(self.drug_micr_asso_mat_zy.T, p= 2, dim= 1).T)
        drug_asso_simi_mat= torch.matmul(nn.functional.normalize(self.drug_micr_asso_mat_zy, p= 2, dim= 1), nn.functional.normalize(self.drug_micr_asso_mat_zy, p= 2, dim= 1).T)
        for i in range(self.params.num_drug):drug_asso_simi_mat[i, i]= 1
        for i in range(self.params.num_micr):micr_asso_simi_mat[i, i]= 1
        for i in range(self.params.num_micr):micr_ani_mat[i, i]= 1
        return micr_ani_mat, micr_inter_mat, micr_dise_mat, micr_asso_simi_mat, micr_gene_simi_mat, torch.where(micr_ani_mat> 0, micr_ani_mat, micr_gene_simi_mat), drug_asso_simi_mat

    # @load drug data
    def load_drug_data(self):
        drug_dise_asso_mat= self.load_adj_data(self.params.drug_dise_adj_path, sp= (self.params.num_drug, self.params.num_dise))
        drug_micr_asso_mat= self.load_adj_data(self.params.drug_micr_adj_path, sp= (self.params.num_drug, self.params.num_micr))
        drug_inter_mat= self.load_adj_data(self.params.drug_inter_adj_path, sp= (self.params.num_drug, self.params.num_drug))
        drug_struct_simi_mat= torch.from_numpy(np.loadtxt(self.params.drug_struct_simi_path, encoding= 'utf-8-sig'))
        return drug_micr_asso_mat, drug_dise_asso_mat, drug_inter_mat, drug_struct_simi_mat
    
    # @load adj data
    def load_adj_data(self, path, sp= (1209, 172)):
        idx= torch.from_numpy(np.loadtxt(path, encoding= 'utf-8-sig')).long()- 1
        mat= torch.zeros((sp[0], sp[1]))
        mat[idx[:, 0], idx[:, 1]]= 1
        return mat
    
    # @NRWR, neighborhood random walk with restart
    def random_walk_root(self, A, alpha, epoch, nei, eps, err):
        Mask_mat= torch.sign(torch.matrix_power(A, nei)- torch.matrix_power(A, nei- 1))
        W= nn.functional.normalize(A, p= 1, dim= 1)
        S_last= torch.eye(A.shape[0])
        for i in range(epoch):
            S_new= alpha* torch.matmul(S_last, W)+ (1- alpha)* torch.eye(A.shape[0])
            S_new= nn.functional.normalize(Mask_mat* S_new, p= 1, dim= 1)
            if (S_new- S_last).abs().sum()<= err* A.shape[0]** 2:
                print('converge...');break
            S_last= S_new
        return S_new

In [38]:
class EarlyStopping:
	"""docstring for EarlyStopping"""
	def __init__(self, patience, pt_file= 'checkpoint/', file_name= 'checkpoint.pt', mess_out= True, eps= 0):
		super().__init__()
		self.patience, self.eps, self.pt_file, self.file_name, self.mess_out= patience, eps, pt_file, file_name, mess_out
		self.best_score, self.counter, self.flag= None, 0, False
		if os.path.exists(self.pt_file)== False:os.makedirs(self.pt_file)
	
	def __call__(self, val_loss, model):
		score= -val_loss
		if self.best_score is None:
			self.best_score= score
			self.save_checkpoint(model)
		elif score<= self.best_score- self.eps:
			self.counter+= 1
			if self.mess_out:print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
			if self.counter>= self.patience:
				self.flag= True
		else:
			self.best_score= score
			self.save_checkpoint(model)
			self.counter= 0

	def save_checkpoint(self, model):
		torch.save(model, f'{self.pt_file}//{self.file_name}')

In [39]:
class MDDataset(Dataset):    
    def __init__(self, triples, split, params):
        self.triples= triples
        self.split= split
        self.params= params
    def __len__(self):
        return len(self.triples)
    def __getitem__(self, idx):
        ele= self.triples[idx]
        triple, label= torch.LongTensor(ele['triple']), np.int32(ele['label'])
        label= self.get_label(label)
        if self.split== 'train' and self.params.lbl_smooth!= 0.0:
            label= (1.0- self.params.lbl_smooth)* label+ (1.0/ self.params.num_ent)
        return triple, label
    @staticmethod
    def collate_fn(data):
        triple= torch.stack([_[0] for _ in data], dim= 0)
        label= torch.stack([_[1] for _ in data], dim= 0)
        return triple, label
    def get_label(self, label):
        y= np.zeros([self.params.num_ent], dtype= np.float32)
        for e2 in label: y[e2]= 1.0
        return torch.FloatTensor(y)

In [40]:
def load_data(params):
    sr2d, data, tp= ddict(set), ddict(list), 0
    for split in ['train', 'valid', 'test']:
        for line in open(f'{params.kg_file}{split}.txt', encoding= 'utf-8-sig'):
            src_id, dst_id, rel_id= line.strip().split('\t')
            src_id, rel_id, dst_id= int(float(src_id)), int(float(rel_id)), int(float(dst_id))
            data[split].append((src_id, rel_id, dst_id))
            if split== 'train':
                sr2d[(src_id, rel_id)].add(dst_id)
                if rel_id in [0, 1, 2]:sr2d[(dst_id, rel_id+ 5)].add(src_id)
    sr2d4tr= {k: list(v) for k, v in sr2d.items()}
    triples= ddict(list)
    for (src_id, rel_id), dst_id in sr2d4tr.items():
        triples['train'].append({'triple': (src_id, rel_id, -1), 'label': dst_id})
    for split in ['test', 'valid']:
        for src_id, rel_id, dst_id in data[split]:
            sr2d[(src_id, rel_id)].add(dst_id)
            sr2d[(dst_id, rel_id+ 5)].add(src_id)
    sr2d4val_te= {k: list(v) for k, v in sr2d.items()}
    for split in ['valid', 'test']:
        for src_id, rel_id, dst_id in data[split]:
            triples[f'{split}'].append({'triple': (src_id, rel_id, dst_id), 'label': sr2d4val_te[(src_id, rel_id)]})
    triples= dict(triples)
    def get_data_loader(dataset_class, split, batch_size, shuffle= True):
        return DataLoader(dataset_class(triples[split], split, params), batch_size= batch_size, shuffle= shuffle, collate_fn= dataset_class.collate_fn)
    data_iter= {
        'train': get_data_loader(MDDataset, 'train', params.batch_size),
        'valid': get_data_loader(MDDataset, 'valid', params.batch_size),
        'test': get_data_loader(MDDataset, 'test', params.batch_size)}
    return data_iter, triples

In [41]:
class SAKG(torch.nn.Module):
    def __init__(self, params, struct_coef_mat1, struct_coef_mat2):
        super().__init__()
        self.params= params
        self.scm1, self.scm2= nn.functional.normalize(struct_coef_mat1, p= 2, dim= 1), nn.functional.normalize(struct_coef_mat2, p= 2, dim= 1)
        self.fc4att= nn.Sequential(nn.Linear(self.params.embed_dim* 3, self.params.embed_dim), nn.ReLU(), nn.Dropout(0.3), nn.Linear(self.params.embed_dim, 2))
        self.fc1, self.fc1res= nn.Sequential(nn.Linear(2, 16), nn.ReLU(), nn.Dropout(0.3), nn.Linear(16, 1)), nn.Linear(2, 1)
        self.fc2= nn.Sequential(nn.Linear(self.params.embed_dim, 2* self.params.embed_dim), nn.ReLU(), nn.Dropout(0.3), nn.Linear(2* self.params.embed_dim, self.params.embed_dim))
        self.ent_embed= torch.nn.Embedding(self.params.num_ent, self.params.embed_dim, padding_idx= None); nn.init.xavier_normal_(self.ent_embed.weight)
        self.rel_embed= torch.nn.Embedding(self.params.num_rel, self.params.embed_dim, padding_idx= None); nn.init.xavier_normal_(self.rel_embed.weight)
        self.bceloss= torch.nn.BCELoss()
        self.init_para()
    def init_para(self):
        nn.init.xavier_normal_(self.fc4att[0].weight, nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.fc4att[3].weight)
        nn.init.xavier_normal_(self.fc1[0].weight, nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.fc1[3].weight)
        nn.init.xavier_normal_(self.fc1res.weight)
        nn.init.xavier_normal_(self.fc2[0].weight, nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.fc2[3].weight)
    def loss(self, pred, true_label, mats, balance_factor= 0.3, threshold= 0.8):
        drug_embed, micr_embed, dise_embed= nn.functional.normalize(self.ent_embed.weight[0: self.params.num_drug], p= 2, dim= 1), nn.functional.normalize(self.ent_embed.weight[self.params.num_drug: self.params.num_drug+ self.params.num_micr], p= 2, dim= 1), nn.functional.normalize(self.ent_embed.weight[self.params.num_drug+ self.params.num_micr: ], p= 2, dim= 1)
        mat0, mat1, mat2= torch.matmul(drug_embed, drug_embed.T), torch.matmul(micr_embed, micr_embed.T), torch.matmul(dise_embed, dise_embed.T)
        mse_loss= (((mat0- mats[0])* ((mats[0]!= 0)| (mat0>= threshold)))** 2).mean()+ ((mat1- mats[1])** 2).mean()+ ((mat2- mats[2])** 2).mean()
        return balance_factor* mse_loss+ self.bceloss(pred, true_label)
    def forward(self, src, rel):
        X1, X2, X3= self.ent_embed.weight, self.scm1@ self.ent_embed.weight, self.scm2@ self.ent_embed.weight
        att_mat= torch.softmax(self.fc4att(torch.cat([X1, X2, X3], dim= 1)), dim= 1)
        X= X1+ att_mat[:, 0].view(-1, 1)* X2+ att_mat[:, 1].view(-1, 1)* X3
        src_emb, rel_emb= X[src], self.rel_embed(rel)
        fea1= torch.cat([src_emb.unsqueeze(1), rel_emb.unsqueeze(1)], dim= 1).transpose(1, 2)
        fea2= (self.fc1(fea1)+ self.fc1res(fea1)).squeeze(2)
        return torch.sigmoid(torch.matmul(self.fc2(fea2), X.T))

In [42]:
def evaluate(net, data_iter, params, split= 'valid'):
    net.eval()
    with torch.no_grad():
        results= {}
        train_iter= iter(data_iter[split])
        for step, batch in enumerate(train_iter):
            triple, label= [_.to(params.device) for _ in batch]
            src, rel, dst, label= triple[:, 0], triple[:, 1], triple[:, 2], label
            pred= net.forward(src, rel)
            b_range= torch.arange(pred.size()[0], device= params.device)
            target_pred= pred[b_range, dst]
            pred= torch.where(label.byte(), torch.zeros_like(pred), pred)
            pred[b_range, dst]= target_pred
            pred= pred.cpu().numpy()
            dst= dst.cpu().numpy()
            for i in range(pred.shape[0]):
                scores= pred[i]
                target= dst[i]
                tar_scr= scores[target]
                scores= np.delete(scores, target)
                rand= np.random.randint(scores.shape[0])
                scores= np.insert(scores, rand, tar_scr)
                sorted_indices= np.argsort(-scores, kind= 'stable')
                _filter= np.where(sorted_indices== rand)[0][0]
                results['count']= 1+ results.get('count', 0.0)
                results['mr']= (_filter+ 1)+ results.get('mr', 0.0)
                results['mrr']= (1.0/ (_filter+ 1))+ results.get('mrr', 0.0)
                for k in range(10):
                    if _filter<= k:
                        results[f'hits@{k+ 1}']= 1+ results.get(f'hits@{k+ 1}', 0.0)
    results['mr']= round(results['mr']/ float(results['count']), 5)
    results['mrr']= round(results['mrr']/ float(results['count']), 5)
    for k in range(10):
        results[f'hits@{k+1}']= round(results.get(f'hits@{k+ 1}', 0)/ float(results['count']), 5)
    return results

In [None]:
random.seed(params.seed)
np.random.seed(params.seed)
torch.manual_seed(params.seed)
torch.cuda.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
os.environ['PYTHONHASHSEED']= str(params.seed)    
torch.backends.cudnn.deterministic= True
dl= dataloader(params)
# split dataset
train_label, valid_label, test_label= dl.drug_micr_asso_mat[dl.train_xy[:, 0], dl.train_xy[:, 1]].long(), dl.drug_micr_asso_mat[dl.valid_xy[:, 0], dl.valid_xy[:, 1]].long(), dl.drug_micr_asso_mat[dl.test_xy[:, 0], dl.test_xy[:, 1]].long()
train_xy_label_dataset= torch.utils.data.TensorDataset(dl.train_xy, train_label)
train_loader= torch.utils.data.DataLoader(train_xy_label_dataset, batch_size= params.batch_size, shuffle= True)
valid_xy_label_dataset= torch.utils.data.TensorDataset(dl.valid_xy, valid_label)
valid_loader= torch.utils.data.DataLoader(valid_xy_label_dataset, batch_size= params.batch_size, shuffle= False)
test_xy_label_dataset= torch.utils.data.TensorDataset(dl.test_xy, test_label)
test_loader= torch.utils.data.DataLoader(test_xy_label_dataset, batch_size= params.batch_size, shuffle= False)
# knowledge graph data
data_iter, triples= load_data(params)
net= SAKG(params, dl.struct_coef_mat1.to(params.device), dl.struct_coef_mat2.to(params.device)).to(params.device)
optimizer= torch.optim.Adam(net.parameters(), lr= params.lr_net1, weight_decay= params.weight_decay_net1)
earlystopping4kg= EarlyStopping(patience= params.patience_net1, pt_file= params.pt_file, file_name= params.pt_file_name_net1, mess_out= True)
mats= [dl.drug_struct_simi_mat.clone().to(params.device), (dl.micr_inte_simi_mat).clone().to(params.device), dl.dise_simi_mat.clone().to(params.device)]

In [None]:
for epoch in range(params.epoch_net1):
    net.train()
    losses= []
    train_iter= iter(data_iter['train'])
    for step, batch in enumerate(train_iter):
        optimizer.zero_grad()
        triple, label= [_.to(params.device) for _ in batch]
        src, rel, dst, label= triple[:, 0], triple[:, 1], triple[:, 2], label
        pred= net.forward(src, rel)
        loss= net.loss(pred, label, mats, params.balance_factor_net1, params.threshold_net1)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        if step% 5000== 0:
            print(f'epoch: {epoch}, Train loss: {np.mean(losses)}, rel ratio: {((rel== 0)).sum()/ len(rel)}, acc: {(torch.where(pred< 0.5, 0, 1)== torch.where(label< 0.5, 0, 1)).sum()/ label.shape[0]/ label.shape[1]}')
    results= evaluate(net, data_iter, params, 'valid')
    print(f'val mr: {results["mr"]}, mrr: {results["mrr"]}, hits10: {results["hits@10"]}')
    earlystopping4kg(-(results['mrr']+ results['hits@10']), net)
    if earlystopping4kg.flag== True: break
# kg valid
net= torch.load(f'{params.pt_file}//{params.pt_file_name_net1}')
print(f"valid results, {evaluate(net, data_iter, params, 'valid')}\ntest results, {evaluate(net, data_iter, params, 'test')}")

### Explain knowledge graph model

In [14]:
net= torch.load(f'{params.pt_file}//{params.pt_file_name_net1}', map_location= 'cpu')
ent_emb, rel_emb= net.ent_embed.weight.detach(), net.rel_embed.weight.detach()

In [None]:
# view relationship embedding
import matplotlib.pyplot as plt
u, s, v= torch.svd(rel_emb)
reduced_rel_emb= u[:, :2]
idx= torch.tensor([0, 3, 4, 5])
reduced_rel_emb= reduced_rel_emb[idx]
plt.figure(figsize=(4, 4))
# drug_micr_rel, 0; drug_inter_rel, 3; micr_inter_rel, 4; micr_drug_rel, 5.
labels= ['drug&micr', 'drug&drug', 'micr&micr', 'micr&drug']
for i, point in enumerate(reduced_rel_emb):
    plt.scatter(point[0], point[1])
    plt.annotate(labels[i], (point[0], point[1]))
plt.xlabel('Dimension 1')
plt.ylabel('Dimension 2')
plt.title('Visualization of Reduced Data')
plt.show()

In [None]:
import matplotlib.pyplot as plt
u, s, v= torch.svd(ent_emb)
reduced_ent_emb= u[:, :2]
plt.figure(figsize= (4, 4))
for i, point in enumerate(reduced_ent_emb[:1209]):
    plt.scatter(point[0], point[1], c= 'red', label= 'Drug' if i== 0 else "")
for i, point in enumerate(reduced_ent_emb[1209: 1209+ 172]):
    plt.scatter(point[0], point[1], c= 'green', label= 'Microbe' if i== 0 else "")
drug_center= reduced_ent_emb[:1209].mean(dim= 0)
micr_center= reduced_ent_emb[1209: 1209+ 172].mean(dim= 0)
plt.scatter(drug_center[0], drug_center[1], c= 'black', marker= 'x', s= 100, label= 'Drug center')
plt.scatter(micr_center[0], micr_center[1], c= 'black', marker= 'x', s= 100, label= 'Microbe center')
plt.plot([drug_center[0], micr_center[0]], [drug_center[1], micr_center[1]], c= 'blue')
# compute dist between center
distance= torch.sqrt(((drug_center- micr_center)** 2).sum())
midpoint= ((drug_center+ micr_center)/ 2).tolist()
plt.text(micr_center[0], midpoint[1], f'Distance: {distance.item():.3f}', color= 'blue', fontsize= 10)
plt.xlim(-0.08, 0.02)
plt.ylim(-0.08, 0.08)
plt.title('Drug and microbe scatter plot', fontsize= 12)
plt.legend(fontsize= 10)
plt.savefig('Drug and microbe scatter plot.png', dpi=300)
plt.show()

In [17]:
# fuse drug and drug-microbe relation
net.scm1= net.scm1.to('cpu')
net.scm2= net.scm2.to('cpu')
X1, X2, X3= net.ent_embed.weight, net.scm1@ net.ent_embed.weight, net.scm2@ net.ent_embed.weight
att_mat= torch.softmax(net.fc4att(torch.cat([X1, X2, X3], dim= 1)), dim= 1)
X= X1+ att_mat[:, 0].view(-1, 1)* X2+ att_mat[:, 1].view(-1, 1)* X3
src= torch.arange(0, 1209, 1)
rel= torch.zeros(1209).to(torch.long)
src_emb, rel_emb= X[src], net.rel_embed(rel)
fea1= torch.cat([src_emb.unsqueeze(1), rel_emb.unsqueeze(1)], dim= 1).transpose(1, 2)
drug_emb_fuse= net.fc2((net.fc1(fea1)+ net.fc1res(fea1)).squeeze(2))
micr_emb_fuse= net.ent_embed.weight[1209: 1209+ 172, :]

In [None]:
import matplotlib.pyplot as plt
u, s, v= torch.svd(drug_emb_fuse.detach().cpu())
reduced_drug_emb_fuse= u[:, :2]
u, s, v= torch.svd(micr_emb_fuse.detach().cpu())
reduced_micr_emb_fuse= u[:, :2]
plt.figure(figsize= (4, 4))
for i, point in enumerate(reduced_drug_emb_fuse):
    plt.scatter(point[0], point[1], c= 'red', label= 'Drug' if i== 0 else "")
for i, point in enumerate(reduced_micr_emb_fuse):
    plt.scatter(point[0], point[1], c= 'green', label= 'Microbe' if i== 0 else "")
drug_center= reduced_drug_emb_fuse.mean(dim= 0)
micr_center= reduced_micr_emb_fuse.mean(dim= 0)
plt.scatter(drug_center[0], drug_center[1], c= 'black', marker= 'x', s= 100, label= 'Drug center')
plt.scatter(micr_center[0], micr_center[1], c= 'black', marker= 'x', s= 100, label= 'Microbe center')
plt.plot([drug_center[0], micr_center[0]], [drug_center[1], micr_center[1]], c= 'blue')
# compute dist between center
distance= torch.sqrt(((drug_center- micr_center)** 2).sum())
midpoint= ((drug_center+ micr_center)/ 2).tolist()
plt.text(micr_center[0], midpoint[1], f'Distance: {distance.item():.3f}', color= 'blue', fontsize= 10)
plt.xlim(-0.08, 0.02)
plt.ylim(-0.08, 0.08)
plt.title('Relation-fused drug and microbe scatter plot', fontsize= 12)
plt.legend(loc='upper left', fontsize= 10)
plt.savefig('Relation-fused drug and microbe scatter plot.png', dpi=300)
plt.show()

In [44]:
class Conv2d(nn.Module):
    def __init__(self, in_ch, out_ch, ksa, ksb):
        super().__init__()
        assert ksb% 2== 1, 'kernel_size shoud be odd.'
        self.ksb= ksb
        self.conv= nn.Conv2d(in_ch, out_ch, kernel_size= (ksa, ksb))

    def forward(self, x):
        x_pad= F.pad(x, [(self.ksb- 1)// 2, (self.ksb- 1)// 2, 0, 0])
        return self.conv(x_pad)

In [45]:
class CrossEmbedLayer(nn.Module):
    def __init__(self, in_ch, out_ch, ksa= 2, ksbs= [1, 3, 7, 11], stride= 1):
        super().__init__()
        assert stride== 1, 'It is under development.'
        kernel_sizes= sorted(ksbs)
        num_scales= len(kernel_sizes)
        # calculate the dimension at each scale
        channel_scales= [int(out_ch / (2 ** i)) for i in range(1, num_scales)]
        channel_scales= [*channel_scales, out_ch- sum(channel_scales)]
        # conv
        self.convs= nn.ModuleList([])
        for ks, ch_scale in zip(ksbs, channel_scales):
            self.convs.append(Conv2d(in_ch, ch_scale, ksa= ksa, ksb= ks))
        self.conv_end= nn.Conv2d(out_ch, out_ch, 1)
    def forward(self, x):
        fmaps= tuple(map(lambda conv: conv(x), self.convs))
        return self.conv_end(torch.cat(fmaps, dim = 1))

In [46]:
class MPFI(torch.nn.Module):
    def __init__(self, params, ent_emb, mats):
        super().__init__()
        self.params, self.asso_mat, self.drug_inte_mat, self.drug_simi_mat= params, mats[0], mats[1], mats[2]
        self.drug_emb, self.micr_emb= ent_emb[0: self.params.num_drug], ent_emb[self.params.num_drug: self.params.num_drug+ self.params.num_micr]
        self.att1drug= nn.Parameter(torch.eye(params.embed_dim))
        self.fc4fuse= nn.Sequential(nn.Linear(5, 128), nn.ReLU(), nn.Linear(128, 1))
    def init_para(self):
        for layer in self.add_modulefc4fuse:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
    def forward(self, x, y):
        fea= torch.cat([(
            self.drug_emb[x]@ (self.att1drug)@ self.drug_emb.T).unsqueeze(-1),
            self.drug_simi_mat[x].unsqueeze(-1),
            self.drug_inte_mat[x].unsqueeze(-1),
            (self.drug_inte_mat[x]@ self.drug_inte_mat.T).unsqueeze(-1),
            (self.drug_simi_mat[x]@ self.drug_simi_mat.T).unsqueeze(-1)
        ], dim= -1)
        fea= (fea- fea.mean(dim= -1, keepdims= True))/ (fea.std(dim= -1, keepdims= True)+ 1e-10)
        att_mat4drug= self.fc4fuse(fea).squeeze()
        att_mat4drug[(self.asso_mat.T)[y]== 0]= -1e10
        att_mat4drug_= torch.cat([torch.ones((x.shape[0], 1)).to(self.params.device), att_mat4drug], dim= 1)
        att_mat4drug_= torch.softmax(att_mat4drug_, dim= 1)
        return torch.cat([self.drug_emb[x], self.micr_emb[y], self.drug_emb[x]- att_mat4drug_[:, 1:]@ self.drug_emb], dim= 1)

In [47]:
class BCE(torch.nn.Module):
    def __init__(self, params):
        super().__init__()
        self.params= params
        self.cv4fea1= nn.Sequential(CrossEmbedLayer(3, 256, ksa= 2, stride= 1), nn.ReLU(), nn.Dropout(0.5))
        self.fc4fea1= nn.Sequential(nn.Linear(1381, 1381), nn.ReLU(), nn.Dropout(0.5))
        self.fc4fea3= nn.Sequential(nn.Linear(1381, 256), nn.ReLU(), nn.Dropout(0.3))
        self.fc4gt= nn.Sequential(nn.Linear(512, 256), nn.Sigmoid(), nn.Dropout(0.3))
        self.init_para()
    def init_para(self):
#         nn.init.xavier_normal_(self.cv4fea1[0].weight, nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.fc4fea1[0].weight, nn.init.calculate_gain('relu'))  
        nn.init.xavier_normal_(self.fc4fea3[0].weight, nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.fc4gt[0].weight, nn.init.calculate_gain('sigmoid'))
    def forward(self, fea):
        fea= self.cv4fea1(fea)
        fea1= nn.functional.normalize(self.fc4fea1(fea.sum(dim=  1, keepdims= True)), p= 2, dim= -1)* fea+ fea
        fea2= nn.functional.normalize(fea1.sum(-1).squeeze(), p= 2, dim= -1)
        fea3= nn.functional.normalize(self.fc4fea3(fea1.sum(1).squeeze()), p= 2, dim= -1)
        gt= self.fc4gt(torch.cat([fea2, fea3], dim= 1))
        fea4= gt* fea2+ (1- gt)* fea3
        return fea2

In [48]:
class CVNet(torch.nn.Module):
    def __init__(self, params):
        super().__init__()
        self.params= params
        self.cv4fea= nn.Sequential(nn.Conv2d(3, 256, kernel_size= (2, 1)), nn.ReLU(), nn.Dropout(0.5))
        self.fc4fea1= nn.Sequential(nn.Linear(1381, 1381), nn.ReLU(), nn.Dropout(0.5))        
        self.fc4out= nn.Sequential(nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 2))
        self.loss_fuc= nn.CrossEntropyLoss(reduction= 'mean')
        self.init_para()
    def init_para(self):
        nn.init.xavier_normal_(self.cv4fea[0].weight, nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.fc4fea1[0].weight, nn.init.calculate_gain('relu'))        
        nn.init.xavier_normal_(self.fc4out[0].weight, nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.fc4out[2].weight)
    def loss(self, pred, label):
        return self.loss_fuc(pred, label)        
    def forward(self, fea):
        fea= self.cv4fea(fea)
        fea= nn.functional.normalize(self.fc4fea1(fea.sum(dim=  1, keepdims= True)), p= 2, dim= -1)* fea+ fea
        fea= nn.functional.normalize(fea.sum(-1).squeeze(), p= 2, dim= -1)        
        return self.fc4out(fea)

In [49]:
class PCMDA(torch.nn.Module):
    def __init__(self, params, ent_emb, mats):
        super().__init__()
        self.params= params
        self.bce, self.mpfi= BCE(params), MPFI(params, ent_emb, mats)
        self.fc4gt= nn.Sequential(nn.Linear(5* self.params.embed_dim, self.params.embed_dim), nn.ReLU(), nn.Dropout(0.3), nn.Linear(self.params.embed_dim, 2))
        self.loss_fuc= nn.CrossEntropyLoss(reduction= 'mean')
    def init_para(self):
        nn.init.xavier_normal_(self.fc4gt[0].weight, nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.fc4gt[3].weight)        
    def loss(self, pred, label):
        return self.loss_fuc(pred, label)
    def forward(self, x, y, fea):
        o1= self.bce(fea)
        o2= self.mpfi(x, y)
        return self.fc4gt(torch.cat([o1, o2], dim= -1))

In [None]:
# prepare data
params.device= 'cuda:0'
net= torch.load(f'{params.pt_file}//{params.pt_file_name_net1}', map_location=torch.device('cpu'))
X1, X2, X3= net.ent_embed.weight, net.scm1@ net.ent_embed.weight, net.scm2@ net.ent_embed.weight
att_mat= torch.softmax(net.fc4att(torch.cat([X1, X2, X3], dim= 1)), dim= 1)
ent_emb= (X1+ att_mat[:, 0].view(-1, 1)* X2+ att_mat[:, 1].view(-1, 1)* X3).detach().clone()
drug_inte_mat, drug_struct_simi_mat=  dl.drug_inter_mat, dl.drug_struct_simi_mat
mats= [dl.drug_micr_asso_mat_zy, drug_inte_mat, drug_struct_simi_mat]
mats= [mat.to(torch.float).to(params.device) for mat in mats]
# define network
net= PCMDA(params, ent_emb.to(params.device), mats).to(params.device)
optimizer= torch.optim.Adam(net.parameters(), lr= params.lr_net2, weight_decay= params.weight_decay_net2)
earlystopping= EarlyStopping(patience= params.patience_net2, pt_file= params.pt_file, file_name= params.pt_file_name_net2, mess_out= True)

In [51]:
def get_fea(params, xy, dl):
    # mask
    xys= []
    for i in range(xy.shape[0]):
        xys.append(torch.tensor([[i, 0, 0, 1209+ xy[i, 1]], [i, 0, 1, xy[i, 0]], [i, 1, 0, 1209+ xy[i, 1]], [i, 1, 1, xy[i, 0]], [i, 2, 0, 1209+ xy[i, 1]], [i, 2, 1, xy[i, 0]]]))
    xys= torch.cat(xys, dim= 0)
    # construct feature
    a= torch.cat([dl.drug_struct_simi_mat[xy[:, 0], :].unsqueeze(dim= 1).unsqueeze(dim= 1), dl.drug_micr_asso_mat_zy[:, xy[:, 1]].T.unsqueeze(dim= 1).unsqueeze(dim= 1)], dim= 2)
    c= torch.cat([0* dl.drug_asso_simi_mat[xy[:, 0], :].unsqueeze(dim= 1).unsqueeze(dim= 1), dl.drug_micr_asso_mat_zy[:, xy[:, 1]].T.unsqueeze(dim= 1).unsqueeze(dim= 1)], dim= 2)
    b= torch.cat([dl.drug_inter_mat[xy[:, 0], :].unsqueeze(dim= 1).unsqueeze(dim= 1), dl.drug_micr_asso_mat_zy[:, xy[:, 1]].T.unsqueeze(dim= 1).unsqueeze(dim= 1)], dim= 2)
    d= torch.cat([dl.drug_micr_asso_mat_zy[xy[:, 0], :].unsqueeze(dim= 1).unsqueeze(dim= 1), dl.micr_inter_mat[xy[:, 1], :].unsqueeze(dim= 1).unsqueeze(dim= 1)], dim= 2) 
    e= torch.cat([dl.drug_micr_asso_mat_zy[xy[:, 0], :].unsqueeze(dim= 1).unsqueeze(dim= 1), dl.micr_asso_simi_mat[xy[:, 1], :].unsqueeze(dim= 1).unsqueeze(dim= 1)], dim= 2)
    f= torch.cat([dl.drug_micr_asso_mat_zy[xy[:, 0], :].unsqueeze(dim= 1).unsqueeze(dim= 1), dl.micr_inte_simi_mat[xy[:, 1], :].unsqueeze(dim= 1).unsqueeze(dim= 1)], dim= 2)
    fea= torch.cat([torch.cat([a, f], dim= -1), torch.cat([b, d], dim= -1), torch.cat([c, e], dim= -1)], dim= 1).to(torch.float)
    fea[xys[:, 0], xys[:, 1], xys[:, 2], xys[:, 3]]= 0
    return fea.to(params.device)

In [52]:
def avg_auc_aupr_cpt(test_xy, test_label, pred, ass_mat_shape):
    label_mat, pred_mat= torch.zeros((ass_mat_shape)) -1, torch.zeros((ass_mat_shape)) -1
    label_mat[test_xy[:, 0], test_xy[:, 1]], pred_mat[test_xy[:, 0], test_xy[:, 1]]= test_label* 1.0, pred
    bool_mat4mark_test_examp= (label_mat!= -1)
    aucs, auprs= [], []
    for i in range(ass_mat_shape[0]):
        test_examp_loc= bool_mat4mark_test_examp[i]
        pos_num= label_mat[i, test_examp_loc].sum()
        if pos_num> 0 and (test_examp_loc).sum()- pos_num> 0:
            fpr4rowi, tpr4rowi, _= roc_curve(label_mat[i, test_examp_loc], pred_mat[i, test_examp_loc])
            prec4rowi, recall4rowi, _= precision_recall_curve(label_mat[i, test_examp_loc], pred_mat[i, test_examp_loc])
            prec4rowi[-1]= [1, 0][(int)(prec4rowi[-2]== 0)]
            aucs.append(auc(fpr4rowi, tpr4rowi));auprs.append(auc(recall4rowi, prec4rowi))
    return np.mean(aucs), np.mean(auprs)

In [None]:
for ep in range(params.epoch_net2):
    for step, (xy, label) in enumerate(train_loader):
        net.train()
        xy, label= xy, label.to(params.device)
        logp= net(xy[:, 0], xy[:, 1], get_fea(params, xy, dl))
        loss= net.loss(logp, label)
        optimizer.zero_grad();loss.backward();optimizer.step()
        if step% 500== 0:print(f'epoch: {ep+ 1}, step: {step+ 1}, train loss: {loss}, acc: {(torch.max(logp, dim= 1)[1]== label).float().sum()* 1.0/ len(xy)}')
    net.eval(); val_loss= 0; pred= []
    with torch.no_grad():
        for step, (xy, label) in enumerate(valid_loader):
            xy, label= xy, label.to(params.device)
            logp= net(xy[:, 0], xy[:, 1], get_fea(params, xy, dl))
            val_loss= net.loss(logp, label).item()
            pred.append(logp)
    pred= torch.cat(pred, dim= 0).cpu()
    roc_auc, aupr_auc= avg_auc_aupr_cpt(dl.valid_xy, valid_label, pred[:, 1], (1209, 172))
    print(f'epoch: {ep+ 1}, valid loss: {val_loss}, acc: {(torch.max(pred, dim= 1)[1]== valid_label).float().sum()* 1.0/ len(valid_label)}, auc: {roc_auc}, aupr: {aupr_auc}')
    earlystopping(-(roc_auc+ aupr_auc), net)
    if earlystopping.flag== True:print(f'early_stopping');break;
net= torch.load(f'{params.pt_file}//{params.pt_file_name_net2}')
net.eval(); pred= []
with torch.no_grad():
    for step, (xy, label) in enumerate(test_loader):
        xy, label= xy, label.to(params.device)
        logp= net(xy[:, 0], xy[:, 1], get_fea(params, xy, dl))
        pred.append(logp)
    pred= torch.cat(pred, dim= 0).cpu()
    precision, recall, threshold= precision_recall_curve(test_label, pred[:, 1])
    roc_auc, aupr= roc_auc_score(test_label, pred[:, 1]), auc(recall, precision)
    print(avg_auc_aupr_cpt(dl.test_xy, test_label, pred[:, 1], (1209, 172)))

### Explain Attention part

In [None]:
# Load model
net= torch.load(f'{params.pt_file}//{params.pt_file_name_net2}', map_location= torch.device('cpu'))
net.params.device= 'cpu'

In [None]:
# Load drug and microbe name files
drug_name= np.loadtxt('./mdd-dataset/mdd/drug/drug_name.txt', dtype= str, delimiter= '\t')
micr_name= np.loadtxt('./mdd-dataset/mdd/microbe/microbe_name.txt', dtype= str, delimiter= '\t')
# Drug name and microbe name
# Baicalin, 449
# Moxifloxacin, 880
# Ciprofloxacin, 554
search_drug_name= 'Ciprofloxacin'
# search_micr_name= 'Streptococcus mutans'
search_micr_name= 'Escherichia coli'
print(f'{search_drug_name}: {[i for i, item in enumerate(drug_name) if item== search_drug_name][0]}')
print(f'{search_micr_name}: {[i for i, item in enumerate(micr_name) if item== search_micr_name][0]}')

In [None]:
# Compute attn value between associated drugs and candidate drug
x= torch.tensor([554, 0], dtype= torch.long)
y= torch.tensor([62, 0], dtype= torch.long)
fea= torch.cat([(net.mpfi.drug_emb[x]@ (net.mpfi.att1drug+ torch.eye(net.mpfi.params.embed_dim).to(net.mpfi.params.device))@ net.mpfi.drug_emb.T/ net.mpfi.params.num_drug** 0.5).unsqueeze(-1),
                net.mpfi.drug_simi_mat[x].unsqueeze(-1), 
                net.mpfi.drug_inte_mat[x].unsqueeze(-1), 
                (net.mpfi.drug_inte_mat[x]@ net.mpfi.drug_inte_mat.T).unsqueeze(-1)/ net.mpfi.params.num_drug** 0.5, 
                (net.mpfi.drug_simi_mat[x]@ net.mpfi.drug_simi_mat.T).unsqueeze(-1)/ net.mpfi.params.num_drug** 0.5], dim= -1)
fea= (fea- fea.mean(dim= -1, keepdims= True))/ (fea.std(dim= -1, keepdims= True)+ 1e-10)
att_mat4drug= net.mpfi.fc4fuse(fea).squeeze()
att_mat4drug[(net.mpfi.asso_mat.T)[y]== 0]= -1e10
att_mat4drug_= torch.cat([torch.ones((x.shape[0], 1)).to(net.mpfi.params.device), att_mat4drug], dim= 1)
att_mat4drug_= torch.softmax(att_mat4drug_, dim= 1)
attn_vale= att_mat4drug_[0][1:].detach()

In [None]:
import matplotlib.pyplot as plt
sorted_indices= np.argsort(attn_vale.numpy())[::-1][1:17]
top_drug_names= drug_name[sorted_indices]
top_attn_values= attn_vale.numpy()[sorted_indices]
plt.rcParams['font.sans-serif']= ['Times New Roman']
plt.rcParams['axes.unicode_minus']= False
plt.figure(figsize= (8, 6))
bars= plt.bar(top_drug_names, top_attn_values, color= 'skyblue', edgecolor= 'gray')
# plt.xlabel('Drug Name')
plt.ylabel('Attention value', fontsize= 10)
plt.title('Attention map for associated drugs and candidate drug ciprofloxacin', fontsize= 12)
for bar, value in zip(bars, top_attn_values):
    plt.text(bar.get_x()+ bar.get_width()/ 2, bar.get_height(), f'{value:.3f}', ha= 'center', va= 'bottom', fontsize= 10)
plt.xticks(rotation= 75, fontsize= 10)
plt.tight_layout()
plt.grid(axis= 'y', linestyle= '--', alpha= 0.7)
plt.savefig('Attention map for associated drugs and candidate drug ciprofloxacin.png', dpi= 300)
plt.show()

In [None]:
search_drug_name= 'Piperacillin-tazobactam'
print(f'{search_drug_name}: {[i for i, item in enumerate(drug_name) if item== search_drug_name][0]}')

In [None]:
micr_name[(dl.drug_micr_asso_mat[1028]* dl.drug_micr_asso_mat[554]).to(torch.bool)]

### SHAP

In [None]:
import shap
# Load model
net= torch.load(f'{params.pt_file}//{params.pt_file_name_net2}', map_location= torch.device('cpu'))
net= net.to('cpu')
net.params.device= 'cpu'
params.device= 'cpu'
# xy_bg= dl.train_xy.clone()[85879: 85879+ 128]
xy_bg= dl.test_xy.clone()[987: 987+ 64]
background_feat= get_fea(params, xy_bg, dl)
exper= shap.DeepExplainer(net, background_feat)
# xy= torch.tensor([[554, 62], [554, 62]], dtype= torch.long)
xy= xy_bg[0: 2]
feat= get_fea(params, xy, dl)
# check_additivity, avoid additivity check.
shap_vales= exper.shap_values(feat, check_additivity= False)
print(f'len of shap_vales: {len(shap_vales)}, shape of shap_vales[0]: {shap_vales[0].shape}')

In [None]:
# Statistical feature analysis
for i in range(3):
    print(f'Premise{i+ 1} max:{shap_vales[0][1][i].max()}, min:{shap_vales[0][1][i].min()}, mean:{shap_vales[0][1][i].mean()}')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.sans-serif']= ['Times New Roman']
plt.rcParams['axes.unicode_minus']= False
min_value= shap_vales[0][1].min()
max_value= shap_vales[0][1].max()
for channel in range(0, 3, 1):
    plt.figure(figsize= (14, 2.8))
    heatmap_data= shap_vales[0][1][channel]
    ax= sns.heatmap(heatmap_data, cmap= "coolwarm", annot= False, vmin= min_value, vmax= max_value)
    plt.title(f"SHAP heatmap for ciprofloxacin and escherichia coli, biological premise{channel+ 1}", fontsize= 12)
    ax.set_yticklabels(['Ciprofloxacin', 'Escherichia coli'])
    plt.xticks(rotation= 75, fontsize= 10)
    plt.yticks(rotation= 35, fontsize= 10)
    plt.tight_layout()
    plt.savefig(f'SHAP heatmap for ciprofloxacin and escherichia coli, biological premise{channel+ 1}', dpi= 300)
    plt.show()

### heat_map analysis

In [None]:
channel1_important_feat_index= np.where(shap_vales[0][1][0]>= 0.04)
channel1_important_feat_index

In [None]:
print(f'{drug_name[channel1_important_feat_index[1][channel1_important_feat_index[1]< 1209]]} of similarity is important!')
print(f'similarity value: {dl.drug_struct_simi_mat[554, channel1_important_feat_index[1][channel1_important_feat_index[1]< 1209]]}.')
print(f'association value: {dl.drug_micr_asso_mat[channel1_important_feat_index[1][channel1_important_feat_index[1]< 1209], 62]} of those drug and microbe')

In [None]:
print(f'{micr_name[channel1_important_feat_index[1][channel1_important_feat_index[1]>= 1209]- 1209]} of association is important!')
print(f'association value: {dl.drug_micr_asso_mat[554, channel1_important_feat_index[1][channel1_important_feat_index[1]>= 1209]- 1209]}')
print(f'similarity value: {dl.micr_inte_simi_mat[62, channel1_important_feat_index[1][channel1_important_feat_index[1]>= 1209]- 1209]}')

In [None]:
channel2_important_feat_index= np.where(shap_vales[0][1][1]>= 0.04)
channel2_important_feat_index

In [None]:
channel3_important_feat_index= np.where(shap_vales[0][1][2]>= 0.05)
channel3_important_feat_index

In [None]:
print(f'{micr_name[channel3_important_feat_index[1]- 1209]} of microbe association similarity is important!')
print(f'association similarity value: {dl.micr_asso_simi_mat[62, channel3_important_feat_index[1][channel3_important_feat_index[1]>= 1209]- 1209]}')
print(f'drug association value: {dl.drug_micr_asso_mat[554, channel3_important_feat_index[1][channel3_important_feat_index[1]>= 1209]- 1209]}')