baseline model using presence/absence or abundance as features to predict age

Prior knowledge
1. diversity : increase in young adults, plateaus arond 40, dicrease in older adults
2. older adults decrease in beneficial bacteria and increase in harmful bacteria




De La Cuesta-Zuluaga, J., Kelley, S., Chen, Y., Escobar, J., Mueller, N., Ley, R., McDonald, D., Huang, S., Swafford, A., Knight, R., & Thackray, V. (2019). Age- and Sex-Dependent Patterns of Gut Microbial Diversity in Human Adults. mSystems, 4. https://doi.org/10.1128/mSystems.00261-19.

In [0]:
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split


from xgboost import XGBRegressor
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns



## Load data

In [0]:
df_age = pd.read_csv("../data/age.csv", header=None, index_col=0, sep='\t')
df_age = df_age[df_age[1]>=18]  # todo remove after test 
y = df_age.to_numpy().reshape(-1, 1).flatten()

y.shape

In [0]:
y_class = y//10
y_class[y_class==9] = 8 

In [0]:
train_idx, test_idx = train_test_split(range(len(y)), test_size=0.2, stratify=y_class, random_state=42)  # split the data once so that index keeps the same for different types of X
test_idx[:10]

age >= 18 : [2140, 2182, 2987, 2103, 5, 3370, 3208, 2281, 41, 4598]

[4436, 2292, 4448, 4903, 2378, 842, 2625, 3097, 4898, 1911]

In [0]:
# X = pd.read_csv("../data/processed_log_drop08_scaled.csv", header=0, index_col=0, sep='\t').loc[df_age.index, :].to_numpy()
df_genus = pd.read_csv("../data/processed_genus_log_drop08_scaled.csv", header=0, index_col=0, sep='\t').loc[df_age.index, :]
X = df_genus.to_numpy()

In [0]:
# from scipy.stats import spearmanr
# correlations = df_genus.apply(lambda x: spearmanr(x, y)[0])
# correlations_sorted = correlations.dropna().sort_values()

# plt.figure(figsize=(8, 12))
# correlations_sorted.plot(kind='barh')
# plt.xlabel('Spearman Correlation with Age')
# plt.ylabel('Genus')
# plt.title('Genus by their Correlation to Age')
# plt.yticks(fontsize=6);

In [0]:
X_presence = X>0
X_train, X_test, y_train, y_test = X_presence[train_idx], X_presence[test_idx], y[train_idx], y[test_idx]
xgb_model = XGBRegressor(n_estimators=200, learning_rate=0.05, max_depth=4)
xgb_model.fit(X_train, y_train)
y_pred = xgb_model.predict(X_test)
# Plot y_pred vs y_test with text for R-squared and MSE
plt.figure(figsize=(6, 4.5))
plt.scatter(y_test, y_pred, alpha=0.5)
plt.xlabel('Actual Age')
plt.ylabel('Predicted Age')
plt.title('Actual Age vs Predicted Age')
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2)

# Calculate R-squared and MSE
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

# Add text for R-squared and MSE
plt.text(0.05, 0.95, f'R2: {r2:.2f}\nMSE: {mse:.2f}', transform=plt.gca().transAxes, 
         fontsize=12, verticalalignment='top', bbox=dict(boxstyle='round,pad=0.3', edgecolor='black', facecolor='white'))

plt.tight_layout()
plt.show()

In [0]:
import shap
explainer = shap.Explainer(xgb_model, X_train, feature_names=df_genus.columns)
shap_values = explainer(X_test)
shap.summary_plot(shap_values, X_test)

In [0]:
# correlations_sorted.tail(20)

In [0]:
# correlations_sorted.head(20)

In [0]:
# get the proportion of zeros in X
p_zero = (X_train == 0).mean()
print(p_zero)
print(p_zero/(1-p_zero))
print(1/(1-p_zero))
print(1/p_zero)

In [0]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
# import umap

def plot_latent_space(latent, method='tsne', colorby = None, cbar_label=None):
    if method == 'tsne':
        reducer = TSNE(n_components=2, random_state=42)
    elif method == 'pca':
        reducer = PCA(n_components=2)
    elif method == 'umap':
        reducer = umap.UMAP(n_components=2, random_state=42)
    else:
        raise ValueError("Invalid method")

    reduced = reducer.fit_transform(latent)
    
    scatter = plt.scatter(reduced[:, 0], reduced[:, 1], 
                          c=colorby, cmap='viridis',
                          s=10, alpha=0.6, edgecolors='w', linewidths=0.5)
    plt.colorbar(scatter, label=cbar_label)
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')

In [0]:
non_zero_counts = (X_test != 0).sum(axis=1)

plt.figure(figsize=(15, 4))

plt.subplot(1, 2, 1)
plot_latent_space(X_test, method='pca', colorby=non_zero_counts, cbar_label='Number of non-zero Features')
plt.title('PCA')
plt.subplot(122)
plot_latent_space(X_test, method='tsne', colorby=non_zero_counts, cbar_label='Number of non-zero Features')
plt.title('TSNE')
plt.tight_layout()

In [0]:
plt.figure(figsize=(15, 4))

plt.subplot(1, 2, 1)
plot_latent_space(X_test, method='pca', colorby=y_test, cbar_label='Number of non-zero Features')
plt.title('PCA')
plt.subplot(122)
plot_latent_space(X_test, method='tsne', colorby=y_test, cbar_label='Number of non-zero Features')
plt.title('TSNE')
plt.tight_layout()