# GEOGRAPHY MANUSCRIPT CLEANED VERSION.

Code for Lymberopoulos et al., 2021.

For tmap_geography repo.

In [None]:
### Imports ### 
import pandas as pd
import os
import numpy as np

from sklearn.preprocessing import MinMaxScaler
from sklearn.cluster import DBSCAN
from tmap.tda import mapper, Filter
from tmap.tda.cover import Cover
from tmap.tda.plot import Color, vis_progressX
from tmap.tda.metric import Metric
from tmap.tda.utils import optimize_dbscan_eps
from scipy.spatial.distance import squareform,pdist
from tmap.test import load_data
from tmap.netx.SAFE import SAFE_batch, get_SAFE_summary, get_significant_nodes
from tmap.netx.coenrichment_analysis import *

import plotly.graph_objs as go
from collections import Counter
pd.options.display.expand_frame_repr = False
pd.options.display.max_rows = 1000
import scipy.stats as scs
import warnings
import plotly
import plotly.express as px

## Data import and parsing

In [None]:
X = pd.read_csv("genus_health_all.csv", index_col=0)
metadata_import = pd.read_csv("metadata_health_all.csv", index_col=0)

In [None]:
## only select dummy coded variables and host age
cols =['host_age', 'sex: Female', 'sex: Male', 'BMI: normal', 'BMI: obese', 'BMI: overweight', 'BMI: underweight',
       'country: Brazil', 'country: Canada', 'country: China', 'country: Denmark',
       'country: France', 'country: Germany', 'country: Italy','country: New Zealand', 'country: Spain',
       'country: Tanzania, United Republic of', 'country: United Kingdom', 'country: United States of America']



metadata = metadata_import[cols].copy()

metadata.rename(columns = {'country: Brazil': 'Brazil', 'country: Canada': "Canada", 'country: China': "China",
                          'country: Denmark': "Denmark", 'country: France': "France", 'country: Germany': "Germany",
                           'country: Italy': "Italy", 'country: New Zealand': "New Zealand", 'country: Spain': "Spain",
                           'country: Tanzania, United Republic of': "Tanzania", 'country: United Kingdom': "UK",
                           'country: United States of America': "USA"}, inplace = True)

## TDA analysis

In [None]:
"""MAPPER"""

# Converting vector-form distance to square-form distance matrix
dm = squareform(pdist(X,metric='braycurtis')) 

# Step 1: Initiate the Mapper algorithm
tm = mapper.Mapper(verbose=1)

# Step 2: Projection into low dimensional space
metric = Metric(metric="precomputed")
lens = [Filter.MDS(components=[0, 1], metric=metric, random_state=100)]
projected_X = tm.filter(dm, lens=lens)

In [None]:
# Step 3: Covering, clustering & mapping
eps = optimize_dbscan_eps(X, threshold=95)
clusterer = DBSCAN(eps=eps, min_samples=5)

cover = Cover(projected_data=MinMaxScaler().fit_transform(projected_X), resolution=85, overlap=0.85)
graph = tm.map(data=X, cover=cover, clusterer=clusterer)
print(graph.info())
graph.show()

In [None]:
"""SAFE"""

# Setting number of permutations
n_iter = 5000

### Enriching the metadata
safe_scores = SAFE_batch(graph,metadata = metadata,n_iter=n_iter,nr_threshold=0.5,_mode='both')
enriched_SAFE_metadata,declined_SAFE_metadata = safe_scores['enrich'],safe_scores['decline']

In [None]:
### Enriching the taxa
safe_scores = SAFE_batch(graph,metadata = X,n_iter=n_iter,nr_threshold=0.5,_mode='both')
enriched_SAFE,declined_SAFE = safe_scores['enrich'],safe_scores['decline']

In [None]:
## Obtaining and summarising the enrichment data (scores and enriched nodes, etc.)

# enrichment summary for metadata:
safe_summary_metadata = get_SAFE_summary(graph=graph, metadata=metadata, safe_scores=enriched_SAFE_metadata,
                                n_iter=n_iter, p_value=0.05)

# enrichment summary for taxa:
safe_summary = get_SAFE_summary(graph=graph, metadata=X, safe_scores=enriched_SAFE,
                                n_iter=n_iter, p_value=0.05)

# combine the enrichments
enriched_SAFE_total = pd.concat([enriched_SAFE,enriched_SAFE_metadata],axis=1)

# produce a combined summary
safe_summary_total = pd.concat([safe_summary_metadata,safe_summary],axis=0)

In [None]:
### Calculate pairwise co-enrichment for all variables

