# For Colab

In [1]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
cd
drive / MyDrive / Project / Bioinformatics / Microbe\ Disease\ Association / Previous\ Work\:\ KGNMDA / SKGNMDA

/content/drive/MyDrive/Project/Bioinformatics/Microbe Disease Association/Previous Work : KGNMDA/SKGNMDA


# Prerequirements

In [3]:
dataset = 'mdkg_hmdad'

In [4]:
from tensorflow.python.client import device_lib

print(device_lib.list_local_devices())

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 1825593657987859404
xla_global_id: -1
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 14357954560
locality {
  bus_id: 1
  links {
  }
}
incarnation: 9317435490282877389
physical_device_desc: "device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5"
xla_global_id: 416903419
]


# Load Data

In [5]:
from src.config import DISEASE_MICROBE_EXAMPLE, PROCESSED_DATA_DIR
from src.utils import format_filename
import numpy as np

examples_file = format_filename(
    PROCESSED_DATA_DIR, DISEASE_MICROBE_EXAMPLE, dataset=dataset
)
examples = np.load(examples_file)

In [6]:
examples.shape

(898, 3)

In [7]:
examples[:3, ]

array([[50863, 33211,     1],
       [43621, 40832,     1],
       [33293, 47880,     1]])

In [8]:
from src.data import MicrobeDiseaseData

data = MicrobeDiseaseData([examples[:, :1], examples[:, 1:2]], examples[:, 2:3].reshape(-1))

In [9]:
from keras import backend as K
from src.config import MICROBE_SIMILARITY_FILE, DISEASE_SIMILARITY_FILE, PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE
import pandas as pd
from src.utils import pickle_load
import tensorflow as tf

microbe_similarity_df = pd.read_csv(MICROBE_SIMILARITY_FILE, index_col=0)
disease_similarity_df = pd.read_csv(DISEASE_SIMILARITY_FILE, index_col=0)

entity_vocab_size = len(
    pickle_load(
        format_filename(PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, dataset=dataset)
    )
)

microbe_similarity_matrix = np.zeros((entity_vocab_size, microbe_similarity_df.shape[1]), dtype="float64")
disease_similarity_matrix = np.zeros((entity_vocab_size, disease_similarity_df.shape[1]), dtype="float64")

for i, row in microbe_similarity_df.iterrows():
    for j in range(len(row)):
        microbe_similarity_matrix[i][j] = row[j]

for i, row in disease_similarity_df.iterrows():
    for j in range(len(row)):
        disease_similarity_matrix[i][j] = row[j]

microbe_similarity_matrix = tf.Variable(microbe_similarity_matrix,
                                        name='pre_term_microbe_embedding',
                                        dtype='float32',
                                        trainable=False)
disease_similarity_matrix = tf.Variable(disease_similarity_matrix,
                                        name='pre_term_disease_embedding',
                                        dtype='float32',
                                        trainable=False)

Logging Info - Loaded: /content/drive/MyDrive/Project/Bioinformatics/Microbe Disease Association/Previous Work : KGNMDA/SKGNMDA/data_repository/processed/mdkg_hmdad_entity_vocab.pkl


In [10]:
def get_first_term_embedding(x):
    microbe_pre_embed = K.gather(microbe_similarity_matrix, K.cast(x, dtype='int64'))
    return microbe_pre_embed


def get_second_term_embedding(x):
    disease_pre_embed = K.gather(disease_similarity_matrix, K.cast(x, dtype='int64'))
    return disease_pre_embed

# Configure Model

In [11]:
from src.config import KGCNModelConfig

kgcn_config = KGCNModelConfig()

kgcn_config.model_name = 'Previous 1'
kgcn_config.embed_dim = 32
kgcn_config.neighbor_sample_size = 8
kgcn_config.n_depth = 2
kgcn_config.l2_weight = 0.01
kgcn_config.aggregator_type = 'sum'

In [12]:
kgcn_config.get_summary()

{'model_name': 'Previous 1',
 'embed_dim': 32,
 'neighbor_sample_size': 8,
 'n_depth': 2,
 'l2_weight': 0.01,
 'aggregator_type': 'sum'}

# Configure Data

In [13]:
from src.config import DataConfig, PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, \
    RELATION_VOCAB_TEMPLATE, ADJ_ENTITY_TEMPLATE, ADJ_RELATION_TEMPLATE
from src.utils import pickle_load, format_filename
import numpy as np

data_config = DataConfig()

data_config.entity_vocab_size = len(
    pickle_load(
        format_filename(PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, dataset=dataset)
    )
)  # the size of entity_vocab

data_config.relation_vocab_size = len(
    pickle_load(
        format_filename(
            PROCESSED_DATA_DIR, RELATION_VOCAB_TEMPLATE, dataset=dataset
        )
    )
)  # the size of relation_vocab

data_config.adj_entity = np.load(
    format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE, dataset=dataset)
)  # load adj_entity matrix

data_config.adj_relation = np.load(
    format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE, dataset=dataset)
)  # load adj_relation matrix


