In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import numpy as np
import os
import re
import pandas as pd
import scipy.sparse as sp
import torch as th

#import dgl
#from dgl.data.utils import download, extract_archive, get_download_dir

from itertools import product
from collections import Counter
from copy import deepcopy
from sklearn.model_selection import KFold
from tqdm import tqdm
from sklearn.metrics import accuracy_score

import random
random.seed(1234)
np.random.seed(1234)


In [3]:
def load_data(directory):
    D_SSM = np.loadtxt(directory + '/D_SM.txt')


    M_FSM = np.loadtxt(directory + '/M_SM.txt')

    print('D_SSM',D_SSM)
    print('M_FSM',M_FSM)

    ID = np.zeros(shape=(D_SSM.shape[0], D_SSM.shape[1]))
    IM = np.zeros(shape=(M_FSM.shape[0], M_FSM.shape[1]))
    for i in range(D_SSM.shape[0]):
        for j in range(D_SSM.shape[1]):
            if D_SSM[i][j] == 0:
                ID[i][j] = D_GSM[i][j]###
            else:
                ID[i][j] = D_SSM[i][j]
    for i in range(M_FSM.shape[0]):
        for j in range(M_FSM.shape[1]):
            if M_FSM[i][j] == 0:
                IM[i][j] = M_GSM[i][j]##3
            else:
                IM[i][j] = M_FSM[i][j]
                
    ID = pd.DataFrame(ID).reset_index()
    IM = pd.DataFrame(IM).reset_index()
    print('ID',ID)
    print('IM',IM)
    ID.rename(columns = {'index':'id'}, inplace = True)
    IM.rename(columns = {'index':'id'}, inplace = True)
    ID['id'] = ID['id'] + 1
    IM['id'] = IM['id'] + 1
    print('ID',ID)
    print('IM',IM)
    #print(ID.shape)
    #print(IM.shape)
    return ID, IM


In [4]:
def sample(directory, random_seed):
    all_associations = pd.read_csv(directory + '/drug_mutation_pairs.csv', names=['Drug', 'Mutation', 'label'])
    known_associations = all_associations.loc[all_associations['label'] == 1]
    unknown_associations = all_associations.loc[all_associations['label'] == 0]
    random_negative = unknown_associations.sample(n=known_associations.shape[0], random_state=random_seed, axis=0)

    sample_df = pd.concat([known_associations, random_negative], ignore_index=True)
    sample_df.reset_index(drop=True, inplace=True)
    #print(sample_df)
                 
    return sample_df

In [5]:
def obtain_data(directory, isbalance):
    ID, IM = load_data(directory)
    
    if isbalance:
        dtp = sample(directory, random_seed = 1234)
    else:
        dtp = pd.read_csv(directory + '/drug_mutation_pairs.csv', names=['Drug', 'Mutation', 'label'])
        
    mirna_ids = list(set(dtp['Drug']))
    disease_ids = list(set(dtp['Mutation']))
    
    print('mirna_ids',len(mirna_ids))
    print('disease_ids',len(disease_ids))
    random.shuffle(mirna_ids)
    random.shuffle(disease_ids)
    print('# Drug = {} | Mutation = {}'.format(len(mirna_ids), len(disease_ids)))

    mirna_test_num = int(len(mirna_ids) / 5)
    disease_test_num = int(len(disease_ids) / 5)
    print('# Test: Drug = {} | Mutation = {}'.format(mirna_test_num, disease_test_num))
    
    #print(ID.shape)
    #print(IM.shape)
    #print('dtp',dtp)
    #cf=pd.merge(dtp, IM, left_on = 'miRNA', right_on = 'id')
    #print('cf',cf)
    #print('ID',ID)
    #df=pd.merge(cf, ID, left_on = 'disease', right_on = 'id')
    #print('df',df)
    #print(dtp.to_csv('C:/Users/Administrator/Desktop/图采样data/text/dtp.csv'))
    #print(ID.to_csv('C:/Users/Administrator/Desktop/图采样data/text/ID.csv'))
    #print(IM.to_csv('C:/Users/Administrator/Desktop/图采样data/text/IM.csv'))
    
    knn_x = pd.merge(pd.merge(dtp, ID, left_on = 'Drug', right_on = 'id'), IM, left_on = 'Mutation', right_on = 'id')
    #print('knn_x',knn_x)
    label = dtp['label']
    knn_x.drop(labels = ['Drug', 'Mutation', 'label', 'id_x', 'id_y'], axis = 1, inplace = True)
    assert ID.shape[0] + IM.shape[0] == knn_x.shape[1]
    #print(knn_x.shape, Counter(label))
    #print(label.shape)
    return ID, IM, dtp, mirna_ids, disease_ids, mirna_test_num, disease_test_num, knn_x, label

