# The full analysis pipeline using TDA
### This takes the fully filtered dataset as input

In [None]:
### Imports ### 
import pandas as pd
import os
import numpy as np

from sklearn.preprocessing import MinMaxScaler
from sklearn.cluster import DBSCAN
from tmap.tda import mapper, Filter
from tmap.tda.cover import Cover
from tmap.tda.plot import Color, vis_progressX
from tmap.tda.metric import Metric
from tmap.tda.utils import optimize_dbscan_eps
from scipy.spatial.distance import squareform,pdist
from tmap.test import load_data
from tmap.netx.SAFE import SAFE_batch, get_SAFE_summary, get_significant_nodes

import plotly.graph_objs as go
from collections import Counter
pd.options.display.expand_frame_repr = False
pd.options.display.max_rows = 1000
import scipy.stats as scs
import warnings
import plotly
import plotly.express as px

In [None]:
### Data import and parsing ###
X = pd.read_csv("curat_otu_filtered.csv", index_col=0)
X.columns = X.columns.str.replace("_", " ")
X

In [None]:
metadata = pd.read_csv("curat_metadata_filtered.csv", index_col=0, dtype={"disease": "string", 
                                                                          "country": "string"})
metadata = metadata.drop(["age", "disease", "country", "hosp_by_pop"], axis=1)
metadata = metadata.set_axis(['% over 70', 'GDP', 'HDI', 'female', 'male', 'normal', 'obese', 'overweight', 'underweight', 
                              'high severity', 'low severity', 'adult', 'child', 'schoolage', 'senior'], axis=1)
metadata

In [None]:
"""TDA ANALYSIS"""

# Converting vector-form distance to square-form distance matrix
dm = squareform(pdist(X,metric='braycurtis')) 

# Step 1: Initiate the Mapper algorithm
tm = mapper.Mapper(verbose=1)

# Step 2: Projection into low dimensional space
metric = Metric(metric="precomputed")
lens = [Filter.MDS(components=[0, 1], metric=metric, random_state=100)]
projected_X = tm.filter(dm, lens=lens)

# Step 3: Covering, clustering & mapping
eps = optimize_dbscan_eps(X, threshold=95)
clusterer = DBSCAN(eps=eps, min_samples=7)

In [None]:
cover = Cover(projected_data=MinMaxScaler().fit_transform(projected_X), resolution=100, overlap=1.1)
graph = tm.map(data=X, cover=cover, clusterer=clusterer)
print(graph.info())
graph.show()

In [None]:
"""SAFE ANALYSIS"""

# Setting number of permutations
n_iter = 4000

## Enriching the metadata
safe_scores = SAFE_batch(graph,metadata = metadata,n_iter=n_iter,nr_threshold=0.5,_mode='both')
enriched_SAFE_metadata,declined_SAFE_metadata = safe_scores['enrich'],safe_scores['decline']

In [None]:
## Enriching the taxa
safe_scores = SAFE_batch(graph,metadata = X,n_iter=n_iter,nr_threshold=0.5,_mode='both')
enriched_SAFE,declined_SAFE = safe_scores['enrich'],safe_scores['decline']

In [None]:
## Obtaining and summarising the enrichment data (scores and enriched nodes, etc.)

# enrichment summary for metadata:
safe_summary_metadata = get_SAFE_summary(graph=graph, metadata=metadata, safe_scores=enriched_SAFE_metadata,
                                n_iter=n_iter, p_value=0.05)

# enrichment summary for taxa:
safe_summary = get_SAFE_summary(graph=graph, metadata=X, safe_scores=enriched_SAFE,
                                n_iter=n_iter, p_value=0.05)

# combine the enrichments
enriched_SAFE_total = pd.concat([enriched_SAFE,enriched_SAFE_metadata],axis=1)

# produce a combined summary
safe_summary_total = pd.concat([safe_summary_metadata,safe_summary],axis=0)

