# Settings

In [None]:
import math
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import rcParams
import plotly.graph_objects as go
from statannot import add_stat_annotation
from scipy.stats import pearsonr, mannwhitneyu

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [None]:
rcParams['font.size'] = 12
rcParams['figure.figsize'] = (9, 6)
rcParams['savefig.dpi'] = 300
rcParams['savefig.format'] = 'png'

In [None]:
colors = {
    'MED diet': 'cornflowerblue', 
    'PPT diet': 'orange',
    'Gut': 'lightpink',
    'Oral': 'mediumseagreen'}

In [None]:
diet_order = ['PPT diet', 'MED diet']
time_point_order = ['Pre-intervention', 'Post-intervention']
env_order = ['Oral', 'Gut']
mediation_order = ['diet-species-glucose', 'diet-species-serum', 
                   'diet-pathways-glucose', 'diet-pathways-serum', 
                   'diet-strains-glucose', 'diet-strains-serum']

In [None]:
data_types = ['diet', 'gut species', 'gut pathways', 'oral species', 'oral pathways', 'metabolites', 'cytokines'] # the order that they are introduced in the paper

In [None]:
alpha = 0.05

In [None]:
def get_delta(df):
   
    pre = df.xs('Pre-intervention', level='Time Point')
    post = df.xs('Post-intervention', level='Time Point')

    return post.subtract(pre)

# Data

In [None]:
data = {}

s1 = 'Supplementary file 1 - Source data.xlsx'
index = ['Diet', 'Participant ID', 'Time Point']
for data_type in data_types:
    header = 12 if data_type == 'metabolites' else 0
    data[data_type] = pd.read_excel(s1, sheet_name=data_type, header=header).set_index(index)

s2 = 'Supplementary file 2 - Statistical tests.xlsx'
index = ['Diet', 'Participant ID', 'Time Point']
for data_type in data_types:
    for diet in diet_order:
        data[f'{data_type} {diet}'] = pd.read_excel(s2, sheet_name=f'{data_type} {diet}').set_index('feature')

s3 = 'Supplementary file 3 - Mediation analyses.xlsx'
index = ['path', 'x', 'm', 'y']
for med in mediation_order:
    suffix = env_order if 'strains' in med else diet_order
    for suf in suffix:
        data[f'{med} {suf}'] = pd.read_excel(s3, sheet_name=f'{med} {suf.split(" ")[0]}', index_col=[0, 1, 2, 3])
        
s4 = 'Supplementary file 4 - Metabolites prediction by the microbiome.xlsx'
index = ['Diet', 'Participant ID', 'Time Point']
data_type = 'metabolites prediction'
data[data_type] = pd.read_excel(s4, sheet_name=data_type, header=[0, 1]).T.reset_index(0).T.set_index(index)

s5 = 'Supplementary file 5 - Oral and gut microbial strains.xlsx'
index = ['Species', 'Participant ID', 'Diet']
for env in env_order:
    data_type = f'{env.lower()} strains'
    data[data_type] = pd.read_excel(s5, sheet_name=data_type).set_index(index)

# Figure 2 - Diet

In [None]:
df = data['diet'].reset_index()

In [None]:
for time_point in time_point_order:
    
    plt.figure()
    ax = sns.scatterplot(x='% Carbohydrates', y='% Lipids', hue='Diet', hue_order=diet_order, data=df[df['Time Point'] == time_point], palette=colors, alpha=0.5, s=100)

    plt.legend(title=False, loc='upper right', frameon=True)

    plt.title('Intervention' if time_point == 'Post-intervention' else time_point)
    plt.xlabel('% Carbohydrates in diet')
    plt.ylabel('% Lipids in diet')
    plt.xlim([9, 55])
    plt.ylim([15, 75])
    
    plt.text(x=0, y=1.03, s='a' if time_point == 'Pre-intervention' else 'b', transform=ax.transAxes, size=20, weight='bold')   

    plt.savefig(f'Figure 2{"a" if time_point == "Pre-intervention" else "b"}')

# Figure 3 - Features