In [6]:
def generate_task_Tp_train_test_idx(knn_x):
    kf = KFold(n_splits = 5, shuffle = True, random_state = 1234)

    train_index_all, test_index_all, n = [], [], 0
    train_id_all, test_id_all = [], []
    fold = 0
    for train_idx, test_idx in tqdm(kf.split(knn_x)): #train_index与test_index为下标
        print('-------Fold ', fold)
        train_index_all.append(train_idx) 
        test_index_all.append(test_idx)

        train_id_all.append(np.array(dtp.iloc[train_idx][['Drug', 'Mutation']]))
        test_id_all.append(np.array(dtp.iloc[test_idx][['Drug', 'Mutation']]))

        print('# Pairs: Train = {} | Test = {}'.format(len(train_idx), len(test_idx)))
        fold += 1
    return train_index_all, test_index_all, train_id_all, test_id_all

In [7]:
def generate_task_Tm_Td_train_test_idx(item, ids, dtp):
    
    test_num = int(len(ids) / 5)
    
    train_index_all, test_index_all = [], []
    train_id_all, test_id_all = [], []
    
    for fold in range(5):
        print('-------Fold ', fold)
        if fold != 4:
            test_ids = ids[fold * test_num : (fold + 1) * test_num]
        else:
            test_ids = ids[fold * test_num :]

        train_ids = list(set(ids) ^ set(test_ids))
        print('# {}: Train = {} | Test = {}'.format(item, len(train_ids), len(test_ids)))

        test_idx = dtp[dtp[item].isin(test_ids)].index.tolist()
        train_idx = dtp[dtp[item].isin(train_ids)].index.tolist()
        
        random.shuffle(test_idx)
        random.shuffle(train_idx)
        print('# Pairs: Train = {} | Test = {}'.format(len(train_idx), len(test_idx)))
        
        assert len(train_idx) + len(test_idx) == len(dtp)

        train_index_all.append(train_idx) 
        test_index_all.append(test_idx)
        train_id_all.append(train_ids)
        test_id_all.append(test_ids)
        
    train_index_all = np.array(train_index_all, dtype=object)
    test_index_all = np.array(test_index_all, dtype=object)
    train_id_all = np.array(train_id_all, dtype=object)
    test_id_all = np.array(test_id_all, dtype=object)
        
    return train_index_all, test_index_all, train_id_all, test_id_all

# KNN

In [8]:
from sklearn.neighbors import KNeighborsClassifier

In [9]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score, auc
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report

