# Settings

In [None]:
import math
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rcParams
from statannot import add_stat_annotation
from scipy.stats import pearsonr, mannwhitneyu

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [None]:
rcParams['font.size'] = 12
rcParams['figure.figsize'] = (9, 6)
rcParams['savefig.dpi'] = 200
rcParams['savefig.format'] = 'png'

In [None]:
colors = {
    'MED diet': 'cornflowerblue', 
    'PPT diet': 'orange',
    'Gut': 'lightpink',
    'Oral': 'mediumseagreen'}

In [None]:
diet_order = ['PPT diet', 'MED diet']
time_point_order = ['Pre-intervention', 'Post-intervention']
env_order = ['Oral', 'Gut']

In [None]:
data_types = ['gut species', 'oral species', 'metabolites', 'cytokines']

In [None]:
alpha = 0.05

# Data

In [None]:
data = {}

index = ['Diet', 'Participant ID', 'Time Point']

s1 = 'Supplementary file 1 - Diet microbiome metabolites and cytokines data.xlsx'
for data_type in ['diet']+data_types:
    header = 12 if data_type == 'metabolites' else 0
    data[data_type] = pd.read_excel(s1, sheet_name=data_type, header=header).set_index(index)

    s2 = 'Supplementary file 2 - Microbiome metabolites and cytokines statistical tests.xlsx'
for data_type in data_types:
    for diet in diet_order:
        data[f'{data_type} {diet}'] = pd.read_excel(s2, sheet_name=f'{data_type} {diet}').set_index('feature')
        
s3 = 'Supplementary file 3 - Serum metabolites predicted by the gut microbiome composition.xlsx'
data_type = 'metabolites prediction'
data[data_type] = pd.read_excel(s3, sheet_name=data_type, header=[0, 1]).T.reset_index(0).T.set_index(index)

index = ['Species', 'Participant ID']

s4 = 'Supplementary file 4 - Oral and gut microbial strains data.xlsx'
for env in env_order:
    data_type = f'{env.lower()} strains'
    data[data_type] = pd.read_excel(s4, sheet_name=data_type).set_index(index)

In [None]:
for key in data.keys():
    print(key, data[key].shape)

# Figure 2

In [None]:
df = data['diet'].reset_index()

In [None]:
for time_point in time_point_order:
    
    plt.figure()
    ax = sns.scatterplot(x='% Carbohydrates', y='% Lipids', hue='Diet', hue_order=diet_order, data=df[df['Time Point'] == time_point], palette=colors, alpha=0.5, s=100)

    plt.legend(title=False, loc='upper right', frameon=True)

    plt.title('Intervention' if time_point == 'Post-intervention' else time_point)
    plt.xlabel('% Carbohydrates in diet')
    plt.ylabel('% Lipids in diet')
    plt.xlim([9, 55])
    plt.ylim([15, 75])
    
    plt.text(x=0, y=1.03, s='A' if time_point == 'Pre-intervention' else 'B', transform=ax.transAxes, size=20, weight='bold')   

    plt.savefig(f'Figure 2{"A" if time_point == "Pre-intervention" else "B"}')

# Figure 3

In [None]:
df = pd.DataFrame(columns=['data_type', 'diet', 'n_tested', 'n_sig'])

for data_type in data_types:
    for diet in diet_order:
        key = f'{data_type} {diet}'
        col = 'p_FDR' if data_type == 'cytokines' else 'p_bonferroni'
        df.loc[df.shape[0]] = data_type[0].upper()+data_type[1:], diet, data[key].shape[0], (data[key][col] < alpha).sum()
        
df['%'] = 100*df['n_sig']/df['n_tested']

In [None]:
sns.barplot(x='data_type', y='%', hue='diet', hue_order=diet_order, data=df, palette=colors)

plt.legend(title=None, loc='upper right')

plt.xlabel('')
plt.ylabel('% Signficant features')

ticks, labels = plt.xticks()
for t in ticks:
    data_type = labels[t].get_text()
    text_df = df.loc[(df['data_type'] == data_type) & (df['diet'] == diet_order[0])].iloc[0]
    plt.text(x=t-0.2, y=text_df['%'], s=f'n={text_df["n_tested"]:,}', ha='center')
    text_df = df.loc[(df['data_type'] == data_type) & (df['diet'] == diet_order[1])].iloc[0]
    plt.text(x=t+0.2, y=text_df['%'], s=f'n={text_df["n_tested"]:,}', ha='center')
    
plt.savefig('Figure 3')

# Figure 4

In [None]:
def get_delta(df):
   
    pre = df.xs('Pre-intervention', level='Time Point')
    post = df.xs('Post-intervention', level='Time Point')

    return post.subtract(pre)

In [None]:
def get_data(r):

    preds = get_delta(data['metabolites prediction'].iloc[1:].astype(float))
    actuals = get_delta(data['metabolites']).loc[preds.index, preds.columns]

    df = actuals.groupby('Diet').mean().melt(ignore_index=False).set_index('variable', append=True).rename(columns={'value': 'observed change'}).join(
         preds.groupby('Diet').mean().melt(ignore_index=False).set_index('variable', append=True).rename(columns={'value': 'predicted change'}))

    df = df.join(data['metabolites prediction'].iloc[0].T.to_frame('R2').astype(float).dropna(), on='variable')
    df = df.loc[df['R2'] < -r] if r < 0 else df.loc[df['R2'] > r]
    df = df.reset_index()

    return df