Logging Info - Loaded: /content/drive/MyDrive/Project/Bioinformatics/Microbe Disease Association/Previous Work : KGNMDA/SKGNMDA/data_repository/processed/mdkg_hmdad_entity_vocab.pkl
Logging Info - Loaded: /content/drive/MyDrive/Project/Bioinformatics/Microbe Disease Association/Previous Work : KGNMDA/SKGNMDA/data_repository/processed/mdkg_hmdad_relation_vocab.pkl


In [14]:
data_config.get_summary()

{'entity_vocab_size': 66911, 'relation_vocab_size': 39}

# Bulid Model

In [15]:
from src.models.graph_models import PairKGCN

model = PairKGCN(kgcn_config=kgcn_config,
                 data_config=data_config)

KerasTensor(type_spec=TensorSpec(shape=(None, 32), dtype=tf.float32, name=None), name='lambda_1/Squeeze:0', description="created by layer 'lambda_1'")


In [16]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 second_input (InputLayer)   [(None, 1)]                  0         []                            
                                                                                                  
 receptive_filed_for_second  [(None, 1),                  0         ['second_input[0][0]']        
 _ent (Lambda)                (None, 8),                                                          
                              (None, 64)]                                                         
                                                                                                  
 receptive_filed_for_second  [(None, 8),                  0         ['second_input[0][0]']        
 _rel (Lambda)                (None, 64)]                                                     

# Configure Optimizer

In [17]:
from base.config import OptimizerConfig
from src.config import MODEL_SAVED_DIR

In [18]:
optimizer_config = OptimizerConfig()
optimizer_config.optimizer = 'adam'
optimizer_config.lr = 1e-3
optimizer_config.batch_size = 32
optimizer_config.n_epoch = 50
optimizer_config.checkpoint_dir = MODEL_SAVED_DIR
optimizer_config.callbacks_to_add = []

# Train Model

In [19]:
from src.optimization.optimization import KGCNTrainer

In [20]:
trainer = KGCNTrainer()
result = trainer.train(model, data, optimizer_config, [])

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
Logging Info - Training time: 00:00:24


In [21]:
result.get_result()

{'AUC': 0.8934901116561922,
 'ACC': 0.8084632516703786,
 'F1 Score': 0.8036529680365297,
 'AUPR': 0.9055901195546971}

# Cross Validation

In [22]:
from src.data import MicrobeDiseaseTrainTestSplit

train_test_spliter = MicrobeDiseaseTrainTestSplit(examples=examples,
                                                  with_gaussian_similarity=True)

In [23]:
from src.optimization.optimization import KGCNTrainer, KGCNTester
from src.models.graph_models import PairKGCNFactory

trainer = KGCNTrainer()
tester = KGCNTester()
factory = PairKGCNFactory(kgcn_config,
                          data_config,
                          first_term_size=291,  #291
                          second_term_size=39)  #39

In [24]:
from base.optimization import cross_validation

cross_validation(k=5,
                 data_size=len(examples),
                 train_test_spliter=train_test_spliter,
                 model_factory=factory,
                 trainer=trainer,
                 tester=tester,
                 optimization_config=optimizer_config)


Logging Info - Fold 1 >>>>>>>>>>>>>>