In [10]:
def generate_knn_graph_save(knn_x, label, n_neigh, train_index_all, test_index_all, pwd, task, balance):
    
    fold = 0
    knn_y = None  # 初始化 knn_y 以避免引用未赋值的情况
    knn_neighbors_graph = None  # 也初始化 knn_neighbors_graph
    knn = None  # 确保 knn 变量被初始化
    
    for train_idx, test_idx in zip(train_index_all, test_index_all): 
        print('-------Fold ', fold)
        
        knn_y = deepcopy(label)  ###深层复制label
        knn_y[test_idx] = 0
        
        print('Label: ', Counter(label))
        print('knn_y: ', Counter(knn_y))


        # 确保 knn_x 的列名是字符串类型
        knn_x.columns = knn_x.columns.astype(str)
        knn = KNeighborsClassifier(n_neighbors = n_neigh)

        # 使用训练集的特征和标签
        knn_train_x = knn_x.iloc[train_idx]  # 选择训练集数据
        knn_train_y = knn_y[train_idx]      # 选择训练集标签
        
        # 检查训练数据的形状，确保有数据
        print(f"Training data shape: {knn_train_x.shape}, Training labels shape: {knn_train_y.shape}")
        
        # 确保 KNN 分类器的训练数据和标签都有数据
        if knn_train_x.shape[0] > 0 and knn_train_y.shape[0] > 0:
            knn.fit(knn_train_x, knn_train_y)  # 训练模型
        else:
            print(f"Warning: Fold {fold} has no training data")
            continue  # 跳过这个fold，防止错误

        # 预测和计算邻接图
        knn_neighbors_graph = knn.kneighbors_graph(knn_train_x, n_neighbors=n_neigh)  # 使用训练集计算 KNN 图


        '''
        # 确保 knn_x 的列名是字符串类型
        knn_x.columns = knn_x.columns.astype(str)
        # 然后继续使用 KNeighborsClassifier 训练模型
        #knn.fit(knn_x, knn_y)

        #knn_y_pred = knn.predict(knn_x)
        #knn_y_prob = knn.predict_proba(knn_x)
        #knn_neighbors_graph = knn.kneighbors_graph(knn_x, n_neighbors = n_neigh)
        '''


        
        #print(knn_neighbors_graph)
        #print(knn_y)
        #print(accuracy_score(knn_y, knn_y_pred))

        #prec_reca_f1_supp_report = classification_report(knn_y, knn_y_pred, target_names = ['label_0', 'label_1'])
        #tn, fp, fn, tp = confusion_matrix(knn_y, knn_y_pred).ravel()        
        
        
        #prec_reca_f1_supp_report = classification_report(knn_y, knn_y_pred, target_names = ['label_0', 'label_1','label_-1'])
        #print(prec_reca_f1_supp_report)
        #cf=confusion_matrix(knn_y, knn_y_pred,labels=["0", "1","-1"])
        #print(cf)
        #tn, fp, fn, tp = confusion_matrix(knn_y, knn_y_pred,labels=['label_0', 'label_1','label_-1']).ravel()

        #pos_acc = tp / sum(knn_y)
        #neg_acc = tn / (len(knn_y_pred) - sum(knn_y_pred)) # [y_true=0 & y_pred=0] / y_pred=0
        #accuracy = (tp+tn)/(tn+fp+fn+tp)

        #recall = tp / (tp+fn)
        #precision = tp / (tp+fp)
        #f1 = 2*precision*recall / (precision+recall)

        #roc_auc = roc_auc_score(knn_y, knn_y_prob[:, 1])
        #prec, reca, _ = precision_recall_curve(knn_y, knn_y_prob[:, 1])
        #aupr = auc(reca, prec)

        #print('acc={:.4f}|precision={:.4f}|recall={:.4f}|f1={:.4f}|auc={:.4f}|aupr={:.4f}|pos_acc={:.4f}|neg_acc={:.4f}'.format(accuracy, precision, recall, f1, roc_auc, aupr, pos_acc, neg_acc))
        #print('tn = {}, fp = {}, fn = {}, tp = {}'.format(tn, fp, fn, tp))
        #print('y_pred: ', Counter(knn_y_pred))
        #print('y_true: ', Counter(knn_y))
#         print('knn_score = {:.4f}'.format(knn.score(knn_x, knn_y)))
        knn_graph_file = os.path.join(pwd, f'task_{task}{balance}__testlabel0_knn{n_neigh}neighbors_edge__fold{fold}.npz')
        sp.save_npz(knn_graph_file, knn_neighbors_graph)

        #sp.save_npz(pwd + 'task_' + task + balance + '__testlabel0_knn' + str(n_neigh) + 'neighbors_edge__fold' + str(fold) + '.npz', knn_neighbors_graph)
        fold += 1
    return knn_x, knn_y, knn, knn_neighbors_graph