In [None]:
def get_data(data_type):
    
    for i, diet in enumerate(diet_order):
        key = f'{data_type} {diet}'
        col = 'p_FDR' if data_type == 'cytokines' else 'p_bonferroni'
        new = get_delta(data[data_type].loc[diet, data[key].index[data[key][col] < alpha]]).mean().to_frame(diet)
        df = new.copy() if i == 0 else df.join(new, how='outer').fillna(0)
        
    return df

In [None]:
metabolite2super_pathway = pd.read_excel(s1, sheet_name='metabolites')
metabolite2super_pathway = metabolite2super_pathway.set_index('PATHWAY_SORTORDER').loc[['Diet', 'SUPER_PATHWAY']].T.set_index('Diet').fillna('unknown').iloc[2:].replace('Partially Characterized Molecules', 'Partially Characterized').to_dict()['SUPER_PATHWAY']

In [None]:
for data_type in data_types:
    
    df = get_data(data_type)
        
    if df.shape[0] < 10:
        continue
              
    radius = np.log10(df.shape[0])
    step = radius/8
    
    cmap = plt.get_cmap('bwr')
    norm = mpl.colors.Normalize(vmin=-df.abs().max().max() if data_type != 'diet' else -10, vmax=df.abs().max().max() if data_type != 'diet' else 10)
    
    if data_type == 'diet':
        df['Level'] = 'Food category'
        df.loc[df.index.str.contains('%'), 'Level'] = 'Macro-nutrient'
        df.loc[df.index.str.contains(']'), 'Level'] = 'Micro-nutrient'
        # sort
        df = df.loc[[idx for idx in ['% Carbohydrates', '% Lipids', '% Proteins', '% Daily caloric target', '% Fibers', 
                                     '% Mono un-saturated fatty acids', '% Poly un-saturated fatty acids', '% Saturated fatty acids']
                         if idx in df.index] + 
                     sorted([idx for idx in df.index if '[' in idx]) + 
                     sorted([idx for idx in df.index if '%' not in idx and '[' not in idx])]
    elif 'species' in data_type:
        df['Family'] = df.index.str.split('|').str[-6].str.replace('f__', '').str.replace('_unclassified', '').str.replace('unknown', 'Unclassified')
        df.index = df.index.str.split('|').str[-4].str.replace('s__', '').str.replace('_', ' ').str.replace('unknown', 'Unclassified') + ' (' + df.index.str.split('|').str[-1].str.replace('sSGB__', 'SGB_') + ')'
        df.index = ['Unclassified ' + index.split(' ')[-1] if 'sp' in index or 'CAG' in index else index for index in df.index]
        df = df.sort_values('Family')
    elif data_type == 'metabolites':
        df['Super Pathway'] = df.index.map(metabolite2super_pathway).str.replace('unknown', 'Uncharacterized')
        df = df.sort_values('Super Pathway')
    elif 'pathways' in data_type:
        super_class = {  # curated manually from MetCyc
            'ARGSYNBSUB-PWY: L-arginine biosynthesis II (acetyl cycle)' : 'Amino Acid Biosynthesis',
            'ARGSYN-PWY: L-arginine biosynthesis I (via L-ornithine)': 'Amino Acid Biosynthesis',
            'GLUTORN-PWY: L-ornithine biosynthesis I': 'Amino Acid Biosynthesis',
            'PWY-6292: superpathway of L-cysteine biosynthesis (mammalian)': 'Amino Acid Biosynthesis',
            'PWY-6507: 4-deoxy-L-threo-hex-4-enopyranuronate degradation': 'Sugar Degradation',
            'GALACTUROCAT-PWY: D-galacturonate degradation I': 'Sugar Degradation',
            'PWY-7356: thiamine diphosphate salvage IV (yeast)': 'Cofactor&Vitamin Biosynthesis',#'Cofactor, Carrier, and Vitamin Biosynthesis',
            'GLUCUROCAT-PWY: superpathway of &beta;-D-glucuronosides degradation': 'Sugar Degradation',
            'PWY-7242: D-fructuronate degradation': 'Sugar Degradation',
            'PWY-7456: &beta;-(1,4)-mannan degradation': 'Polysaccharide Degradation',
            'PWY66-399: gluconeogenesis III': 'Sugar Biosynthesis',
            'PWY-7383: anaerobic energy metabolism (invertebrates, cytosol)': 'zOther',#'Fermentation',
            'PWY490-3: nitrate reduction VI (assimilatory)': 'zOther',#'Inorganic Nutrient Metabolism',
            'PWY-6305: superpathway of putrescine biosynthesis': 'zOther',#'Amide, Amidine, Amine, and Polyamine Biosynthesis',
            'P124-PWY: Bifidobacterium shunt': 'Sugar Degradation',
            'PWY-5941: glycogen degradation II': 'Polysaccharide Degradation',
            'PWY-7238: sucrose biosynthesis II': 'Sugar Biosynthesis',
            'PWY0-1296: purine ribonucleosides degradation': 'zOther',#'Nucleoside and Nucleotide Degradation',
            'PWY-6147: 6-hydroxymethyl-dihydropterin diphosphate biosynthesis I': 'zOther',#'Other Biosynthesis',
            'PWY-6549: L-glutamine biosynthesis III': 'zOther',#'Transport',
            'PWY-6703: preQ0 biosynthesis': 'zOther',#'Secondary Metabolite Biosynthesis',
            'PWY-6823: molybdopterin biosynthesis': 'Cofactor&Vitamin Biosynthesis',#'Cofactor, Carrier, and Vitamin Biosynthesis',
            'PWY-241: C4 photosynthetic carbon assimilation cycle, NADP-ME type': 'Carbon Fixation',#'Generation of Precursor Metabolites and Energy',
            'PWY-7115: C4 photosynthetic carbon assimilation cycle, NAD-ME type': 'Carbon Fixation',#'Generation of Precursor Metabolites and Energy',
            'PWY-7117: C4 photosynthetic carbon assimilation cycle, PEPCK type': 'Carbon Fixation'}#'Generation of Precursor Metabolites and Energy'}
        df['Super Class'] = df.index.map(super_class)
        df = df.sort_values('Super Class').replace('zOther', 'Other')
        
    df.loc[' '] = ''
    df.loc['Diet'] = df.columns
    df.loc['   '] = ''

    fig, ax = plt.subplots()

    for i, diet in enumerate(df.columns):
        if diet in diet_order:
            c = cmap(norm(df[diet_order[i]][:-3].astype(float))).tolist() + ['white', colors[diet], 'white']
        else:
            cmap = mpl.cm.Pastel1
            norm = mpl.colors.Normalize(vmin=0, vmax=8)
            c = {cat: cmap(norm(i)) for i, cat in enumerate(df[diet][:-3].unique())}
            c = [c[v] for v in df[diet][:-3]] + ['white', 'white', 'white']
        
        pie, labels = ax.pie([1 for i in range(df.shape[0])], colors=c,
                             radius=radius-i*step, startangle=0,
                             labels=df[diet_order[i]].index if i == 0 else None, rotatelabels=True, labeldistance=1,
                             wedgeprops={'width':step, 'edgecolor':'w'})

        _ = plt.setp(pie, width=step, edgecolor='white')
        
    if diet not in diet_order:
        index = df[diet][:-3].reset_index().index[~df[diet][:-3].duplicated()]
        legend = ax.legend([pie[i] for i in index], [df[diet][i] for i in index], 
                           title=diet, loc='center', bbox_to_anchor=(0., 0, 1, 1.), frameon=False)
    
    extra = {'diet':        {'x': -0.40, 'y': 1.55, 'l': 'a', 'title': '     Diet'},
             'gut species': {'x': -0.78, 'y': 1.45, 'l': 'b', 'title': '     Gut species'},
             'gut pathways': {'x': -1.20, 'y': 2.20, 'l': 'c', 'title': '     Gut pathways'},
             'metabolites': {'x': -1.20, 'y': 2.20, 'l': 'd', 'title': '     Metabolites'}}
    
    _ = plt.text(x=extra[data_type]['x'], y=extra[data_type]['y'], s=extra[data_type]['l'], transform=ax.transAxes, size=20, weight='bold')   
    _ = plt.text(x=extra[data_type]['x'], y=extra[data_type]['y'], s=extra[data_type]['title'], transform=ax.transAxes)   
    
    plt.savefig(f'Figure 3{extra[data_type]["l"]}', bbox_inches='tight')