In [None]:
"""VISUALIATIONS"""
# Function to make heat maps of metadata
def metadata_visual(parameter): 
    fea1 = parameter
    fig = plotly.subplots.make_subplots(
        1,2,subplot_titles=['{}<Br>original values'.format(fea1),'{}<Br>SAFE values'.format(fea1)])

    color1 = Color(metadata.loc[:,fea1],target_by='sample',dtype='numerical')
    color2 = Color(enriched_SAFE_total.loc[:,fea1],target_by='node',dtype='numerical')
    _d = vis_progressX(graph,color=color1,simple=True,mode='obj')
    _d2 = vis_progressX(graph,color=color2,simple=True,mode='obj')
    fig.append_trace(_d.data[0],1,1)
    _d.data[1]['marker']['colorbar']['x'] = -0.02
    fig.append_trace(_d.data[1],1,1)
    for _ in _d2.data:
        fig.append_trace(_,1,2)


    fig.layout.width =1800
    fig.layout.height=900
    fig.layout.xaxis1.showticklabels = False
    fig.layout.yaxis1.showticklabels = False
    fig.layout.xaxis2.showticklabels = False
    fig.layout.yaxis2.showticklabels = False

    fig.layout.xaxis1.zeroline = False
    fig.layout.yaxis1.zeroline = False
    fig.layout.xaxis2.zeroline = False
    fig.layout.yaxis2.zeroline = False
    fig.layout.hovermode = 'closest'
    plotly.offline.iplot(fig)

# Function to make heat maps for taxa
def taxa_visual(parameter):  
    fea1 = parameter
    fig = plotly.subplots.make_subplots(
        1,2,subplot_titles=['{}<Br>original values'.format(fea1),'{}<Br>SAFE values'.format(fea1)])
    
    color1 = Color(X.loc[:,fea1],target_by='sample',dtype='numerical')
    color2 = Color(enriched_SAFE_total.loc[:,fea1],target_by='node',dtype='numerical')
    _d = vis_progressX(graph,color=color1,simple=True,mode='obj')
    _d2 = vis_progressX(graph,color=color2,simple=True,mode='obj')
    fig.append_trace(_d.data[0],1,1)
    _d.data[1]['marker']['colorbar']['x'] = -0.02
    fig.append_trace(_d.data[1],1,1)
    for _ in _d2.data:
        fig.append_trace(_,1,2)


    fig.layout.width =1800
    fig.layout.height=900
    fig.layout.xaxis1.showticklabels = False
    fig.layout.yaxis1.showticklabels = False
    fig.layout.xaxis2.showticklabels = False
    fig.layout.yaxis2.showticklabels = False

    fig.layout.xaxis1.zeroline = False
    fig.layout.yaxis1.zeroline = False
    fig.layout.xaxis2.zeroline = False
    fig.layout.yaxis2.zeroline = False
    fig.layout.hovermode = 'closest'
    plotly.offline.iplot(fig)

In [None]:
### PD groups comparison

node_pos = graph.nodePos
nodes = graph.node
sizes = graph.size

enriched_centroides = get_significant_nodes(graph=graph,
                                 safe_scores=enriched_SAFE_total,
                                 pvalue=0.05,
                                 n_iter=5000
                                 )
xs = []
ys = []
for edge in graph.edges:
    xs += [node_pos[edge[0], 0],
           node_pos[edge[1], 0],
           None]
    ys += [node_pos[edge[0], 1],
           node_pos[edge[1], 1],
           None]
fig = plotly.subplots.make_subplots(1,1)


node_line = go.Scatter(
    # ordination line
    visible=True,
    x=xs,
    y=ys,
    marker=dict(color="#8e9da2",
                opacity=0.7),
    line=dict(width=1),
    showlegend=False,
    mode="lines")
fig.append_trace(node_line, 1, 1)

color_tmp = {'low severity':'#0053C7',
            'high severity':'#C70039'}

for idx,fea in enumerate(['low severity', 'high severity']):
    node_position = go.Scatter(
        # node position
        visible=True,
        x=node_pos[enriched_centroides[fea], 0],
        y=node_pos[enriched_centroides[fea], 1],
        hoverinfo="text",
        marker=dict(#color=node_colors,
                    color=color_tmp[fea],
                    size=[5 + sizes[_] for _ in enriched_centroides[fea]],
                    opacity=0.7),
        showlegend=True,
        name = fea,
        mode="markers")
    fig.append_trace(node_position, 1,1)
