In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install torch_geometric



In [3]:
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html

Looking in links: https://data.pyg.org/whl/torch-2.1.0+cu121.html


In [4]:
cd /content/drive/MyDrive/Academic/Topics/AI/Machine\ Learning\ Dr.\ Montazeri/Project/ml_mda

/content/drive/MyDrive/Academic/Topics/AI/Machine Learning Dr. Montazeri/Project/ml_mda


# Requirement

In [5]:
import logging
import sys

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.StreamHandler(stream=sys.stdout)
    ],
    force=True
)

In [6]:
from src.models import MDMDAClassifier, MDMDAClassifierFactory
from src.optimization import MDMDAClassifierTrainer, MDMDAClassifierTester
from src.config import MatrixDecomposerConfig, MDMDAClassifierOptimizerConfig
from src.data import MicrobeDiseaseAssociationData, MicrobeDiseaseAssociationTrainTestSpliter
from src.utils import train_test_sampler
from src.features import get_associations, get_entities
from base import cross_validation

2024-02-10 17:34:30,560 [INFO] NumExpr defaulting to 2 threads.


# Classification

## Data

In [7]:
associations = get_associations()

train_indices, test_indices = train_test_sampler(associations.shape[0], 0.7)

data = MicrobeDiseaseAssociationData(associations)

train_data = MicrobeDiseaseAssociationData(associations.iloc[train_indices])
test_data = MicrobeDiseaseAssociationData(associations.iloc[test_indices])

## Classifier

In [8]:
microbe_ids = get_entities().loc[get_entities()['type'] == 'Microbe']['id'].tolist()
disease_ids = get_entities().loc[get_entities()['type'] == 'Disease']['id'].tolist()

In [9]:
classifier_config = MatrixDecomposerConfig()
classifier_config.model_name = "NMF MDA Classifier"
classifier_config.microbe_ids = microbe_ids
classifier_config.disease_ids = disease_ids
classifier_config.n_components = 5
classifier_config.random_state = 1
classifier_config.decomposer = 'PCA'

In [10]:
mda_classifier = MDMDAClassifier(classifier_config)

2024-02-10 17:34:33,438 [INFO] Initializing MCMDAClassifier
2024-02-10 17:34:33,443 [INFO] Initializing MatrixFeatureExtractor
2024-02-10 17:34:33,446 [INFO] Initializing MFFeatureExtractor with model : None and decomposer : PCA


## Optimizer

In [11]:
optimizer_config = MDMDAClassifierOptimizerConfig()
optimizer_config.conv_threshold = 4.5
optimizer_config.exp_name = "Optimizer for NMF MDA Classifier"
optimizer_config.threshold = 0.5

## Train Test Approach

### Train

In [12]:
train_result = MDMDAClassifierTrainer().train(mda_classifier, train_data, optimizer_config)

2024-02-10 17:34:33,467 [INFO] Call Training with Optimizer for NMF MDA Classifier
2024-02-10 17:34:33,479 [INFO] Calling build with associations :      disease  microbe  increased
677     9724     8766          0
344    12403    58532          1
628    43372    11010          0
786    13213    61523          0
676    55164    45746          0
..       ...      ...        ...
335    43621    64598          1
474    37496    18423          0
93     50863    31268          1
462    66623    34574          0
274    48777     6432          1

[628 rows x 3 columns]
2024-02-10 17:34:33,567 [INFO] interaction matrix with shape (5179, 5645) has built
2024-02-10 17:34:33,866 [INFO] mask matrix with shape (5179, 5645) has built. This matrix shows not non elements.
2024-02-10 17:34:35,679 [INFO] interaction has been imputed to delete nans
2024-02-10 17:34:40,836 [INFO] interation 1 mse : 275.84308113140946
2024-02-10 17:34:46,038 [INFO] interation 2 mse : 75.14309280883317
2024-02-10 17:34:49,03

### Test

In [13]:
test_result = MDMDAClassifierTester().test(model=mda_classifier,
                                            data=test_data,
                                            config=optimizer_config)

2024-02-10 17:36:49,319 [INFO] y_test has been built : [1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1] ...
2024-02-10 17:36:49,421 [INFO] y_predict has been built : [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, -0.10757890337453661, -0.01902042728512248, 0.0, 0.10660305886768465, 1.0, 1.0, -0.00468862661361349, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, -0.006972876342092436] ...
2024-02-10 17:36:49,437 [INFO] Test Result : {'AUC': 0.8018077601410936, 'ACC': 0.8444444444444444, 'F1 Score': 0.8434047724259832, 'AUPR': 0, 'Loss': 0}


In [14]:
test_result.get_result()

{'AUC': 0.8018077601410936,
 'ACC': 0.8444444444444444,
 'F1 Score': 0.8434047724259832,
 'AUPR': 0,
 'Loss': 0}

## Cross Validation

In [15]:
trainer = MDMDAClassifierTrainer()
tester = MDMDAClassifierTester()
factory = MDMDAClassifierFactory(classifier_config)
spliter = MicrobeDiseaseAssociationTrainTestSpliter(data.associations)
cross_validation(k=5, data_size=data.associations.shape[0], train_test_spliter=spliter, model_factory=factory, trainer=trainer, tester=tester, config=optimizer_config)

2024-02-10 17:36:49,459 [INFO] Initializing MCMDAClassifierFactory
2024-02-10 17:36:49,463 [INFO] Initializing MicrobeDiseaseAssociationTrainTestSpliter
2024-02-10 17:36:49,465 [INFO] Start 5-fold Cross Validation with config : Optimizer for NMF MDA Classifier
2024-02-10 17:36:49,469 [INFO] ---- Fold 1 ----
2024-02-10 17:36:49,472 [INFO] Initializing MCMDAClassifier
2024-02-10 17:36:49,474 [INFO] Initializing MatrixFeatureExtractor
2024-02-10 17:36:49,475 [INFO] Initializing MFFeatureExtractor with model : None and decomposer : PCA
2024-02-10 17:36:49,476 [INFO] Call Training with Optimizer for NMF MDA Classifier
2024-02-10 17:36:49,483 [INFO] Calling build with associations :      disease  microbe  increased
2      33293    47880          1
3      13213    53186          1
4      33293    14909          1
6      12403    26565          1
7      43621    57454          1
..       ...      ...        ...
892    22068    20153          0
893    64642    53920          0
894    25026    6

<base.evaluation.Result at 0x7a49ddb70b20>