In [None]:
for r2 in [-0.05, 0.05]:
    
    df = get_data(r2)
    
    plt.figure()
    size = 'R2' if r2 > 0 else None
    ax = sns.scatterplot(x='observed change', y='predicted change', hue='Diet', hue_order=diet_order, data=df, s=100, size=size, sizes=(50 ,300), palette=colors, alpha=0.5)

    pearsonr(df['observed change'], df['predicted change'])

    l = plt.legend(title=False, loc='lower right', frameon=True)
    for handle in l.legendHandles:
        diet = handle.get_label()
        if diet in diet_order:
            r, p = pearsonr(df.loc[df['Diet'] == diet, 'observed change'], df.loc[df['Diet'] == diet, 'predicted change'])
            r, p
            handle.set_label(f'{diet}\nr={r:.2f}, ' + (f'p<1e{math.floor(math.log10(p))+1}' if p < 0.01 else f'p={p:.2f}'))
    handles = l.legendHandles[1:] if len(l.legendHandles) > 5 else l.legendHandles
    plt.legend(handles=handles, title=False, loc='upper left', frameon=True)
        
    plt.xlabel('Mean observed change per metabolite')
    plt.ylabel('Mean predicted change per metabolite')

    plt.text(x=0, y=1.03, s='A' if r2 > 0 else 'B', transform=ax.transAxes, size=20, weight='bold')   
    
    plt.title(f'{"Poorly" if r2 < 0 else "Well"} predicted metabolites (R2{"<" if r2 < 0 else ">"}{abs(r2)})')
    plt.savefig(f'Figure 4{"A" if r2 > 0 else "B"}')

# Figure 5

In [None]:
def get_data(b):
    
    df = pd.concat([data['gut strains'].assign(env='Gut'), data['oral strains'].assign(env='Oral')])
    df['Strain replacement'] = df['Strain replacement'].map({True: 1, False: 0})
    df = df.groupby([b, 'env'])['Strain replacement'].apply(lambda g: (g.shape[0], g.mean()*100)).to_frame()
    df['n'], df['%'] = df.iloc[:, 0].str
    df['n_bin'] = pd.cut(df['n'], bins=np.arange(0, df['n'].max(), 18 if b == 'Participant ID' else 30))
    df = df.reset_index()
    
    return df

In [None]:
for by in ['Participant ID', 'Species']:

    df = get_data(by)

    plt.figure()
    ax = sns.histplot(df, x='%', hue='env', hue_order=env_order, palette=colors, kde=False, element='step', binwidth=10, alpha=0.3, common_norm=False, stat='percent')

    l = ax.legend(labels=env_order)
    ax.legend(labels=env_order, handles=l.legendHandles[::-1], loc='upper right', frameon=True)
    
    plt.xlabel(f'% Strain repalcements per {by.replace("Participant ID", "participant").lower()}')
    plt.ylabel(f'% {"Species" if by == "Participant ID" else "Participants"}')
    plt.xlim([0, 100])
    plt.ylim([0, 100])

    _, p = mannwhitneyu(x=df.loc[df['env'] == env_order[0], '%'].tolist(),
                        y=df.loc[df['env'] == env_order[1], '%'].tolist(),
                        use_continuity=True, alternative='two-sided', axis=0, method='auto')
    plt.text(x=0.865, y=0.8, s='p'+r'$\leq$'+f'{p:.0e}', transform=ax.transAxes)

    plt.text(x=0, y=1.03, s='A' if by == 'Participant ID' else 'B', transform=ax.transAxes, size=20, weight='bold')

    plt.savefig(f'Figure 5{"A" if by == "Participant ID" else "B"}')

In [None]:
for by in ['Participant ID', 'Species']:
    
    df = get_data(by)
    
    plt.figure()
    ax = sns.boxplot(x='n_bin', y='%', hue='env', hue_order=env_order, data=df, palette=colors, boxprops={'alpha': 0.7}, fliersize=0)
    box_pairs = [((b, env_order[0]), (b, env_order[1])) for b in df['n_bin'].unique().dropna()]
    ax, test_results = add_stat_annotation(ax, x='n_bin', y='%', hue='env', data=df, box_pairs=box_pairs, test='Mann-Whitney', text_format='simple', comparisons_correction=None)
    sns.stripplot(x='n_bin', y='%', hue='env', hue_order=env_order, data=df, palette=colors, dodge=True, legend=False, color='lightgrey', s=8, alpha=0.3, ax=ax)

    corr = df.groupby('env').apply(lambda env: pearsonr(env['n'], env['%']))
    l = ax.legend(title=False, loc='upper right', frameon=True)
    for handle in l.legendHandles:
        env = handle.get_label()
        r, p = corr.loc[env]
        r, p
        handle.set_label(f'{env}\nr={r:.2f}, ' + (f'p<1e{math.floor(math.log10(p))+1}' if p < 0.01 else f'p={p:.2f}'))
    plt.legend(handles=l.legendHandles, title=False, loc='upper right', frameon=True)

    new_labels = [t.get_text().replace('(', '').replace(']', '').replace(', ', '-') for t in ax.get_xticklabels()]
    ax.set_xticklabels(new_labels)
    
    plt.xlabel(f'Number of {"species" if by == "Participant ID" else "participants"} available for comparison')
    plt.ylabel(f'% Strain repalcements per {by.replace("Participant ID", "participant").lower()}')
    
    plt.text(x=0, y=1.03, s='C' if by == 'Participant ID' else 'D', transform=ax.transAxes, size=20, weight='bold')
    
    plt.savefig(f'Figure 5{"C" if by == "Participant ID" else "D"}')