# Figure 4, S2, S3 - Mediation

In [None]:
letters = 'abcd'
l = 0

for med in mediation_order:

    suffix = env_order[::-1] if 'strains' in med else diet_order
    for suf in suffix:
        
        df = data[f'{med} {suf}'].copy()
        df = df.groupby(['x', 'm', 'y']).filter(lambda g: (g.loc[['M ~ X', 'Indirect'], 'p'] < alpha).all())
        if df.empty:
            continue
        df = df.loc['Indirect'].reset_index(['x', 'm', 'y'])
        df.shape

        label2int = {label: i for i, label in enumerate(list(df['x'].unique()) + list(df['m'].unique()) + list(df['y'].unique()))}

        link = dict(source=pd.concat([df['x'], df['m']]).replace(label2int).values, 
                    
                    target=pd.concat([df['m'], df['y']]).replace(label2int).values,

                    value=[1]*df.shape[0]*2,

                    color=[f'rgba{mpl.cm.tab10(label2int[label]%10)}' for label in df['x']] + \
                          [f'rgba{mpl.cm.tab10(label2int[label]%10)}' for label in df['x']])

        node = dict(pad=35, thickness=20, 
                    label=list(label2int.keys()), 
                    color=[f'rgba{mpl.cm.tab10(label2int[label]%10)}' for label in df['x'].unique()]+['darkgrey']*(len(df['m'].unique())+len(df['y'].unique())))
               
        fig = go.Figure(go.Sankey(link=link, node=node))
        med_nice = med.replace('diet-', ('Diet - '+suf+' ') if 'strains' in med else (suf+' - Microbial ')).replace('-glucose', ' - Glycemic measurements').replace('-serum', ' - Metabolites and Cytokines')
        fig.update_layout(title=dict(text=f'{letters[l]}   {med_nice}', font=dict(color='Black')))
        figure = '4' if 'species' in med else ('S2' if 'pathways' in med else 'S3')
        try:
            fig.write_image(f'Figure {figure}{letters[l]}.png', width=1500)
        except:
            f'FAILED SAVING Figure {figure}{letters[l]}.png'
            
        l = 0 if (l >= 2) & (suf != suffix[0]) else l + 1

