In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install torch_geometric



In [3]:
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html

Looking in links: https://data.pyg.org/whl/torch-2.1.0+cu121.html


In [4]:
cd /content/drive/MyDrive/Academic/Topics/AI/Machine\ Learning\ Dr.\ Montazeri/Project/ml_mda

/content/drive/MyDrive/Academic/Topics/AI/Machine Learning Dr. Montazeri/Project/ml_mda


# Requirement

In [5]:
import logging
import sys

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.StreamHandler(stream=sys.stdout)
    ],
    force=True
)

In [6]:
from src.models import MDMDAClassifier, MDMDAClassifierFactory
from src.optimization import MDMDAClassifierTester, MDMDAClassifierTrainer
from src.config import MatrixDecomposerConfig, MDMDAClassifierOptimizerConfig
from src.data import MicrobeDiseaseAssociationData, MicrobeDiseaseAssociationTrainTestSpliter
from src.utils import train_test_sampler
from src.features import get_associations, get_entities
from base import cross_validation

2024-02-10 21:03:51,504 [INFO] NumExpr defaulting to 2 threads.


# Classification

## Data

In [7]:
associations = get_associations()

train_indices, test_indices = train_test_sampler(associations.shape[0], 0.7)

data = MicrobeDiseaseAssociationData(associations)

train_data = MicrobeDiseaseAssociationData(associations.iloc[train_indices])
test_data = MicrobeDiseaseAssociationData(associations.iloc[test_indices])

## Classifier

In [8]:
microbe_ids = get_entities().loc[get_entities()['type'] == 'Microbe']['id'].tolist()
disease_ids = get_entities().loc[get_entities()['type'] == 'Disease']['id'].tolist()

In [9]:
classifier_config = MatrixDecomposerConfig()
classifier_config.model_name = "NMF MDA Classifier"
classifier_config.microbe_ids = microbe_ids
classifier_config.disease_ids = disease_ids
classifier_config.n_components = 10
classifier_config.random_state = 1
classifier_config.decomposer = 'NMF'

In [10]:
mda_classifier = MDMDAClassifier(classifier_config)

2024-02-10 21:03:52,189 [INFO] Initializing MCMDAClassifier
2024-02-10 21:03:52,191 [INFO] Initializing MatrixFeatureExtractor
2024-02-10 21:03:52,193 [INFO] Initializing MFFeatureExtractor with model : None and decomposer : NMF


## Optimizer

In [11]:
optimizer_config = MDMDAClassifierOptimizerConfig()
optimizer_config.conv_threshold = 5
optimizer_config.exp_name = "Optimizer for NMF MDA Classifier"
optimizer_config.threshold = 0.5

## Train Test Approach

### Train

In [12]:
train_result = MDMDAClassifierTrainer().train(mda_classifier, train_data, optimizer_config)

2024-02-10 21:03:52,214 [INFO] Call Training with Optimizer for NMF MDA Classifier
2024-02-10 21:03:52,226 [INFO] Calling build with associations :      disease  microbe  increased
728    44112    37477          0
262    33293    13641          1
127    12403    20754          1
804    63129    20627          0
661     1667    46605          0
..       ...      ...        ...
329      654    15670          1
35     43621    27509          1
846    10506     9788          0
228    12403    28933          1
43     50863    65017          1

[628 rows x 3 columns]
2024-02-10 21:03:52,312 [INFO] interaction matrix with shape (5179, 5645) has built
2024-02-10 21:03:52,607 [INFO] mask matrix with shape (5179, 5645) has built. This matrix shows not non elements.
2024-02-10 21:03:54,355 [INFO] interaction has been imputed to delete nans




2024-02-10 21:04:54,169 [INFO] interation 1 mse : 30.506063740307727




2024-02-10 21:05:51,652 [INFO] interation 2 mse : 21.43196538647058




2024-02-10 21:06:48,181 [INFO] interation 3 mse : 16.25441876864512




2024-02-10 21:07:43,760 [INFO] interation 4 mse : 12.899859076083429




2024-02-10 21:08:38,965 [INFO] interation 5 mse : 10.590660148943828




2024-02-10 21:09:34,010 [INFO] interation 6 mse : 8.918546151173015




2024-02-10 21:10:27,499 [INFO] interation 7 mse : 7.644743875043131




2024-02-10 21:11:22,094 [INFO] interation 8 mse : 6.656693754719345




2024-02-10 21:12:16,919 [INFO] interation 9 mse : 5.86358891402908




2024-02-10 21:13:12,127 [INFO] interation 10 mse : 5.207539417544243