enriched_centroides, enriched_nodes = get_significant_nodes(graph=graph,
                                                            safe_scores=enriched_SAFE_total,
                                                            n_iter=n_iter,
                                                            pvalue=0.05,
                                                            r_neighbor=True
                                                            )
corrected_fe_dis = pairwise_coenrichment(graph,
                                         safe_scores=enriched_SAFE_total,
                                         _pre_cal_enriched=enriched_centroides)

## save to csv for later use in another script
corrected_fe_dis.to_csv('scores_coenrichment.csv')

## Visualising the results

#### Metadata visualisations

In [None]:
"""Country Comparison"""


node_pos = graph.nodePos
nodes = graph.node
sizes = graph.size

enriched_centroides = get_significant_nodes(graph=graph,
                                 safe_scores=enriched_SAFE_total,
                                 pvalue=0.05,
                                 n_iter=5000
                                 )
xs = []
ys = []
for edge in graph.edges:
    xs += [node_pos[edge[0], 0], node_pos[edge[1], 0], None]
    ys += [node_pos[edge[0], 1], node_pos[edge[1], 1], None]
fig = plotly.subplots.make_subplots(1,1)


node_line = go.Scatter(
    # ordination line
    visible=True, x=xs, y=ys, marker=dict(color="#8E9DA2", opacity=0.7),
    line=dict(width=1), showlegend=False, mode="lines")
fig.append_trace(node_line, 1, 1)

color_tmp = {'Brazil': '#ff4040', 'Canada': '#33d2ff', 'China': '#fcba03',
             'Denmark': '#011941', 'France': '#41a941', 'Germany': '#FF8D15',
             'Italy': '#8D6E63', 'New Zealand': '#C70039', 'Spain': '#6C3483',
             'Tanzania': '#E8DAEF', 'UK': '#1B5E20',
             'USA': '#0053C7'}

for idx,fea in enumerate(['USA', 'UK', 'New Zealand', 
                          'Canada', 'China', 'Denmark', 'France',
                          'Germany', 'Italy', 'Spain', 'Brazil', 
                          'Tanzania']):
    node_position = go.Scatter(
        # node position
        visible=True,
        x=node_pos[enriched_centroides[fea], 0],
        y=node_pos[enriched_centroides[fea], 1],
        hoverinfo="text",
        marker=dict(color=color_tmp[fea], size=[5 + sizes[_] for _ in enriched_centroides[fea]], opacity=0.7),
        showlegend=True, name = fea, mode="markers")
    fig.append_trace(node_position, 1,1)
fig.layout.width = 1200
fig.layout.height = 1000
fig.layout.font.size = 16
plotly.offline.iplot(fig)

In [None]:
"""Ranking the host metadata according to network-wide enrichment"""

meta_sort = safe_summary_metadata.sort_values(by='SAFE enriched score', ascending = False)
meta_bar = px.bar(meta_sort, x=meta_sort.index, y='SAFE enriched score', color='SAFE enriched score')
meta_bar.show()

In [None]:
"""Other metadata"""

# Function to make heat maps of metadata
def metadata_visual(parameter): 
    fea1 = parameter
    fig = plotly.subplots.make_subplots(
        1,2,subplot_titles=['{}<Br>original values'.format(fea1),'{}<Br>SAFE values'.format(fea1)])

    color1 = Color(metadata.loc[:,fea1],target_by='sample',dtype='numerical')
    color2 = Color(enriched_SAFE_total.loc[:,fea1],target_by='node',dtype='numerical')
    _d = vis_progressX(graph,color=color1,simple=True,mode='obj')
    _d2 = vis_progressX(graph,color=color2,simple=True,mode='obj')
    fig.append_trace(_d.data[0],1,1)
    _d.data[1]['marker']['colorbar']['x'] = -0.02
    fig.append_trace(_d.data[1],1,1)
    for _ in _d2.data:
        fig.append_trace(_,1,2)


    fig.layout.width =1800
    fig.layout.height=900
    fig.layout.xaxis1.showticklabels = False
    fig.layout.yaxis1.showticklabels = False
    fig.layout.xaxis2.showticklabels = False
    fig.layout.yaxis2.showticklabels = False

    fig.layout.xaxis1.zeroline = False
    fig.layout.yaxis1.zeroline = False
    fig.layout.xaxis2.zeroline = False
    fig.layout.yaxis2.zeroline = False
    fig.layout.hovermode = 'closest'
    plotly.offline.iplot(fig)

In [None]:
### Stratification by Age
metadata_visual('host_age')

In [None]:
### Sex Comparison

node_pos = graph.nodePos
nodes = graph.node
sizes = graph.size

