In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.4.0-py3-none-any.whl (1.0 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/1.0 MB[0m [31m2.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.4.0


In [3]:
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html

Looking in links: https://data.pyg.org/whl/torch-2.1.0+cu121.html
Collecting pyg_lib
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu121/pyg_lib-0.4.0%2Bpt21cu121-cp310-cp310-linux_x86_64.whl (2.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m47.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_scatter
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu121/torch_scatter-2.1.2%2Bpt21cu121-cp310-cp310-linux_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m61.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_sparse
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu121/torch_sparse-0.6.18%2Bpt21cu121-cp310-cp310-linux_x86_64.whl (5.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m81.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_cluster
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu121/torch_cluster-1.6.3%2Bp

In [4]:
cd /content/drive/MyDrive/Academic/Topics/AI/Machine\ Learning\ Dr.\ Montazeri/Project/ml_mda

/content/drive/MyDrive/Academic/Topics/AI/Machine Learning Dr. Montazeri/Project/ml_mda


# Requirements

In [5]:
import logging
import sys

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.StreamHandler(stream=sys.stdout)
    ],
    force=True
)

In [6]:
logger = logging.getLogger(__name__)

In [7]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
from src.optimization import MatrixFeatureBasedMDAClassifierTrainer, MatrixFeatureBasedMDAClassifierTester
from src.config import MatrixDecomposerConfig, SimpleClassifierConfig
from src.models import MDFeatureBasedMDAClassifier, MDFeatureBasedMDAClassifierFactory
from src.data import MicrobeDiseaseAssociationData, MicrobeDiseaseAssociationTrainTestSpliter
from src.features import get_associations, get_entities
from src.utils import train_test_sampler
from base import cross_validation, OptimizerConfig


2024-02-10 18:29:31,749 [INFO] NumExpr defaulting to 2 threads.


# Classification

## Data

In [9]:
associations = get_associations()

train_indices, test_indices = train_test_sampler(associations.shape[0], 0.7)

data = MicrobeDiseaseAssociationData(associations)

train_data = MicrobeDiseaseAssociationData(associations.iloc[train_indices])
test_data = MicrobeDiseaseAssociationData(associations.iloc[test_indices])

## Classifier

In [10]:
microbe_ids = get_entities().loc[get_entities()['type'] == 'Microbe']['id'].tolist()
disease_ids = get_entities().loc[get_entities()['type'] == 'Disease']['id'].tolist()

In [11]:
md_config = MatrixDecomposerConfig()
md_config.model_name = "NMF MDA Classifier"
md_config.microbe_ids = microbe_ids
md_config.disease_ids = disease_ids
md_config.n_components = 40
md_config.random_state = 1
md_config.decomposer = 'NMF'

In [12]:
simple_classifier_config = SimpleClassifierConfig()
simple_classifier_config.model_name = "simple classifier"
simple_classifier_config.input_dim = md_config.n_components * 2
simple_classifier_config.hidden_dim = 32
simple_classifier_config.output_dim = 1
simple_classifier_config.num_layers = 2
simple_classifier_config.dropout = 0.1

In [13]:
mda_classifier = MDFeatureBasedMDAClassifier(simple_classifier_config, md_config)

2024-02-10 18:29:37,792 [INFO] Initializing MDFeatureBasedMDAClassifier with model : simple classifier
2024-02-10 18:29:37,796 [INFO] Initializing SimpleMDAClassifier with model : simple classifier
2024-02-10 18:29:37,798 [INFO] Initial SimpleMLP with 80 input dimension, 32 hidden dimension, 1 
            output dimension, 2 layers and with 0.1 dropout
2024-02-10 18:29:37,834 [INFO] Initializing MatrixFeatureExtractor
2024-02-10 18:29:37,836 [INFO] Initializing MFFeatureExtractor with model : None and decomposer : NMF


## Optimizer

In [14]:
classifier_optimizer_config = OptimizerConfig()
classifier_optimizer_config.optimizer = torch.optim.Adam
classifier_optimizer_config.criterion = torch.nn.BCEWithLogitsLoss()
classifier_optimizer_config.lr = 0.01
classifier_optimizer_config.batch_size = 32
classifier_optimizer_config.n_epoch = 50
classifier_optimizer_config.exp_name = "adam optimizer"
classifier_optimizer_config.save = False
classifier_optimizer_config.save_path = None
classifier_optimizer_config.device = device
classifier_optimizer_config.report_size = 10  # batch to report ratio
classifier_optimizer_config.threshold = 0.5

## Train Test Approach

### Train

In [15]:
train_result = MatrixFeatureBasedMDAClassifierTrainer().train(model=mda_classifier,
                                                          data=train_data,
                                                          config=classifier_optimizer_config)

2024-02-10 18:29:37,857 [INFO] Call Training with adam optimizer
2024-02-10 18:29:37,869 [INFO] Calling build with associations :      disease  microbe  increased
190    59444    54894          1
293    66623    10559          1
251    43621    42610          1
93     50863    31268          1
424    43621      431          1
..       ...      ...        ...
845     7877     8218          0
769      654    20724          0
784    12403    50024          0
71     33293    39272          1
189    13213    14120          1

[628 rows x 3 columns]
2024-02-10 18:29:37,960 [INFO] interaction matrix with shape (5179, 5645) has built
2024-02-10 18:29:38,231 [INFO] mask matrix with shape (5179, 5645) has built. This matrix shows not non elements.
2024-02-10 18:29:39,742 [INFO] interaction has been imputed to delete nans




2024-02-10 18:31:03,614 [INFO] Initializing SimplePytorchData with X shape : torch.Size([628, 80]) and y shape : torch.Size([628, 1])
2024-02-10 18:31:03,616 [INFO] Running Simple Trainer with config : adam optimizer
2024-02-10 18:31:03,619 [INFO] moving data and model to cpu
2024-02-10 18:31:03,766 [INFO] loss: 0.0545    [1,    10]
2024-02-10 18:31:03,784 [INFO] loss: 0.0270    [1,    20]
2024-02-10 18:31:03,801 [INFO] loss: 0.0225    [2,    10]
2024-02-10 18:31:03,819 [INFO] loss: 0.0217    [2,    20]
2024-02-10 18:31:03,838 [INFO] loss: 0.0146    [3,    10]
2024-02-10 18:31:03,855 [INFO] loss: 0.0185    [3,    20]
2024-02-10 18:31:03,873 [INFO] loss: 0.0145    [4,    10]
2024-02-10 18:31:03,899 [INFO] loss: 0.0158    [4,    20]
2024-02-10 18:31:03,917 [INFO] loss: 0.0138    [5,    10]
2024-02-10 18:31:03,936 [INFO] loss: 0.0153    [5,    20]
2024-02-10 18:31:03,953 [INFO] loss: 0.0135    [6,    10]
2024-02-10 18:31:03,972 [INFO] loss: 0.0148    [6,    20]
2024-02-10 18:31:03,990 [IN

### Test

In [16]:
test_result = MatrixFeatureBasedMDAClassifierTester().test(model=mda_classifier,
                                                       data=test_data,
                                                       config=classifier_optimizer_config)

2024-02-10 18:31:05,634 [INFO] Call Testing with adam optimizer
2024-02-10 18:31:05,725 [INFO] Initializing SimplePytorchData with X shape : torch.Size([270, 80]) and y shape : torch.Size([270, 1])
2024-02-10 18:31:05,729 [INFO] Running Simple Tester with config : adam optimizer
2024-02-10 18:31:05,732 [INFO] moving data and model to cpu
2024-02-10 18:31:05,769 [INFO] Result on Test Data : {'AUC': 0.9586282578875172, 'ACC': 0.8888888888888888, 'F1 Score': 0.8888644970089457, 'AUPR': 0, 'Loss': 0.28061404824256897}


In [17]:
test_result.get_result()

{'AUC': 0.9586282578875172,
 'ACC': 0.8888888888888888,
 'F1 Score': 0.8888644970089457,
 'AUPR': 0,
 'Loss': 0.28061404824256897}

## Cross Validation

In [18]:
trainer = MatrixFeatureBasedMDAClassifierTrainer()
tester = MatrixFeatureBasedMDAClassifierTester()
factory = MDFeatureBasedMDAClassifierFactory(simple_classifier_config, md_config)
spliter = MicrobeDiseaseAssociationTrainTestSpliter(data.associations)
cross_validation(k=5, data_size=data.associations.shape[0], train_test_spliter=spliter, model_factory=factory,
                    trainer=trainer, tester=tester, config=classifier_optimizer_config)

2024-02-10 18:31:05,798 [INFO] Initializing MDFeatureBasedMDAClassifierFactory
2024-02-10 18:31:05,801 [INFO] Initializing MicrobeDiseaseAssociationTrainTestSpliter
2024-02-10 18:31:05,803 [INFO] Start 5-fold Cross Validation with config : adam optimizer
2024-02-10 18:31:05,805 [INFO] ---- Fold 1 ----
2024-02-10 18:31:05,808 [INFO] Initializing MDFeatureBasedMDAClassifier with model : simple classifier
2024-02-10 18:31:05,810 [INFO] Initializing SimpleMDAClassifier with model : simple classifier
2024-02-10 18:31:05,812 [INFO] Initial SimpleMLP with 80 input dimension, 32 hidden dimension, 1 
            output dimension, 2 layers and with 0.1 dropout
2024-02-10 18:31:05,814 [INFO] Initializing MatrixFeatureExtractor
2024-02-10 18:31:05,815 [INFO] Initializing MFFeatureExtractor with model : None and decomposer : NMF
2024-02-10 18:31:05,816 [INFO] Call Training with adam optimizer
2024-02-10 18:31:05,822 [INFO] Calling build with associations :      disease  microbe  increased
0      50



2024-02-10 18:32:29,791 [INFO] Initializing SimplePytorchData with X shape : torch.Size([719, 80]) and y shape : torch.Size([719, 1])
2024-02-10 18:32:29,793 [INFO] Running Simple Trainer with config : adam optimizer
2024-02-10 18:32:29,796 [INFO] moving data and model to cpu
2024-02-10 18:32:29,821 [INFO] loss: 0.0243    [1,    10]
2024-02-10 18:32:29,844 [INFO] loss: 0.0246    [1,    20]
2024-02-10 18:32:29,870 [INFO] loss: 0.0175    [2,    10]
2024-02-10 18:32:29,886 [INFO] loss: 0.0147    [2,    20]
2024-02-10 18:32:29,909 [INFO] loss: 0.0137    [3,    10]
2024-02-10 18:32:29,927 [INFO] loss: 0.0156    [3,    20]
2024-02-10 18:32:29,950 [INFO] loss: 0.0130    [4,    10]
2024-02-10 18:32:29,969 [INFO] loss: 0.0155    [4,    20]
2024-02-10 18:32:29,991 [INFO] loss: 0.0123    [5,    10]
2024-02-10 18:32:30,010 [INFO] loss: 0.0160    [5,    20]
2024-02-10 18:32:30,032 [INFO] loss: 0.0137    [6,    10]
2024-02-10 18:32:30,049 [INFO] loss: 0.0134    [6,    20]
2024-02-10 18:32:30,071 [IN



2024-02-10 18:33:56,802 [INFO] Initializing SimplePytorchData with X shape : torch.Size([719, 80]) and y shape : torch.Size([719, 1])
2024-02-10 18:33:56,804 [INFO] Running Simple Trainer with config : adam optimizer
2024-02-10 18:33:56,807 [INFO] moving data and model to cpu
2024-02-10 18:33:56,829 [INFO] loss: 0.0270    [1,    10]
2024-02-10 18:33:56,845 [INFO] loss: 0.0219    [1,    20]
2024-02-10 18:33:56,865 [INFO] loss: 0.0194    [2,    10]
2024-02-10 18:33:56,885 [INFO] loss: 0.0175    [2,    20]
2024-02-10 18:33:56,911 [INFO] loss: 0.0151    [3,    10]
2024-02-10 18:33:56,927 [INFO] loss: 0.0150    [3,    20]
2024-02-10 18:33:56,951 [INFO] loss: 0.0147    [4,    10]
2024-02-10 18:33:56,970 [INFO] loss: 0.0135    [4,    20]
2024-02-10 18:33:56,992 [INFO] loss: 0.0135    [5,    10]
2024-02-10 18:33:57,008 [INFO] loss: 0.0127    [5,    20]
2024-02-10 18:33:57,030 [INFO] loss: 0.0147    [6,    10]
2024-02-10 18:33:57,047 [INFO] loss: 0.0133    [6,    20]
2024-02-10 18:33:57,066 [IN



2024-02-10 18:35:22,675 [INFO] Initializing SimplePytorchData with X shape : torch.Size([719, 80]) and y shape : torch.Size([719, 1])
2024-02-10 18:35:22,678 [INFO] Running Simple Trainer with config : adam optimizer
2024-02-10 18:35:22,680 [INFO] moving data and model to cpu
2024-02-10 18:35:22,703 [INFO] loss: 0.0340    [1,    10]
2024-02-10 18:35:22,719 [INFO] loss: 0.0203    [1,    20]
2024-02-10 18:35:22,740 [INFO] loss: 0.0152    [2,    10]
2024-02-10 18:35:22,762 [INFO] loss: 0.0172    [2,    20]
2024-02-10 18:35:22,783 [INFO] loss: 0.0152    [3,    10]
2024-02-10 18:35:22,802 [INFO] loss: 0.0183    [3,    20]
2024-02-10 18:35:22,823 [INFO] loss: 0.0155    [4,    10]
2024-02-10 18:35:22,840 [INFO] loss: 0.0150    [4,    20]
2024-02-10 18:35:22,861 [INFO] loss: 0.0159    [5,    10]
2024-02-10 18:35:22,878 [INFO] loss: 0.0144    [5,    20]
2024-02-10 18:35:22,904 [INFO] loss: 0.0141    [6,    10]
2024-02-10 18:35:22,920 [INFO] loss: 0.0132    [6,    20]
2024-02-10 18:35:22,940 [IN



2024-02-10 18:36:38,035 [INFO] Initializing SimplePytorchData with X shape : torch.Size([719, 80]) and y shape : torch.Size([719, 1])
2024-02-10 18:36:38,037 [INFO] Running Simple Trainer with config : adam optimizer
2024-02-10 18:36:38,041 [INFO] moving data and model to cpu
2024-02-10 18:36:38,059 [INFO] loss: 0.0487    [1,    10]
2024-02-10 18:36:38,075 [INFO] loss: 0.0282    [1,    20]
2024-02-10 18:36:38,096 [INFO] loss: 0.0200    [2,    10]
2024-02-10 18:36:38,114 [INFO] loss: 0.0157    [2,    20]
2024-02-10 18:36:38,137 [INFO] loss: 0.0183    [3,    10]
2024-02-10 18:36:38,155 [INFO] loss: 0.0172    [3,    20]
2024-02-10 18:36:38,185 [INFO] loss: 0.0181    [4,    10]
2024-02-10 18:36:38,201 [INFO] loss: 0.0151    [4,    20]
2024-02-10 18:36:38,221 [INFO] loss: 0.0132    [5,    10]
2024-02-10 18:36:38,239 [INFO] loss: 0.0142    [5,    20]
2024-02-10 18:36:38,259 [INFO] loss: 0.0141    [6,    10]
2024-02-10 18:36:38,275 [INFO] loss: 0.0129    [6,    20]
2024-02-10 18:36:38,294 [IN



2024-02-10 18:37:49,498 [INFO] Initializing SimplePytorchData with X shape : torch.Size([716, 80]) and y shape : torch.Size([716, 1])
2024-02-10 18:37:49,499 [INFO] Running Simple Trainer with config : adam optimizer
2024-02-10 18:37:49,502 [INFO] moving data and model to cpu
2024-02-10 18:37:49,522 [INFO] loss: 0.0199    [1,    10]
2024-02-10 18:37:49,539 [INFO] loss: 0.0183    [1,    20]
2024-02-10 18:37:49,560 [INFO] loss: 0.0147    [2,    10]
2024-02-10 18:37:49,579 [INFO] loss: 0.0165    [2,    20]
2024-02-10 18:37:49,602 [INFO] loss: 0.0125    [3,    10]
2024-02-10 18:37:49,618 [INFO] loss: 0.0154    [3,    20]
2024-02-10 18:37:49,638 [INFO] loss: 0.0130    [4,    10]
2024-02-10 18:37:49,654 [INFO] loss: 0.0143    [4,    20]
2024-02-10 18:37:49,674 [INFO] loss: 0.0113    [5,    10]
2024-02-10 18:37:49,691 [INFO] loss: 0.0134    [5,    20]
2024-02-10 18:37:49,711 [INFO] loss: 0.0122    [6,    10]
2024-02-10 18:37:49,727 [INFO] loss: 0.0131    [6,    20]
2024-02-10 18:37:49,747 [IN

<base.evaluation.Result at 0x7f15c51daf50>