# Figure 5 - Prediction

In [None]:
def get_data(r):

    preds = get_delta(data['metabolites prediction'].iloc[1:].astype(float))
    actuals = get_delta(data['metabolites']).loc[preds.index, preds.columns]

    df = actuals.groupby('Diet').mean().melt(ignore_index=False).set_index('variable', append=True).rename(columns={'value': 'observed change'}).join(
         preds.groupby('Diet').mean().melt(ignore_index=False).set_index('variable', append=True).rename(columns={'value': 'predicted change'}))

    df = df.join(data['metabolites prediction'].iloc[0].T.to_frame('R2').astype(float).dropna(), on='variable')
    df = df.loc[df['R2'] < -r] if r < 0 else df.loc[df['R2'] > r]
    df = df.reset_index()

    return df

In [None]:
for r2 in [-0.05, 0.05]:
    
    df = get_data(r2)
    
    plt.figure()
    size = 'R2' if r2 > 0 else None
    ax = sns.scatterplot(x='observed change', y='predicted change', hue='Diet', hue_order=diet_order, data=df, s=100, size=size, sizes=(50 ,300), palette=colors, alpha=0.5)

    pearsonr(df['observed change'], df['predicted change'])

    l = plt.legend(title=False, loc='lower right', frameon=True)
    for handle in l.legendHandles:
        diet = handle.get_label()
        if diet in diet_order:
            r, p = pearsonr(df.loc[df['Diet'] == diet, 'observed change'], df.loc[df['Diet'] == diet, 'predicted change'])
            r, p
            handle.set_label(f'{diet}\nr={r:.2f}, ' + (f'p<1e{math.floor(math.log10(p))+1}' if p < 0.01 else f'p={p:.2f}'))
    handles = l.legendHandles[1:] if len(l.legendHandles) > 5 else l.legendHandles
    plt.legend(handles=handles, title=False, loc='upper left', frameon=True)
        
    plt.xlabel('Mean observed change')
    plt.ylabel('Mean predicted change')

    df_abs = df.copy()
    df_abs[['observed change', 'predicted change']] = df_abs[['observed change', 'predicted change']].abs()
    index = df_abs.sort_values('observed change').tail(3).index.union(df_abs.sort_values('predicted change').tail(3).index)
    if r2 > 0:
        index = index.union(df.sort_values('R2').tail(3).index)
        index = list(index) + [117]
    for i in index:
        s, x, y = df.loc[i, ['variable', 'observed change', 'predicted change']]
        if i == 856:
            va = 'top'
        elif i == 857:
            va = 'bottom'
        else:
            va = 'center'
        _ = plt.text(x=x, y=y, s=s, va=va)
    
    plt.text(x=0, y=1.03, s='a' if r2 > 0 else 'b', transform=ax.transAxes, size=20, weight='bold')   

    plt.title(f'{"Poorly" if r2 < 0 else "Well"} predicted metabolites (R2{"<" if r2 < 0 else ">"}{abs(r2)})')
    plt.savefig(f'Figure 5{"a" if r2 > 0 else "b"}')