enriched_centroides = get_significant_nodes(graph=graph, safe_scores=enriched_SAFE_total, pvalue=0.05, n_iter=5000)
xs = []
ys = []
for edge in graph.edges:
    xs += [node_pos[edge[0], 0], node_pos[edge[1], 0], None]
    ys += [node_pos[edge[0], 1], node_pos[edge[1], 1], None]
fig = plotly.subplots.make_subplots(1,1)


node_line = go.Scatter(
    # ordination line
    visible=True,
    x=xs,
    y=ys,
    marker=dict(color="#8E9DA2",
                opacity=0.7),
    line=dict(width=1),
    showlegend=False,
    mode="lines")
fig.append_trace(node_line, 1, 1)

color_tmp = {'sex: Female':'#da3c3d',
            'sex: Male':'#011941'}

for idx,fea in enumerate(["sex: Female", "sex: Male"]):
    node_position = go.Scatter(
        # node position
        visible=True,
        x=node_pos[enriched_centroides[fea], 0],
        y=node_pos[enriched_centroides[fea], 1],
        hoverinfo="text",
        marker=dict(#color=node_colors,
                    color=color_tmp[fea],
                    size=[5 + sizes[_] for _ in enriched_centroides[fea]],
                    opacity=0.7),
        showlegend=True,
        name = fea,
        mode="markers")
    fig.append_trace(node_position, 1,1)
fig.layout.width = 1000    
fig.layout.height = 1000
fig.layout.font.size = 15
plotly.offline.iplot(fig)

In [None]:
### BMI Comparison

node_pos = graph.nodePos
nodes = graph.node
sizes = graph.size

enriched_centroides = get_significant_nodes(graph=graph,
                                 safe_scores=enriched_SAFE_total,
                                 pvalue=0.05,
                                 n_iter=5000
                                 )
xs = []
ys = []
for edge in graph.edges:
    xs += [node_pos[edge[0], 0],
           node_pos[edge[1], 0],
           None]
    ys += [node_pos[edge[0], 1],
           node_pos[edge[1], 1],
           None]
fig = plotly.subplots.make_subplots(1,1)


node_line = go.Scatter(
    # ordination line
    visible=True,
    x=xs,
    y=ys,
    marker=dict(color="#8E9DA2",
                opacity=0.7),
    line=dict(width=1),
    showlegend=False,
    mode="lines")
fig.append_trace(node_line, 1, 1)

color_tmp = {'BMI: underweight':'#ff4040',
            'BMI: normal':'#33d2ff',
            'BMI: overweight':'#fcba03',
            'BMI: obese':'#011941'}
## !!this dataset doesn't have severe and morbily obese participants, just as a note when comparing to Atlas

for idx,fea in enumerate(['BMI: underweight', 'BMI: normal', 'BMI: overweight', 'BMI: obese']):
    node_position = go.Scatter(
        # node position
        visible=True,
        x=node_pos[enriched_centroides[fea], 0],
        y=node_pos[enriched_centroides[fea], 1],
        hoverinfo="text",
        marker=dict(#color=node_colors,
                    color=color_tmp[fea],
                    size=[5 + sizes[_] for _ in enriched_centroides[fea]],
                    opacity=0.7),
        showlegend=True,
        name = fea,
        mode="markers")
    fig.append_trace(node_position, 1,1)
fig.layout.width = 1200
fig.layout.height = 1000
fig.layout.font.size = 16
plotly.offline.iplot(fig)

#### Taxa visualisations

In [None]:
### Stratification of taxa with most enriched nodes

node_pos = graph.nodePos
nodes = graph.node
sizes = graph.size

xs = []
ys = []
for edge in graph.edges:
    xs += [node_pos[edge[0], 0],
           node_pos[edge[1], 0],
           None]
    ys += [node_pos[edge[0], 1],
           node_pos[edge[1], 1],
           None]
fig = plotly.subplots.make_subplots(1, 1)

node_line = go.Scatter(
    # ordination line
    visible=True,
    x=xs,
    y=ys,
    marker=dict(color="#8E9DA2",
                opacity=0.7),
    line=dict(width=1),
    showlegend=False,
    mode="lines")
fig.append_trace(node_line, 1, 1)

# setting the threshold for what constitutes a significantly enriched node
safe_score_df = pd.DataFrame.from_dict(enriched_SAFE)
min_p_value = 1.0 / (5000 + 1.0)
threshold = np.log10(0.05) / np.log10(min_p_value)