# Run

In [13]:
import itertools
output_dir = './Data pre-processing result/balance/'
os.makedirs(output_dir, exist_ok=True)

for isbalance in [True]:
#for isbalance in [False, True]:
    print('************isbalance = ', isbalance)
    
    #for task in ['Td']:
    for task in ['Tp','Td', 'Tm']:
        print('=================task = ', task)
        
        ID, IM, dtp, mirna_ids, disease_ids, mirna_test_num, disease_test_num, knn_x, label = obtain_data('./Data', isbalance)

        if task == 'Tp':
            train_index_all, test_index_all, train_id_all, test_id_all = generate_task_Tp_train_test_idx(knn_x)
        elif task == 'Tm':
            item = 'Drug'
            ids = mirna_ids
            train_index_all, test_index_all, train_id_all, test_id_all = generate_task_Tm_Td_train_test_idx(item, ids, dtp)
        elif task == 'Td':
            item = 'Mutation'
            ids = disease_ids
            train_index_all, test_index_all, train_id_all, test_id_all = generate_task_Tm_Td_train_test_idx(item, ids, dtp)

        if isbalance:
            balance = ''
        else:
            balance = '__nobalance'



        # 确保将数据转换为 NumPy 数组后再保存
        train_index_all = np.array(train_index_all, dtype=object)
        test_index_all = np.array(test_index_all, dtype=object)
        train_id_all = np.array(train_id_all, dtype=object)
        test_id_all = np.array(test_id_all, dtype=object)

        # 打印形状
        print(f'train_index_all shape: {train_index_all.shape}')
        print(f'test_index_all shape: {test_index_all.shape}')
        print(f'train_id_all shape: {train_id_all.shape}')
        print(f'test_id_all shape: {test_id_all.shape}')
        


        np.savez_compressed('./Data pre-processing result/balance/task_' + task + balance + '__testlabel0_knn_edge_train_test_index_all.npz', 
                               train_index_all = train_index_all, 
                               test_index_all = test_index_all,
                               train_id_all = train_id_all, 
                               test_id_all = test_id_all)

        pwd = './Data pre-processing result/balance/'
        for n_neigh in [1, 3, 5, 7, 10, 15]: 
            print('--------------------------n_neighbors = ', n_neigh)
            knn_x, knn_y, knn, knn_neighbors_graph = generate_knn_graph_save(knn_x, label, n_neigh, train_index_all, test_index_all, pwd, task, balance)

#load data path：            
directory='./Data'
#ID,IM=load_data(directory)
#print(ID)
dtp.to_csv('./Data pre-processing result/balance/dtp.csv')
node_feature_label = pd.concat([dtp, knn_x], axis = 1)
node_feature_label



#pwd = './Data pre-processing result/balance/'
#node_feature_label.to_csv(pwd + 'node_feature_label.csv')

************isbalance =  True
D_SSM [[1.         0.89766196 0.9695392  ... 0.93836794 0.94563883 0.93286166]
 [0.89766196 1.         0.97386686 ... 0.99178252 0.98852405 0.9811674 ]
 [0.9695392  0.97386686 1.         ... 0.98927445 0.9959657  0.98758327]
 ...
 [0.93836794 0.99178252 0.98927445 ... 1.         0.9946604  0.9929528 ]
 [0.94563883 0.98852405 0.9959657  ... 0.9946604  1.         0.99153622]
 [0.93286166 0.9811674  0.98758327 ... 0.9929528  0.99153622 1.        ]]