# Figure 6 - Strains

In [None]:
hue = 'env'
hue_order = env_order

In [None]:
def get_data(b):
    
    df = pd.concat([data['gut strains'].assign(env='Gut'), data['oral strains'].assign(env='Oral')])
    df['Strain replacement'] = df['Strain replacement'].map({True: 1, False: 0})
    df = df.groupby([b, hue])['Strain replacement'].apply(lambda g: (g.shape[0], g.mean()*100)).to_frame()
    df['n'], df['%'] = df.iloc[:, 0].str
    df['n_bin'] = pd.cut(df['n'], bins=np.arange(0, df['n'].max(), 18 if b == 'Participant ID' else 30))
    df = df.reset_index()
    
    return df

In [None]:
for by in ['Participant ID', 'Species']:

    df = get_data(by)

    plt.figure()
    ax = sns.histplot(df, x='%', hue=hue, hue_order=hue_order, palette=colors, kde=False, element='step', binwidth=10, alpha=0.3, common_norm=False, stat='percent')

    l = ax.legend(labels=hue_order)
    ax.legend(labels=hue_order, handles=l.legendHandles[::-1], loc='upper right', frameon=True)
    
    plt.xlabel(f'% Strain repalcements per {by.replace("Participant ID", "participant").lower()}')
    plt.ylabel(f'% {"Species" if by == "Participant ID" else "Participants"}')
    plt.xlim([0, 100])
    plt.ylim([0, 100])

    _, p = mannwhitneyu(x=df.loc[df[hue] == hue_order[0], '%'].tolist(),
                        y=df.loc[df[hue] == hue_order[1], '%'].tolist(),
                        use_continuity=True, alternative='two-sided', axis=0, method='auto')
    plt.text(x=0.865, y=0.8, s='p'+r'$\leq$'+f'{p:.0e}', transform=ax.transAxes)

    plt.text(x=0, y=1.03, s='a' if by == 'Participant ID' else 'b', transform=ax.transAxes, size=20, weight='bold')

    plt.savefig(f'Figure 6{"a" if by == "Participant ID" else "b"}')

In [None]:
for by in ['Participant ID', 'Species']:
    
    df = get_data(by)
    
    plt.figure()
    ax = sns.boxplot(x='n_bin', y='%', hue=hue, hue_order=hue_order, data=df, palette=colors, boxprops={'alpha': 0.7}, fliersize=0)
    box_pairs = [((b, hue_order[0]), (b, hue_order[1])) for b in df['n_bin'].unique().dropna()]
    ax, test_results = add_stat_annotation(ax, x='n_bin', y='%', hue=hue, data=df, box_pairs=box_pairs, test='Mann-Whitney', text_format='simple', comparisons_correction=None)
    sns.stripplot(x='n_bin', y='%', hue=hue, hue_order=hue_order, data=df, palette=colors, dodge=True, legend=False, color='lightgrey', s=8, alpha=0.3, ax=ax)

    corr = df.groupby(hue).apply(lambda g: pearsonr(g['n'], g['%']))
    l = ax.legend(title=False, loc='upper right', frameon=True)
    for handle in l.legendHandles:
        h = handle.get_label()
        r, p = corr.loc[h]
        r, p
        handle.set_label(f'{h}\nr={r:.2f}, ' + (f'p<1e{math.floor(math.log10(p))+1}' if p < 0.01 else f'p={p:.2f}'))
    plt.legend(handles=l.legendHandles, title=False, loc='upper right', frameon=True)

    new_labels = [t.get_text().replace('(', '').replace(']', '').replace(', ', '-') for t in ax.get_xticklabels()]
    ax.set_xticklabels(new_labels)
    
    plt.xlabel(f'Number of {"species" if by == "Participant ID" else "participants"} available for comparison\n{"Environments richness" if by == "Participant ID" else "Species prevalence"}')
    plt.ylabel(f'% Strain repalcements per {by.replace("Participant ID", "participant").lower()}')
    
    plt.text(x=0, y=1.03, s='c' if by == 'Participant ID' else 'd', transform=ax.transAxes, size=20, weight='bold')
    
    plt.savefig(f'Figure 6{"c" if by == "Participant ID" else "d"}')