2024-02-10 21:14:04,502 [INFO] interation 11 mse : 4.639842029828027
2024-02-10 21:14:04,507 [INFO] training finished


### Test

In [13]:
test_result = MDMDAClassifierTester().test(model=mda_classifier,
                                            data=test_data,
                                            config=optimizer_config)

2024-02-10 21:14:04,528 [INFO] y_test has been built : [0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1] ...
2024-02-10 21:14:04,725 [INFO] y_predict has been built : [0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.7816286454109487, 0.0, 0.9852962209881905, 0.0, 0.0, 0.0, 1.0, 1.0, 0.2055831622000405, 1.0, 0.9888979533412917, 0.0, 0.0, 1.0] ...
2024-02-10 21:14:04,749 [INFO] Test Result : {'AUC': 0.972936512306591, 'ACC': 0.9444444444444444, 'F1 Score': 0.9442233270441118, 'AUPR': 0, 'Loss': 0}


In [14]:
test_result.get_result()

{'AUC': 0.972936512306591,
 'ACC': 0.9444444444444444,
 'F1 Score': 0.9442233270441118,
 'AUPR': 0,
 'Loss': 0}

## Cross Validation

In [15]:
trainer = MDMDAClassifierTrainer()
tester = MDMDAClassifierTester()
factory = MDMDAClassifierFactory(classifier_config)
spliter = MicrobeDiseaseAssociationTrainTestSpliter(data.associations)
cross_validation(k=5, data_size=data.associations.shape[0], train_test_spliter=spliter, model_factory=factory, trainer=trainer, tester=tester, config=optimizer_config)

2024-02-10 21:14:04,774 [INFO] Initializing MCMDAClassifierFactory
2024-02-10 21:14:04,776 [INFO] Initializing MicrobeDiseaseAssociationTrainTestSpliter
2024-02-10 21:14:04,779 [INFO] Start 5-fold Cross Validation with config : Optimizer for NMF MDA Classifier
2024-02-10 21:14:04,782 [INFO] ---- Fold 1 ----
2024-02-10 21:14:04,785 [INFO] Initializing MCMDAClassifier
2024-02-10 21:14:04,787 [INFO] Initializing MatrixFeatureExtractor
2024-02-10 21:14:04,789 [INFO] Initializing MFFeatureExtractor with model : None and decomposer : NMF
2024-02-10 21:14:04,790 [INFO] Call Training with Optimizer for NMF MDA Classifier
2024-02-10 21:14:04,796 [INFO] Calling build with associations :      disease  microbe  increased
0      50863    33211          1
1      43621    40832          1
3      13213    53186          1
4      33293    14909          1
5      33293    35937          1
..       ...      ...        ...
892    22068    20153          0
893    64642    53920          0
894    25026    6



2024-02-10 21:15:00,608 [INFO] interation 1 mse : 32.45964921551163




2024-02-10 21:15:55,361 [INFO] interation 2 mse : 26.089845769794994




2024-02-10 21:16:50,289 [INFO] interation 3 mse : 19.09443632470318




2024-02-10 21:17:45,574 [INFO] interation 4 mse : 15.424695153343821




2024-02-10 21:18:38,783 [INFO] interation 5 mse : 13.57306334387447




2024-02-10 21:19:33,006 [INFO] interation 6 mse : 11.249892133411816




2024-02-10 21:20:28,031 [INFO] interation 7 mse : 9.968871269021957




2024-02-10 21:21:23,098 [INFO] interation 8 mse : 8.415667964487378




2024-02-10 21:22:18,274 [INFO] interation 9 mse : 7.262489299706241




2024-02-10 21:23:11,595 [INFO] interation 10 mse : 6.376818019113093




2024-02-10 21:24:07,042 [INFO] interation 11 mse : 5.664226064042676




2024-02-10 21:25:02,488 [INFO] interation 12 mse : 5.067347766077726