test_indices: [2, 3, 11, 22, 27, 30, 46, 56, 58, 61, 66, 68, 70, 73, 85, 94, 101, 106, 109, 110, 113, 114, 115, 119, 120, 124, 127, 132, 140, 151, 159, 165, 166, 169, 177, 189, 192, 195, 197, 198, 199, 202, 204, 219, 220, 225, 226, 227, 228, 234, 238, 239, 241, 244, 246, 247, 248, 254, 261, 266, 273, 275, 276, 282, 285, 286, 291, 296, 300, 301, 303, 314, 320, 321, 322, 325, 331, 335, 344, 348, 349, 366, 368, 381, 384, 387, 396, 397, 398, 399, 412, 415, 416, 419, 424, 425, 429, 439, 442, 446, 450, 453, 455, 456, 459, 460, 467, 475, 476, 477, 484, 486, 499, 503, 505, 506, 515, 523, 525, 530, 532, 540, 544, 547, 550, 558, 562, 564, 567, 573, 579, 584, 587, 598, 601, 602, 604, 605, 635, 654, 660, 663, 666, 673, 676, 685, 686, 689, 691, 711, 712, 713, 724, 733, 748, 759, 763, 770, 784, 786, 790, 801, 807, 815, 823, 829, 830, 840, 844, 851, 855, 869, 878, 880, 881, 885, 891, 894, 897]
train_indices: [0, 1, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 1

<base.evaluation.Result at 0x7adcc420e470>

In [25]:
a = {1, 2, 3}
list(a)

[1, 2, 3]

In [26]:
test_indices = {9, 11, 25, 26, 27, 29, 40, 45, 50, 57, 59, 66, 82, 86, 88, 96, 105, 106, 107, 114, 116, 121, 122, 134,
                137, 146, 148, 149, 157, 160, 161, 162, 164, 166, 175, 177, 185, 186, 187, 188, 190, 197, 198, 203, 218,
                222, 237, 239, 240, 245, 246, 250, 261, 262, 269, 270, 272, 281, 284, 292, 298, 301, 302, 303, 305, 306,
                307, 309, 311, 312, 314, 318, 320, 331, 336, 343, 349, 354, 359, 361, 368, 372, 375, 384, 386, 387, 388,
                390, 392, 407, 409, 410, 412, 414, 417, 418, 420, 434, 441, 443, 446, 448, 454, 455, 456, 465, 468, 471,
                474, 486, 490, 491, 526, 532, 544, 550, 555, 557, 558, 563, 564, 574, 578, 581, 583, 586, 588, 589, 590,
                600, 602, 614, 616, 620, 624, 628, 630, 632, 638, 640, 651, 660, 663, 664, 671, 672, 680, 681, 690, 695,
                700, 705, 708, 718, 721, 723, 725, 727, 731, 742, 748, 751, 757, 768, 774, 786, 788, 790, 791, 796, 797,
                809, 827, 834, 835, 838, 844, 849, 864, 878, 885, 894}
train_indices = {0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 28, 30, 31, 32, 33,
                 34, 35, 36, 37, 38, 39, 41, 42, 43, 44, 46, 47, 48, 49, 51, 52, 53, 54, 55, 56, 58, 60, 61, 62, 63, 64,
                 65, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 83, 84, 85, 87, 89, 90, 91, 92, 93, 94,
                 95, 97, 98, 99, 100, 101, 102, 103, 104, 108, 109, 110, 111, 112, 113, 115, 117, 118, 119, 120, 123,
                 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 135, 136, 138, 139, 140, 141, 142, 143, 144, 145,
                 147, 150, 151, 152, 153, 154, 155, 156, 158, 159, 163, 165, 167, 168, 169, 170, 171, 172, 173, 174,
                 176, 178, 179, 180, 181, 182, 183, 184, 189, 191, 192, 193, 194, 195, 196, 199, 200, 201, 202, 204,
                 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 219, 220, 221, 223, 224, 225, 226,
                 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 238, 241, 242, 243, 244, 247, 248, 249, 251, 252,
                 253, 254, 255, 256, 257, 258, 259, 260, 263, 264, 265, 266, 267, 268, 271, 273, 274, 275, 276, 277,
                 278, 279, 280, 282, 283, 285, 286, 287, 288, 289, 290, 291, 293, 294, 295, 296, 297, 299, 300, 304,
                 308, 310, 313, 315, 316, 317, 319, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 332, 333, 334,
                 335, 337, 338, 339, 340, 341, 342, 344, 345, 346, 347, 348, 350, 351, 352, 353, 355, 356, 357, 358,
                 360, 362, 363, 364, 365, 366, 367, 369, 370, 371, 373, 374, 376, 377, 378, 379, 380, 381, 382, 383,
                 385, 389, 391, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 408, 411, 413,
                 415, 416, 419, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 435, 436, 437, 438,
                 439, 440, 442, 444, 445, 447, 449, 450, 451, 452, 453, 457, 458, 459, 460, 461, 462, 463, 464, 466,
                 467, 469, 470, 472, 473, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 487, 488, 489, 492,
                 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512,
                 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 527, 528, 529, 530, 531, 533, 534,
                 535, 536, 537, 538, 539, 540, 541, 542, 543, 545, 546, 547, 548, 549, 551, 552, 553, 554, 556, 559,
                 560, 561, 562, 565, 566, 567, 568, 569, 570, 571, 572, 573, 575, 576, 577, 579, 580, 582, 584, 585,
                 587, 591, 592, 593, 594, 595, 596, 597, 598, 599, 601, 603, 604, 605, 606, 607, 608, 609, 610, 611,
                 612, 613, 615, 617, 618, 619, 621, 622, 623, 625, 626, 627, 629, 631, 633, 634, 635, 636, 637, 639,
                 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 652, 653, 654, 655, 656, 657, 658, 659, 661, 662,
                 665, 666, 667, 668, 669, 670, 673, 674, 675, 676, 677, 678, 679, 682, 683, 684, 685, 686, 687, 688,
                 689, 691, 692, 693, 694, 696, 697, 698, 699, 701, 702, 703, 704, 706, 707, 709, 710, 711, 712, 713,
                 714, 715, 716, 717, 719, 720, 722, 724, 726, 728, 729, 730, 732, 733, 734, 735, 736, 737, 738, 739,
                 740, 741, 743, 744, 745, 746, 747, 749, 750, 752, 753, 754, 755, 756, 758, 759, 760, 761, 762, 763,
                 764, 765, 766, 767, 769, 770, 771, 772, 773, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785,
                 787, 789, 792, 793, 794, 795, 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 810, 811, 812,
                 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 828, 829, 830, 831, 832, 833,
                 836, 837, 839, 840, 841, 842, 843, 845, 846, 847, 848, 850, 851, 852, 853, 854, 855, 856, 857, 858,
                 859, 860, 861, 862, 863, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 879, 880,
                 881, 882, 883, 884, 886, 887, 888, 889, 890, 891, 892, 893, 895, 896, 897}


In [27]:
len(train_indices)

716