fig.layout.width = 1000
fig.layout.height = 1000
fig.layout.font.size = 15
plotly.offline.iplot(fig)

In [None]:
## Ranking the host metadata according to network-wide enrichment
meta_sort = safe_summary_metadata.sort_values(by='SAFE enriched score', ascending = False)
meta_bar = px.bar(meta_sort, x=meta_sort.index, y='SAFE enriched score', color='SAFE enriched score')
meta_bar.show()

In [None]:
metadata_visual("% over 70")
metadata_visual("GDP")
metadata_visual("HDI")

In [None]:
# gender dummy grouped

node_pos = graph.nodePos
nodes = graph.node
sizes = graph.size

enriched_centroides = get_significant_nodes(graph=graph,
                                 safe_scores=enriched_SAFE_total,
                                 pvalue=0.05,
                                 n_iter=5000
                                 )
xs = []
ys = []
for edge in graph.edges:
    xs += [node_pos[edge[0], 0],
           node_pos[edge[1], 0],
           None]
    ys += [node_pos[edge[0], 1],
           node_pos[edge[1], 1],
           None]
fig = plotly.tools.make_subplots(1,1)


node_line = go.Scatter(
    # ordination line
    visible=True,
    x=xs,
    y=ys,
    marker=dict(color="#8E9DA2",
                opacity=0.7),
    line=dict(width=1),
    showlegend=False,
    mode="lines")
fig.append_trace(node_line, 1, 1)

color_tmp = {'female':'#ff4040',
            'male':'#9e76c3'}

for idx,fea in enumerate(['female', 'male']):
    node_position = go.Scatter(
        # node position
        visible=True,
        x=node_pos[enriched_centroides[fea], 0],
        y=node_pos[enriched_centroides[fea], 1],
        hoverinfo="text",
        marker=dict(#color=node_colors,
                    color=color_tmp[fea],
                    size=[5 + sizes[_] for _ in enriched_centroides[fea]],
                    opacity=0.7),
        showlegend=True,
        name = fea,
        mode="markers")
    fig.append_trace(node_position, 1,1)
fig.layout.width = 1200
fig.layout.height = 1000
fig.layout.font.size = 16
plotly.offline.iplot(fig)

In [None]:
# age dummy grouped

node_pos = graph.nodePos
nodes = graph.node
sizes = graph.size

enriched_centroides = get_significant_nodes(graph=graph,
                                 safe_scores=enriched_SAFE_total,
                                 pvalue=0.05,
                                 n_iter=5000
                                 )
xs = []
ys = []
for edge in graph.edges:
    xs += [node_pos[edge[0], 0],
           node_pos[edge[1], 0],
           None]
    ys += [node_pos[edge[0], 1],
           node_pos[edge[1], 1],
           None]
fig = plotly.tools.make_subplots(1,1)


node_line = go.Scatter(
    # ordination line
    visible=True,
    x=xs,
    y=ys,
    marker=dict(color="#8E9DA2",
                opacity=0.7),
    line=dict(width=1),
    showlegend=False,
    mode="lines")
fig.append_trace(node_line, 1, 1)

color_tmp = {'child':'#ff4040',
            'schoolage':'#33d2ff',
            'adult':'#fcba03',
            'senior':'#011941'}

for idx,fea in enumerate(['child', 'schoolage', 'adult', 'senior']):
    node_position = go.Scatter(
        # node position
        visible=True,
        x=node_pos[enriched_centroides[fea], 0],
        y=node_pos[enriched_centroides[fea], 1],
        hoverinfo="text",
        marker=dict(#color=node_colors,
                    color=color_tmp[fea],
                    size=[5 + sizes[_] for _ in enriched_centroides[fea]],
                    opacity=0.7),
        showlegend=True,
        name = fea,
        mode="markers")
    fig.append_trace(node_position, 1,1)
fig.layout.width = 1200
fig.layout.height = 1000
fig.layout.font.size = 16
plotly.offline.iplot(fig)

In [None]:
# BMI dummy grouped

node_pos = graph.nodePos
nodes = graph.node
sizes = graph.size

