## 4. Important microbial marker identification

In this section, we are going to calculate the global feature importance (GFI) to find out the key microbes that contributes to the CRC detection model. The *simply-explainer* is used to calculate the GFI, more about the AggMapNet model exaplaination can be found [**here**](https://bidd-aggmap.readthedocs.io/en/latest/_HPs/hps_content.html#AggMapNet-Explainers). By calculating the importance score for each microbes, we can draw the saliency-map to find out the hot zone in the **2D-microbiomeprints**.

[**Saliency-Map**](https://www.geeksforgeeks.org/what-is-saliency-map/) is an image in which the brightness of a pixel represents how salient the pixel is i.e brightness of a pixel is directly proportional to its saliency. It is generally a grayscale image. Saliency maps are also called as a heat map where hotness refers to those regions of the image which have a big impact on predicting the class which the object belongs to. 

The purpose of the saliency-map is to find the regions which are prominent or noticeable at every location in the visual field and to guide the selection of attended locations, based on the spatial distribution of saliency. 





[4.1 Calculate the global feature importance](#4.1-Calculate-the-global-feature-importance)

* [4.1.1 GFI for model trained on overall MEGMA Fmaps](#4.1.1-GFI-for-model-trained-on-overall-MEGMA-Fmaps)

* [4.1.2 GFI for model trained on country specific MEGMA Fmaps](#4.1.2-GFI-for-model-trained-on-country-specific-MEGMA-Fmaps)

[4.2 Generate the explaination saliency map](#4.2-Generate-the-explaination-saliency-map) 

* [4.2.1 Saliency map for overall MEGMA Fmaps](#4.2.1-Saliency-map-for-overall-MEGMA-Fmaps)

* [4.2.2 Saliency map country specific MEGMA Fmaps](#4.2.2-Saliency-map-country-specific-MEGMA-Fmaps)

In [169]:
from matplotlib.ticker import FormatStrFormatter
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
import pandas as pd
import numpy as np
import os
sns.set(style='white',  font='sans-serif', font_scale=2)


from aggmap import loadmap, AggMapNet
from aggmap.AggMapNet import load_model

### 4.1 Calculate the global feature importance

Let's calculate the GFI first. we need to reload the trained model that dumped in the disk. After that, we use the `simply_explainer` to calculate GFI based on the training set data (i.e., the one country data used to train the model).

#### 4.1.1 GFI for model trained on overall MEGMA Fmaps

In [None]:
countries = ['CHN', 'USA', 'DEU', 'FRA',  'AUS'] 

model_dir = './megma_overall_model'
# load the pre-fitted megma_all object
megma = loadmap('./megma/megma.all')

gfis = []
for country in countries:
    clf = load_model(os.path.join(model_dir, 'aggmapnet.%s' % country),  gpuid=0)
    sxp = AggMapNet.simply_explainer(clf, megma, 
                                     backgroud = 'global_min', 
                                     apply_smoothing = True)
    gfi = sxp.global_explain()
    gfis.append(gfi.simply_importance_class_0.to_frame(name = country))
    
dfimp1 = pd.concat(gfis, axis=1)

2022-08-17 16:41:08,323 - [32mINFO[0m - [bidd-aggmap][0m - Explaining the whole samples of the training Set[0m
2022-08-17 16:41:08,388 - [32mINFO[0m - [bidd-aggmap][0m - calculating feature importance for class 0 ...[0m


100%|################################################################################| 870/870 [00:08<00:00, 105.37it/s]

2022-08-17 16:41:16,647 - [32mINFO[0m - [bidd-aggmap][0m - calculating feature importance for class 1 ...[0m



100%|################################################################################| 870/870 [00:07<00:00, 113.23it/s]

2022-08-17 16:41:24,397 - [32mINFO[0m - [bidd-aggmap][0m - Explaining the whole samples of the training Set[0m





2022-08-17 16:41:24,467 - [32mINFO[0m - [bidd-aggmap][0m - calculating feature importance for class 0 ...[0m


 58%|##############################################                                  | 501/870 [00:03<00:02, 123.52it/s]

#### 4.1.2 GFI for model trained on country specific MEGMA Fmaps

In [None]:
model_dir = './megma_country_model'

gfis2 = []
reshape_indexes = {}
for country in countries:
    
    clf = load_model(os.path.join(model_dir, 'aggmapnet.%s' % country),  gpuid=0)
    megma = loadmap(os.path.join(model_dir, 'megma.%s' % country))

    sxp = AggMapNet.simply_explainer(clf, megma, 
                                     backgroud = 'global_min', 
                                     apply_smoothing = True)
    gfi2 = sxp.global_explain(clf.X_, clf.y_)
    gfis2.append(gfi2.simply_importance_class_0.to_frame(name = country))
    
    ## megma is different, therefore the reshape index and fmap_shape is also different
    reshape_index = megma.feature_names_reshape
    reshape_indexes.update({country: (megma.fmap_shape, reshape_index)})
    
dfimp2 = pd.concat(gfis2, axis=1)

In [None]:
country

### 4.2 Generate the explaination saliency map



#### 4.2.1 Saliency map for overall MEGMA Fmaps



In [None]:
megma = loadmap('./megma/megma.all')

for country in countries:
    fig, ax  = plt.subplots(1, 1, figsize=(10, 9))
    IMPM = dfimp1[country].values.reshape(*megma.fmap_shape)
    print(IMPM.max().round(1))

    sns.heatmap(IMPM,  
                cmap = 'rainbow', alpha = 0.8, xticklabels=4, ax =ax,
                yticklabels=4, vmin = -0.1, vmax = 4.2,
                cbar_kws = {'fraction':0.046, 'shrink':0.9, 'aspect': 40, 'pad':0.02, })

    bottom, top = ax.get_ylim()
    #ax.set_ylim(bottom + 0.5, top - 0.5)

    ax.set_title(country)
    cbar = ax.collections[0].colorbar
    cbar.ax.set_title('FI')
    cbar.ax.yaxis.set_major_formatter(FormatStrFormatter("%.1f"))

    plt.subplots_adjust(wspace = 0.18)

#### 4.2.2 Saliency map country specific MEGMA Fmaps

In [None]:
for country in countries:

    fig, ax  = plt.subplots(1, 1, figsize=(10, 9))

    fmap_shape, reshape_idx =  reshape_indexes[country]
    IMPM = dfimp2.loc[reshape_idx][country].values.reshape(*fmap_shape)
    print(IMPM.max().round(1))

    sns.heatmap(IMPM,  
                cmap = 'rainbow', alpha = 0.8, xticklabels=4, ax =ax,
                yticklabels=4, vmin = -0.1, vmax = 3.5,
                cbar_kws = {'fraction':0.046, 'shrink':0.9, 'aspect': 40, 'pad':0.02, })

    bottom, top = ax.get_ylim()
    #ax.set_ylim(bottom + 0.5, top - 0.5)

    ax.set_title(country)
    cbar = ax.collections[0].colorbar
    cbar.ax.set_title('FI')
    cbar.ax.yaxis.set_major_formatter(FormatStrFormatter("%.1f"))

    plt.subplots_adjust(wspace = 0.18)

In [None]:
sns.heatmap(dfimp2.corr(), cmap = 'rainbow_r')

In [None]:
megma

In [None]:
megmas = []
for country in countries:
    megma = loadmap(os.path.join(model_dir, 'megma.%s' % country))
    megmas.append(megma)
    
chn, usa, deu, fra, aus = megmas


In [None]:
aus = dict(zip(countries,megmas))['AUS']

In [None]:
countries

In [None]:
aus.plot_grid()