2024-02-10 21:25:58,014 [INFO] interation 13 mse : 4.549787060023141
2024-02-10 21:25:58,015 [INFO] training finished
2024-02-10 21:25:58,021 [INFO] y_test has been built : [1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0] ...
2024-02-10 21:25:58,083 [INFO] y_predict has been built : [1.0557261344200606, 0.1689410818530739, 0.8538583041465231, 0.9918047278894817, 0.4201508427313654, 0.8477535961745123, 0.2700216377536503, 0.6846135397190409, 0.8718949546246714, 0.35328288410589215, 0.5724263921105697, 1.0834826309461498, 0.09510477224452228, 0.1104713620861786, 0.9923820589636334, 0.3005046583531824, 0.5482508808229457, 0.0, 0.6669160710539743, 0.047589493694501946] ...
2024-02-10 21:25:58,104 [INFO] Test Result : {'AUC': 0.8611944027986007, 'ACC': 0.8044692737430168, 'F1 Score': 0.8044448606298966, 'AUPR': 0, 'Loss': 0}
2024-02-10 21:25:58,105 [INFO] Result of fold 1 : {'AUC': 0.8611944027986007, 'ACC': 0.8044692737430168, 'F1 Score': 0.8044448606298966, 'AUPR': 0, 'Loss': 



2024-02-10 21:26:55,221 [INFO] interation 1 mse : 29.32492623852403




2024-02-10 21:27:49,968 [INFO] interation 2 mse : 19.456056749226736




2024-02-10 21:28:44,321 [INFO] interation 3 mse : 14.933450083848316




2024-02-10 21:29:40,147 [INFO] interation 4 mse : 12.773994367523755




2024-02-10 21:30:35,693 [INFO] interation 5 mse : 10.964560860680297




2024-02-10 21:31:31,385 [INFO] interation 6 mse : 9.160867513289174




2024-02-10 21:32:26,624 [INFO] interation 7 mse : 8.038162542269134




2024-02-10 21:33:19,863 [INFO] interation 8 mse : 7.0391595733537615




2024-02-10 21:34:14,958 [INFO] interation 9 mse : 6.309258508586742




2024-02-10 21:35:10,494 [INFO] interation 10 mse : 5.558137259471841




2024-02-10 21:36:05,345 [INFO] interation 11 mse : 4.986086753681932
2024-02-10 21:36:05,347 [INFO] training finished
2024-02-10 21:36:05,353 [INFO] y_test has been built : [0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1] ...
2024-02-10 21:36:05,417 [INFO] y_predict has been built : [0.08238535354506005, 0.33179083074975857, 0.9564468475228407, 0.13343825594096204, 0.37629984719596765, 0.04447212373251924, 0.04200662412309571, 0.217541747316019, 0.7272262793912792, 0.9697297156501801, 0.08260715827451048, 1.1960116004654124, 0.0, 0.0, 0.8357465707585107, 0.7418112166775359, 0.7891410483040356, 0.5272625992030466, 1.0367500297744932, 1.1103757812557538] ...
2024-02-10 21:36:05,429 [INFO] Test Result : {'AUC': 0.8522408963585434, 'ACC': 0.7988826815642458, 'F1 Score': 0.794881588999236, 'AUPR': 0, 'Loss': 0}
2024-02-10 21:36:05,431 [INFO] Result of fold 2 : {'AUC': 0.8522408963585434, 'ACC': 0.7988826815642458, 'F1 Score': 0.794881588999236, 'AUPR': 0, 'Loss': 0}
2024-02-10 2



2024-02-10 21:37:02,545 [INFO] interation 1 mse : 35.50658571289591




2024-02-10 21:37:56,541 [INFO] interation 2 mse : 24.960746916356012




2024-02-10 21:38:50,115 [INFO] interation 3 mse : 18.58868602117466




2024-02-10 21:39:45,330 [INFO] interation 4 mse : 14.08141773376028




2024-02-10 21:40:40,700 [INFO] interation 5 mse : 11.041613015837465




2024-02-10 21:41:35,829 [INFO] interation 6 mse : 8.946576683927152




2024-02-10 21:42:29,260 [INFO] interation 7 mse : 7.440467638805057




2024-02-10 21:43:24,346 [INFO] interation 8 mse : 6.2905069222825105




2024-02-10 21:44:20,048 [INFO] interation 9 mse : 5.3795182807297754




2024-02-10 21:45:15,554 [INFO] interation 10 mse : 4.648689160842466
2024-02-10 21:45:15,555 [INFO] training finished
2024-02-10 21:45:15,560 [INFO] y_test has been built : [0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0] ...
2024-02-10 21:45:15,638 [INFO] y_predict has been built : [0.1461561452947493, 0.2632823300729358, 0.8510376186693911, 0.7635214094853038, 0.0, 1.0103868247596672, 0.1494566787852188, 0.3204675765706818, 0.48996658409916566, 0.3699045268613879, 0.16836289470038773, 0.9781707708545734, 0.31727319881485444, 0.2659655098294812, 0.6452735711858482, 0.4314459239787531, 0.5624743640762927, 0.0628295508774719, 0.3944256835148102, 0.08045076430090146] ...
2024-02-10 21:45:15,650 [INFO] Test Result : {'AUC': 0.8807116104868915, 'ACC': 0.7932960893854749, 'F1 Score': 0.7930635838150288, 'AUPR': 0, 'Loss': 0}
2024-02-10 21:45:15,652 [INFO] Result of fold 3 : {'AUC': 0.8807116104868915, 'ACC': 0.7932960893854749, 'F1 Score': 0.7930635838150288, 'AUPR': 0, 'Loss': 



2024-02-10 21:46:12,534 [INFO] interation 1 mse : 32.37949083601859




2024-02-10 21:47:06,729 [INFO] interation 2 mse : 21.37013213966815




2024-02-10 21:48:00,077 [INFO] interation 3 mse : 19.542159404937646




2024-02-10 21:48:55,100 [INFO] interation 4 mse : 14.337232646430403




2024-02-10 21:49:49,817 [INFO] interation 5 mse : 12.333289297080677




2024-02-10 21:50:44,896 [INFO] interation 6 mse : 10.39716853303389




2024-02-10 21:51:37,356 [INFO] interation 7 mse : 8.826149270270356




2024-02-10 21:52:32,473 [INFO] interation 8 mse : 7.705655079415957




2024-02-10 21:53:27,555 [INFO] interation 9 mse : 6.782209158170792




2024-02-10 21:54:23,193 [INFO] interation 10 mse : 6.023143744744676




2024-02-10 21:55:18,919 [INFO] interation 11 mse : 5.377780752513439




2024-02-10 21:56:12,405 [INFO] interation 12 mse : 4.814411752568075
2024-02-10 21:56:12,411 [INFO] training finished
2024-02-10 21:56:12,416 [INFO] y_test has been built : [0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1] ...
2024-02-10 21:56:12,533 [INFO] y_predict has been built : [0.0, 0.8733441109420745, 0.6256775281565695, 1.1534484408790144, 0.6912969023546787, 1.0194275517945381, 0.0, 0.0, 0.8833673883546277, 0.4758483178850738, 0.44933496271467915, 0.38315328942093596, 0.643159815209396, 0.2200525550949404, 1.0217279173158782, 0.6617187930547154, 0.1454066710109258, 1.02032230635017, 1.0392148107671777, 0.7114287604140008] ...
2024-02-10 21:56:12,552 [INFO] Test Result : {'AUC': 0.8873522755846115, 'ACC': 0.7988826815642458, 'F1 Score': 0.7988261738261737, 'AUPR': 0, 'Loss': 0}
2024-02-10 21:56:12,557 [INFO] Result of fold 4 : {'AUC': 0.8873522755846115, 'ACC': 0.7988826815642458, 'F1 Score': 0.7988261738261737, 'AUPR': 0, 'Loss': 0}
2024-02-10 21:56:12,562 [INFO] -



2024-02-10 21:57:08,381 [INFO] interation 1 mse : 32.203220403357655




2024-02-10 21:58:03,507 [INFO] interation 2 mse : 21.633928253658134




2024-02-10 21:58:58,912 [INFO] interation 3 mse : 19.85552028138665




2024-02-10 21:59:54,019 [INFO] interation 4 mse : 14.681669013374854




2024-02-10 22:00:48,801 [INFO] interation 5 mse : 11.881331386148332




2024-02-10 22:01:41,953 [INFO] interation 6 mse : 10.144086648136136




2024-02-10 22:02:36,535 [INFO] interation 7 mse : 14.595135068663268
2024-02-10 22:02:36,537 [INFO] training finished
2024-02-10 22:02:36,545 [INFO] y_test has been built : [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] ...
2024-02-10 22:02:36,613 [INFO] y_predict has been built : [0.9305616405263825, 0.9070033734148907, 0.8799434550567753, 0.8884326722791112, 1.0715298963550886, 0.624905375506367, 0.8287310788988537, 0.9338074186785231, 0.8677082889586158, 1.1286292945228849, 0.3130211813355816, 0.7754029936175756, 0.5152637359136676, 0.5368573607835606, 1.0514093920472576, 1.1147287330769915, 0.8119815566216587, 0.8796177914640178, 0.15849595271474295, 0.8756697983646963] ...
2024-02-10 22:02:36,626 [INFO] Test Result : {'AUC': 0.8916223404255319, 'ACC': 0.8461538461538461, 'F1 Score': 0.8461352657004831, 'AUPR': 0, 'Loss': 0}
2024-02-10 22:02:36,628 [INFO] Result of fold 5 : {'AUC': 0.8916223404255319, 'ACC': 0.8461538461538461, 'F1 Score': 0.8461352657004831, 'AUPR': 

<base.evaluation.Result at 0x7c1f30237e20>