enriched_centroides = get_significant_nodes(graph=graph,
                                 safe_scores=enriched_SAFE_total,
                                 pvalue=0.05,
                                 n_iter=5000
                                 )
xs = []
ys = []
for edge in graph.edges:
    xs += [node_pos[edge[0], 0],
           node_pos[edge[1], 0],
           None]
    ys += [node_pos[edge[0], 1],
           node_pos[edge[1], 1],
           None]
fig = plotly.tools.make_subplots(1,1)


node_line = go.Scatter(
    # ordination line
    visible=True,
    x=xs,
    y=ys,
    marker=dict(color="#8E9DA2",
                opacity=0.7),
    line=dict(width=1),
    showlegend=False,
    mode="lines")
fig.append_trace(node_line, 1, 1)

color_tmp = {'underweight':'#ff4040',
            'normal':'#33d2ff',
            'overweight':'#fcba03',
            'obese':'#011941'}

for idx,fea in enumerate(['underweight', 'normal', 'overweight', 'obese']):
    node_position = go.Scatter(
        # node position
        visible=True,
        x=node_pos[enriched_centroides[fea], 0],
        y=node_pos[enriched_centroides[fea], 1],
        hoverinfo="text",
        marker=dict(#color=node_colors,
                    color=color_tmp[fea],
                    size=[5 + sizes[_] for _ in enriched_centroides[fea]],
                    opacity=0.7),
        showlegend=True,
        name = fea,
        mode="markers")
    fig.append_trace(node_position, 1,1)
fig.layout.width = 1200
fig.layout.height = 1000
fig.layout.font.size = 16
plotly.offline.iplot(fig)

In [None]:
### TAXA ###

## Stratification of taxa with most enriched nodes
node_pos = graph.nodePos
nodes = graph.node
sizes = graph.size

xs = []
ys = []
for edge in graph.edges:
    xs += [node_pos[edge[0], 0],
           node_pos[edge[1], 0],
           None]
    ys += [node_pos[edge[0], 1],
           node_pos[edge[1], 1],
           None]
fig = plotly.subplots.make_subplots(1, 1)

node_line = go.Scatter(
    # ordination line
    visible=True,
    x=xs,
    y=ys,
    marker=dict(color="#8E9DA2",
                opacity=0.7),
    line=dict(width=1),
    showlegend=False,
    mode="lines")
fig.append_trace(node_line, 1, 1)

# setting the threshold for what constitutes a significantly enriched node
safe_score_df = pd.DataFrame.from_dict(enriched_SAFE)
min_p_value = 1.0 / (5000 + 1.0)
threshold = np.log10(0.05) / np.log10(min_p_value)

# listing all the significant nodes, counting them, and sorting them 
tmp = [safe_score_df.columns[_] if safe_score_df.iloc[idx, _] > threshold else np.nan for idx, _ in enumerate(np.argmax(safe_score_df.values, axis=1))]
t = Counter(tmp)
t_sort = sorted(t.items(), key=lambda x: x[1], reverse = True)

# specifiying the colours for the taxa
clr = {0:'#ff4040',  1:'#33d2ff', 2:'#fcba03', 3:'#011941', 4:'#41a941', 5:'#9e76c3', 6:'#0053C7', 7:'#C70039',
       8: '#6C3483', 9: '#E8DAEF', 10: '#1B5E20', 11: '#8D6E63', 12: '#FF8D15', 13: '#FFF59D', 14: '#66FF33',
       15: '#0000FF', 16: '#CC6600', 17: '#660066', 18: '#00897B'
      }

# visualisation
s_nod = 25
for idx, fea in enumerate([_ for _, v in t_sort if v >= s_nod]): 
# v denotes the number of notes, here I include all the taxa with more than 15 significantly enriched nodes
    # safe higher than threshold, just centroides
    node_position = go.Scatter(
        # node position
        visible=True,
        x=node_pos[np.array(tmp) == fea, 0],
        y=node_pos[np.array(tmp) == fea, 1],
        hoverinfo="text",
        marker=dict(  # color=node_colors,
            color = clr[idx],
            size=[5 + sizes[_] for _ in np.arange(node_pos.shape[0])[np.array(tmp) == fea]],
            opacity=0.9),
        showlegend=True,
        name=fea + ' (%s)' % str(t[fea]),
        mode="markers")
    fig.append_trace(node_position, 1, 1)