# listing all the significant nodes, counting them, and sorting them 
tmp = [safe_score_df.columns[_] if safe_score_df.iloc[idx, _] > threshold else np.nan for idx, _ in enumerate(np.argmax(safe_score_df.values, axis=1))]
t = Counter(tmp)
t_sort = sorted(t.items(), key=lambda x: x[1], reverse = True)

# specifiying the colours for the taxa
clr = {0:'#ff4040',  1:'#33d2ff', 2:'#fcba03', 3:'#011941', 4:'#41a941', 5:'#9e76c3', 6:'#0053C7', 7:'#C70039',
       8: '#6C3483', 9: '#E8DAEF', 10: '#1B5E20', 11: '#8D6E63', 12: '#FF8D15', 13: '#FFF59D', 14: '#66FF33',
       15: '#0000FF', 16: '#CC6600', 17: '#660066', 18: '#00897B'
      }

# visualisation
for idx, fea in enumerate([_ for _, v in t_sort if v >= 25]): 
# v denotes the number of notes, here I include all the taxa with more than 15 significantly enriched nodes
    # safe higher than threshold, just centroides
    node_position = go.Scatter(
        # node position
        visible=True,
        x=node_pos[np.array(tmp) == fea, 0],
        y=node_pos[np.array(tmp) == fea, 1],
        hoverinfo="text",
        marker=dict(  # color=node_colors,
            color = clr[idx],
            size=[5 + sizes[_] for _ in np.arange(node_pos.shape[0])[np.array(tmp) == fea]],
            opacity=0.9),
        showlegend=True,
        name=fea + ' (%s)' % str(t[fea]),
        mode="markers")
    fig.append_trace(node_position, 1, 1)
fig.layout.width = 1300
fig.layout.height = 1000
fig.layout.font.size = 18
fig.layout.hovermode = 'closest'
plotly.offline.iplot(fig)

In [None]:
### Ranking the taxa according to network-wide enrichment

# ! this is only the "top 40" most enriched taxa, when displaying all, we can barely see them
taxa_sort = safe_summary.sort_values(by='SAFE enriched score', ascending=True)
taxa_sort = taxa_sort[-40:]
taxa_bar = px.bar(taxa_sort, x='SAFE enriched score', y=taxa_sort.index, color='SAFE enriched score', 
                  orientation = 'h', height = 1000, text='SAFE enriched score')
taxa_bar.update_traces(texttemplate='%{text:.3s}', textposition='inside')
plotly.offline.plot(taxa_bar, auto_open = True)
taxa_bar.show()

In [None]:
safe_summary.loc["Prevotella"]

In [None]:
safe_summary.loc["Dorea"]

In [None]:
### Function to make heat maps for taxa
def taxa_visual(parameter):  
    fea1 = parameter
    fig = plotly.subplots.make_subplots(
        1,2,subplot_titles=['{}<Br>original values'.format(fea1),'{}<Br>SAFE values'.format(fea1)])
    
    color1 = Color(X.loc[:,fea1],target_by='sample',dtype='numerical')
    color2 = Color(enriched_SAFE_total.loc[:,fea1],target_by='node',dtype='numerical')
    _d = vis_progressX(graph,color=color1,simple=True,mode='obj')
    _d2 = vis_progressX(graph,color=color2,simple=True,mode='obj')
    fig.append_trace(_d.data[0],1,1)
    _d.data[1]['marker']['colorbar']['x'] = -0.02
    fig.append_trace(_d.data[1],1,1)
    for _ in _d2.data:
        fig.append_trace(_,1,2)


    fig.layout.width =1800
    fig.layout.height=900
    fig.layout.xaxis1.showticklabels = False
    fig.layout.yaxis1.showticklabels = False
    fig.layout.xaxis2.showticklabels = False
    fig.layout.yaxis2.showticklabels = False

    fig.layout.xaxis1.zeroline = False
    fig.layout.yaxis1.zeroline = False
    fig.layout.xaxis2.zeroline = False
    fig.layout.yaxis2.zeroline = False
    fig.layout.hovermode = 'closest'
    plotly.offline.iplot(fig)

In [None]:
### Collecting the top 10 taxa and visualising them

taxa_desc = taxa_sort.sort_values(by='SAFE enriched score', ascending=False)
top = taxa_desc.iloc[:10]
top_taxa = list(top.index.values)

# visualisation
for taxa in top_taxa:
    taxa_visual(taxa)

In [None]:
### Visualising additional relevant taxa

taxa = ["Paludibacter", "Bacteroides", "Lachnoclostridium", "Ruminoclostridium", "Prevotella", "Alistipes"]
for taxon in taxa:
   taxa_visual(taxon)