# Figure S1 - Features

In [None]:
def get_data(data_type, diet):
    
    col = 'p_FDR' if data_type == 'cytokines' else 'p_bonferroni'
    df = data[data_type].loc[diet][data[f'{data_type} {diet}'].index[data[f'{data_type} {diet}'][col] < alpha]]
    ps = data[f'{data_type} {diet}'].loc[data[f'{data_type} {diet}'][col] < alpha, col]
    
    if 'species' in data_type:
        df.columns = df.columns.str.split('|').str[-4].str.replace('s__', '').str.replace('_', ' ').str.replace('unknown', 'Unclassified') + ' (' + df.columns.str.split('|').str[-1].str.replace('sSGB__', 'SGB_') + ')'
        ps.index = ps.index.str.split('|').str[-4].str.replace('s__', '').str.replace('_', ' ').str.replace('unknown', 'Unclassified') + ' (' + ps.index.str.split('|').str[-1].str.replace('sSGB__', 'SGB_') + ')'        
    
    return df, ps

In [None]:
letters = 'abcdefghijklmnopqrstuvwxyz'
l = 0

for data_type in data_types:
    for diet in diet_order:
    
        df, ps = get_data(data_type, diet)
        if ps.empty:
            continue
        df = df.reset_index().sort_values('Time Point', ascending=False)

        ns = (~df[df['Time Point'] == 'Pre-intervention'].isna()).sum()
        
        fig, axes = plt.subplots(nrows=int(np.ceil(len(ps.index)/5)), ncols=5, figsize=(15, int(np.ceil(len(ps.index)/5)*5)))
        fig.subplots_adjust(wspace=0.8, hspace=0.4)
        
        for i, col in enumerate(ps.index):
            ax = axes[int(np.floor(i/5)), i%5] if len(ps.index) > 5 else axes[i%5]
            _ = sns.boxplot(x='Time Point', y=col, data=df, color='white', boxprops={'alpha': 0.7}, fliersize=0, ax=ax)
            _ = sns.stripplot(x='Time Point', y=col, data=df, dodge=True, legend=False, color='lightgrey', s=8, alpha=0.3, ax=ax)
            _ = ax.set_title(f'n={ns.loc[col]}\np={ps.loc[col]:.2e}', loc='left', ha='left')
            if np.floor(i/5)%2 == 1:
                _ = ax.set_ylabel(f'{col}\n')
            _ = ax.set_xticklabels([tp.split('-')[0] for tp in time_point_order])
            _ = ax.set_xlabel('')

        if len(ps.index) != 5:
            for i in np.arange(5-len(ps.index)%5)+1:
                axes[-1, -i].set_visible(False) if len(ps.index) > 5 else axes[-i].set_visible(False)

        if len(ps.index) > 5:
            ax = axes[0, 0]
        else:
            ax = axes[0]
        _ = ax.text(x=0, y=1.3, s=letters[l], transform=ax.transAxes, size=20, weight='bold')
        _ = ax.text(x=0, y=1.3, s=f'     {data_type[0].upper()}{data_type[1:]} - {diet}', transform=ax.transAxes, size=20)

        plt.savefig(f'Figure S1{letters[l]}', bbox_inches='tight')
        
        l = l+1