fig.layout.width = 1300
fig.layout.height = 1000
fig.layout.font.size = 18
fig.layout.hovermode = 'closest'
plotly.offline.iplot(fig)

In [None]:
## Ranking the taxa according to network-wide enrichment
taxa_sort = safe_summary.sort_values(by='SAFE enriched score', ascending=True)
taxa_sort = taxa_sort[-40:]
taxa_bar = px.bar(taxa_sort, x='SAFE enriched score', y=taxa_sort.index, color='SAFE enriched score', 
                  orientation = 'h', height = 1000, text='SAFE enriched score')
taxa_bar.update_traces(texttemplate='%{text:.3s}', textposition='inside')
plotly.offline.plot(taxa_bar, auto_open = True)
taxa_bar.show()

In [None]:
## Collecting the top 10 data
taxa_desc = taxa_sort.sort_values(by='SAFE enriched score', ascending=False)
top = taxa_desc.iloc[:11]
top_taxa = top.index.values
for taxa in top_taxa:
    taxa_visual(taxa)

In [None]:
top_taxa = list(top_taxa)

In [None]:
# adding in the taxa with most significantly enriched nodes to get all taxa of interest
for idx, fea in enumerate([_ for _, v in t_sort if v >= s_nod]):
    if fea not in top_taxa:
        top_taxa.append(fea)

top_taxa

In [None]:
"""CO-ENRICHMENT ANALYSIS"""

from tmap.netx.SAFE import get_significant_nodes
from tmap.netx.coenrichment_analysis import *

enriched_centroides, enriched_nodes = get_significant_nodes(graph=graph,
                                                            safe_scores=enriched_SAFE_total,
                                                            n_iter=n_iter,
                                                            pvalue=0.05,
                                                            r_neighbor=True
                                                            )
corrected_fe_dis = pairwise_coenrichment(graph,
                                         safe_scores=enriched_SAFE_total,
                                         _pre_cal_enriched=enriched_centroides)

corrected_fe_dis.to_csv('raw_coenrichment_GMrepo_TDA_PD.csv')


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
values = pd.read_csv('raw_coenrichment_GMrepo_TDA_PD.csv', index_col=0)
values

In [None]:
metadata_of_interest = [ 'high severity',
 'low severity',
 '% over 70',
 'GDP',
 'HDI',
 'female',
 'male',
 'normal',
 'obese',
 'overweight',
 'underweight',
 'adult',
 'child',
 'schoolage',
 'senior']

In [None]:
### constructing a heatmap co-enrichment matrix with raw data
# defining variables of interest
voi = metadata_of_interest + top_taxa

df_voi = values[voi]
df_voi = df_voi.loc[df_voi.columns.to_list()]

In [None]:
mask = np.triu(np.ones_like(df_voi, dtype=np.bool))

fig, ax = plt.subplots(figsize=(30, 22))
sns.heatmap(df_voi, mask=mask, cmap = 'GnBu_r')
sns.set(font_scale=2)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90, horizontalalignment='center')
#plt.savefig('heatmap_voi.png')
plt.show()

In [None]:
#df = corrected_fe_dis.copy()
df = values

# get rid of diagonal by making it 1
df = df.replace([0],1)

# threshold and binarise
threshold = np.percentile(df.to_numpy(), 0.5)
df = df.applymap(lambda x: 1 if x<=threshold else 0)

df.to_csv('binarised_coenrichment_GMrepo_TDA_PD.csv')
print(threshold)

In [None]:
## print all significant combinations
from itertools import combinations

combos = []

for combo in combinations(voi, 2):  # 2 for pairs, 3 for triplets, etc
    if df.loc[combo[0], combo[1]]==1:
        print(combo)
        combos.append(combo)
        
print(len(combos))

In [None]:
## all co-enrichments of PD groups (including non-significant)
for metadata in df.columns:
    if df.loc["high severity", metadata]==1:
        print("High + ", metadata)

for metadata in df.columns:
    if df.loc["low severity", metadata]==1:
        print("Low + ", metadata)