M_FSM [[1.         0.90766937 0.17608963 ... 0.89524413 0.01611704 0.03146066]
 [0.90766937 1.         0.0183536  ... 0.97583942 0.08079069 0.07906328]
 [0.17608963 0.0183536  1.         ... 0.04580561 0.83450014 0.84573587]
 ...
 [0.89524413 0.97583942 0.04580561 ... 1.         0.10737693 0.12459924]
 [0.01611704 0.08079069 0.83450014 ... 0.10737693 1.         0.99315765]
 [0.03146066 0.07906328 0.84573587 ... 0.12459924 0.99315765 1.        ]]
ID      index         0         1         2         3         4       

5it [00:00, 1672.50it/s]

-------Fold  0
# Pairs: Train = 1336 | Test = 334
-------Fold  1
# Pairs: Train = 1336 | Test = 334
-------Fold  2
# Pairs: Train = 1336 | Test = 334
-------Fold  3
# Pairs: Train = 1336 | Test = 334
-------Fold  4
# Pairs: Train = 1336 | Test = 334
train_index_all shape: (5, 1336)
test_index_all shape: (5, 334)
train_id_all shape: (5, 1336, 2)
test_id_all shape: (5, 334, 2)
--------------------------n_neighbors =  1
-------Fold  0
Label:  Counter({1: 835, 0: 835})
knn_y:  Counter({0: 1005, 1: 665})
Training data shape: (1336, 845), Training labels shape: (1336,)
-------Fold  1
Label:  Counter({1: 835, 0: 835})
knn_y:  Counter({0: 1009, 1: 661})
Training data shape: (1336, 845), Training labels shape: (1336,)
-------Fold  2
Label:  Counter({1: 835, 0: 835})
knn_y:  Counter({0: 994, 1: 676})
Training data shape: (1336, 845), Training labels shape: (1336,)
-------Fold  3
Label:  Counter({1: 835, 0: 835})
knn_y:  Counter({0: 1002, 1: 668})
Training data shape: (1336, 845), Training labels




-------Fold  4
Label:  Counter({1: 835, 0: 835})
knn_y:  Counter({0: 1000, 1: 670})
Training data shape: (1336, 845), Training labels shape: (1336,)
--------------------------n_neighbors =  3
-------Fold  0
Label:  Counter({1: 835, 0: 835})
knn_y:  Counter({0: 1005, 1: 665})
Training data shape: (1336, 845), Training labels shape: (1336,)
-------Fold  1
Label:  Counter({1: 835, 0: 835})
knn_y:  Counter({0: 1009, 1: 661})
Training data shape: (1336, 845), Training labels shape: (1336,)
-------Fold  2
Label:  Counter({1: 835, 0: 835})
knn_y:  Counter({0: 994, 1: 676})
Training data shape: (1336, 845), Training labels shape: (1336,)
-------Fold  3
Label:  Counter({1: 835, 0: 835})
knn_y:  Counter({0: 1002, 1: 668})
Training data shape: (1336, 845), Training labels shape: (1336,)
-------Fold  4
Label:  Counter({1: 835, 0: 835})
knn_y:  Counter({0: 1000, 1: 670})
Training data shape: (1336, 845), Training labels shape: (1336,)
--------------------------n_neighbors =  5
-------Fold  0
Label:

Unnamed: 0,Drug,Mutation,label,0_x,1_x,2_x,3_x,4_x,5_x,6_x,...,651,652,653,654,655,656,657,658,659,660
0,1,154,1,1.000000,0.897662,0.969539,0.977718,0.895714,0.955645,0.946404,...,0.643398,0.962466,0.841347,0.933204,0.675166,0.100557,0.366918,0.560881,0.693739,0.656848
1,2,528,1,0.897662,1.000000,0.973867,0.967525,0.998237,0.974889,0.990946,...,0.755256,0.241065,0.507345,0.026673,0.721578,0.990111,0.921796,0.116548,0.702949,0.725836
2,3,451,1,0.969539,0.973867,1.000000,0.997122,0.973182,0.986712,0.994273,...,0.990810,0.774353,0.956297,0.461909,0.996606,0.606847,0.868337,0.129613,0.997275,0.995295
3,5,277,1,0.895714,0.998237,0.973182,0.964158,1.000000,0.966592,0.990725,...,0.962755,0.525239,0.788234,0.216128,0.941587,0.831987,0.953660,0.027768,0.935119,0.936123
4,5,564,1,0.895714,0.998237,0.973182,0.964158,1.000000,0.966592,0.990725,...,0.900798,0.919938,0.991238,0.656184,0.934664,0.423648,0.752314,0.312864,0.927177,0.926781
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1665,117,130,0,0.985151,0.875795,0.946596,0.963267,0.865546,0.956866,0.922686,...,0.289497,0.802112,0.543908,0.987593,0.326497,0.003631,0.106396,0.859960,0.343587,0.319277
1666,33,541,0,0.995515,0.900413,0.968720,0.979297,0.894133,0.962970,0.945407,...,0.816304,0.406510,0.654416,0.119561,0.812865,0.928408,0.979666,0.166884,0.778358,0.816435
1667,30,486,0,0.986313,0.955002,0.995024,0.995447,0.955632,0.979596,0.985605,...,0.139093,0.643617,0.387044,0.844397,0.191773,0.146936,0.161266,0.988111,0.178353,0.197048
1668,57,426,0,0.969931,0.961463,0.982292,0.991723,0.953832,0.992798,0.980442,...,0.441928,0.418078,0.462504,0.322790,0.480621,0.708095,0.703823,0.607898,0.425159,0.489884


In [14]:
node_feature_label = pd.concat([dtp, knn_x], axis = 1)
node_feature_label

Unnamed: 0,Drug,Mutation,label,0_x,1_x,2_x,3_x,4_x,5_x,6_x,...,651,652,653,654,655,656,657,658,659,660
0,1,154,1,1.000000,0.897662,0.969539,0.977718,0.895714,0.955645,0.946404,...,0.643398,0.962466,0.841347,0.933204,0.675166,0.100557,0.366918,0.560881,0.693739,0.656848
1,2,528,1,0.897662,1.000000,0.973867,0.967525,0.998237,0.974889,0.990946,...,0.755256,0.241065,0.507345,0.026673,0.721578,0.990111,0.921796,0.116548,0.702949,0.725836
2,3,451,1,0.969539,0.973867,1.000000,0.997122,0.973182,0.986712,0.994273,...,0.990810,0.774353,0.956297,0.461909,0.996606,0.606847,0.868337,0.129613,0.997275,0.995295
3,5,277,1,0.895714,0.998237,0.973182,0.964158,1.000000,0.966592,0.990725,...,0.962755,0.525239,0.788234,0.216128,0.941587,0.831987,0.953660,0.027768,0.935119,0.936123
4,5,564,1,0.895714,0.998237,0.973182,0.964158,1.000000,0.966592,0.990725,...,0.900798,0.919938,0.991238,0.656184,0.934664,0.423648,0.752314,0.312864,0.927177,0.926781
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1665,117,130,0,0.985151,0.875795,0.946596,0.963267,0.865546,0.956866,0.922686,...,0.289497,0.802112,0.543908,0.987593,0.326497,0.003631,0.106396,0.859960,0.343587,0.319277
1666,33,541,0,0.995515,0.900413,0.968720,0.979297,0.894133,0.962970,0.945407,...,0.816304,0.406510,0.654416,0.119561,0.812865,0.928408,0.979666,0.166884,0.778358,0.816435
1667,30,486,0,0.986313,0.955002,0.995024,0.995447,0.955632,0.979596,0.985605,...,0.139093,0.643617,0.387044,0.844397,0.191773,0.146936,0.161266,0.988111,0.178353,0.197048
1668,57,426,0,0.969931,0.961463,0.982292,0.991723,0.953832,0.992798,0.980442,...,0.441928,0.418078,0.462504,0.322790,0.480621,0.708095,0.703823,0.607898,0.425159,0.489884


In [16]:
pwd = './Data pre-processing result/balance/'
node_feature_label.to_csv(pwd + 'node_feature_label.csv')