## 3. Training & Test AggMapNet on MEGMA Fmaps

In this section, we will introduce how to employ the CNN-based AggMapNet to train CRC detection models using the `megma` generated **2D-microbiomeprints (Fmaps)**. The CRC dection model is a classfication model because we have the binary labels (CRCs or CTRs) in our data. Same as the Fmaps generation, we can use the study to stduy transfer(STST) method to test the performance of our model.

Noted that in the STST experiment, the Fmaps can be transformed by a country-specific `megma` (fitted or trained by one country unlabelled metagenomic data) or an overall `megma` (fitted or trained by all unlabelled metagenomic data), becuase **MEGMA** is an unsupervised learning method and was fitted on the unlabelled data only.

First of all, we are going to train the classficaition model based on one country data, and then we will test the performance on the rest of the countries.

### Training CRC detection model 

In [None]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from aggmap import AggMap, AggMapNet
from aggmap import show, loadmap


## megma parameters
metric = 'correlation' #distance metric
cluster_channels = 5 #channel number
embed_methd = 'umap' #embedding method
random_state = 888 #random seed
minv = 1e-8 #minimal value for log-transform
scale_method = 'standard' #data scaling method

## AggMapNet parameters
epochs = 30 #number of epochs
lr = 1e-4 #learning rate
batch_size = 2 #batch size
conv1_kernel_size = 5 #kernal size of the first cnn layer




countries = ['CHN',] # 'AUS', 'DEU', 'FRA', 'USA'

url = 'https://raw.githubusercontent.com/shenwanxiang/bidd-aggmap/master/docs/source/_example_MEGMA/dataset/'

for country in countries:
    dfx_vector = pd.read_csv(url + '%s_dfx.csv' % country, index_col='Sample_ID')
    dfx_vector = np.log(dfx_vector + minv)
    
    
    dfy = pd.read_csv(url + '%s_dfy.csv' % country, index_col='Sample_ID')

    
    megma = AggMap(dfx_vector)
    megma = megma.fit(random_state = random_state)
    
    save_name = './megma.%s' % country
    megma.save(save_name)
    
    X = megma.batch_transform(dfx.values, scale_method = scale_method)
    
    Y = pd.get_dummies(dfy.Group).values

### Test model performance

2022-08-06 15:10:30,490 - [32mINFO[0m - [bidd-aggmap][0m - Calculating distance ...[0m
2022-08-06 15:10:30,510 - [32mINFO[0m - [bidd-aggmap][0m - the number of process is 16[0m


100%|######################################################################################################################################| 359976/359976 [00:28<00:00, 12756.09it/s]
100%|####################################################################################################################################| 359976/359976 [00:00<00:00, 5400977.91it/s]
100%|#############################################################################################################################################| 849/849 [00:00<00:00, 5723.42it/s]


2022-08-06 15:10:59,093 - [32mINFO[0m - [bidd-aggmap][0m - applying hierarchical clustering to obtain group information ...[0m
UMAP(a=None, angular_rp_forest=False, b=None, init='spectral',
     learning_rate=1.0, local_connectivity=1.0, metric='precomputed',
     metric_kwds=None, min_dist=0.1, n_components=2, n_epochs=None,
     n_neighbors=15, negative_sample_rate=5, random_state=888,
     repulsion_strength=1.0, set_op_mix_ratio=1.0, spread=1.0,
     target_metric='categorical', target_metric_kwds=None,
     target_n_neighbors=-1, target_weight=0.5, transform_queue_size=4.0,
     transform_seed=42, verbose=2)
Construct fuzzy simplicial set
Sat Aug  6 15:10:59 2022 Finding Nearest Neighbors
Sat Aug  6 15:10:59 2022 Finished Nearest Neighbor Search
Sat Aug  6 15:10:59 2022 Construct embedding
	completed  0  /  500 epochs
	completed  50  /  500 epochs
	completed  100  /  500 epochs
	completed  150  /  500 epochs
	completed  200  /  500 epochs
	completed  250  /  500 epochs
	comple

In [5]:
def train(dfx_vector, dfy, country, megma = None, ):
    
    if megma is None:
        megma = AggMap(dfx_vector)
        megma = megma.fit(random_state = random_state)

        save_name = './megma.%s' % country
        megma.save(save_name)    
    
    X = megma.batch_transform(dfx.values, scale_method = scale_method)
    
    Y = pd.get_dummies(dfy.Group).values    
    
    
    
    
    

Unnamed: 0_level_0,CRC,CTR
Sample_ID,Unnamed: 1_level_1,Unnamed: 2_level_1
ERR1018185,1,0
ERR1018186,1,0
ERR1018187,1,0
ERR1018188,1,0
ERR1018189,1,0
...,...,...
ERR1018308,0,1
ERR1018309,0,1
ERR1018310,1,0
ERR1018311,1,0
