# Random Forest Analyses

In [7]:
# Standard library
import os
import warnings
import logging
from itertools import combinations
import string

# Scientific computing
import numpy as np
import pandas as pd
from numpy import array
import scipy
import scipy.stats as ss
from scipy import interp
from scipy.stats import wilcoxon, ttest_rel

# Visualization
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

# scikit-bio
from skbio.stats.distance import permanova

# BIOM format
import biom
from biom import load_table

# Scikit-learn
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, confusion_matrix, classification_report,
    roc_curve, auc, RocCurveDisplay
)
from sklearn.model_selection import GroupKFold, StratifiedKFold
from sklearn.preprocessing import label_binarize

# For confusion matrix
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from sklearn.metrics import classification_report, ConfusionMatrixDisplay

# Set overall styling for plots
sns.set_context("paper", font_scale=1.5)
sns.set_style("ticks")

## Random Forest Classification Tasks

In [8]:
# Read in table at ASV level
biom_path = '..//Data/Tables/Absolute_Abundance_Tables/209766_filtered_feature_table.biom'
# biom_path = '../Data/Tables/Absolute_Abundance_Tables/209766_filtered_by_prevalence_1pct_rare_Genus-ASV-non-collapse.biom'

biom_tbl = load_table(biom_path)
df = pd.DataFrame(biom_tbl.to_dataframe().T)

# delete the prefix from the index
df.index = df.index.str.replace('15564.', '')

# Convert to relative abundance by dividing each row by its row sum
df_dense = df.div(df.sum(axis=1), axis=0)
df = pd.DataFrame(df_dense.values, index=df_dense.index, columns=df_dense.columns)  # Force dense

df

Unnamed: 0,GTGCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCCGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCCGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTAATCGGAATTATTGGGCGTAAAGCGAGTGCAGACGGTTACTTAAGCCAGATGTGAAATCCCC,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGAATTATTGGGCGTAAAGCGCGCGCAGGCGGTTTCTTAAGTCTGATGTGAAAGCCCC,GTGCCAGCAGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATTTATTGGGCGTAAAGGGCTCGTAGGTGGTTGATCGCGTCGGAAGTGTAATCTTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTCCAAGCGTTAATCGGAATTACTGGGCGTAAAGCGTGCGCAGGCGGTTGTGCAAGACCGATGTGAAATCCCC,GTGCCAGCCGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGATTTATTGGGCGTAAAGGGAGCGCAGGTGGTTTCTTAAGTCTGATGTGAAAGCCCA,GTGCCAGCCGCCGCGGTAATACGGAAGGTCCAGGCGTTATCCGGATTTATTGGGTTTAAAGGGAGCGTAGGCGGATTATTAAGTCAGTGGTGAAAGACGG,...,GTGCCAGCCGCCGCGGTAATACGTAGGGGGCAAGCGTTATCCGGATTTACTGGGTGTAAAGGGAGCGTAGACGGCGCAGCAAGTCTGATGTGAAAGGCAG,GTGCCAGCAGCCGCGGTAAGACAGAGGGTGCAAACGTTGCTCGGAATCACTGGGCGTAAAGGGCGTGTAGGCGGGAGAGAAAGTCGGGCGTGAAATCCCT,GTGCCAGCCGCGGTAATACGTAGGGGGCTAGCGTTGTCCGGAATCACTGGGCGTAAAGGGTTCGCAGGCGGAAATGCAAGTCAGGTGTAAAAGGCAGTAG,GTGCCAGCAGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCCGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCAGCCGCGGTAATACGGAGGGTGCAAGCGTTATCCGGAATCATTGGGTTTAAAGGGTCCGCAGGCGGATTTATAAGTCAGTGGTGAAAGCCTA,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCGAGCGTTGTCCGGAATTACTGGGTGTAAAGGGCGTGTAGGCGGGAAGGTAAGTCAGATGTGAAATACCG,GTGCCAGCCGCCGCGGTAATACGGAGGATGCGAGCGTTATTCGGAATCATTGGGTTTAAAGGGTCTGTAGGCGGGCTATTAAGTCAGAGGTGAAAGGTTT,GTGCCAGCCGCCGCGGTAAGACGAAGGGGGCTAGCGTTGTTCGGAATTACTGGGCGTAAAGCGCGTGCAGGCGGTTATCCAAGTCGGGTGTGAAAGCCTT,GTCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTGT
900344,0.529602,0.328848,0.061356,0.044133,0.011841,0.008073,0.004306,0.004306,0.003229,0.001615,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900459,0.068445,0.061485,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900221,0.000744,0.000000,0.000000,0.000000,0.000000,0.000000,0.000541,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900570,0.058922,0.000000,0.000000,0.000000,0.001212,0.000000,0.001666,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900092,0.623945,0.342909,0.011852,0.006428,0.000603,0.000000,0.000000,0.000000,0.000000,0.001406,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9003972,0.115804,0.058794,0.001586,0.000000,0.002776,0.000000,0.072972,0.000000,0.000000,0.003569,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900097,0.011594,0.000000,0.000000,0.000000,0.000000,0.000000,0.015942,0.000000,0.000000,0.000000,...,0.003865,0.002415,0.000483,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900498,0.015707,0.017801,0.000000,0.000000,0.000000,0.000000,0.035602,0.000000,0.014660,0.000000,...,0.000000,0.000000,0.000000,0.015707,0.010471,0.008377,0.000000,0.000000,0.000000,0.000000
900276,0.000000,0.000000,0.041322,0.000000,0.000000,0.000000,0.207989,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.015152,0.004132,0.002755,0.001377


In [9]:
# Load the metadata
metadata_path = '../Data/Metadata/updated_clean_ant_skin_metadata.tab'
metadata = pd.read_csv(metadata_path, sep='\t')

metadata['#sample-id'] = metadata['#sample-id'].str.replace('_', '')
# Set Sample-ID as the index for the metadata dataframe 
metadata = metadata.set_index('#sample-id')


# Create group column based on case_type to simplify group names
metadata['group'] = metadata['case_type'].map({
    'case-lesional skin': 'skin-ADL',
    'case-nonlesional skin': 'skin-ADNL', 
    'control-nonlesional skin': 'skin-H',
    'case-anterior nares': 'nares-AD',
    'control-anterior nares': 'nares-H'
})

metadata

Unnamed: 0_level_0,PlateNumber,PlateLocation,i5,i5Sequence,i7,i7Sequence,identifier,Sequence,Plate ID,Well location,...,sex,enrolment_date,enrolment_season,hiv_exposure,hiv_status,household_size,o_scorad,FWD_filepath,REV_filepath,group
#sample-id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Ca009STL,1,A1,SA501,ATCGTACG,SA701,CGAGAGTT,SA701SA501,CGAGAGTT-ATCGTACG,1.010000e+21,A1,...,male,4/16/2015,Autumn,Unexposed,negative,4.0,40,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
900221,1,B1,SA502,ACTATCTG,SA701,CGAGAGTT,SA701SA502,CGAGAGTT-ACTATCTG,1.010000e+21,B1,...,female,8/11/2015,Winter,Unexposed,negative,7.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
Ca010EBL,1,C1,SA503,TAGCGAGT,SA701,CGAGAGTT,SA701SA503,CGAGAGTT-TAGCGAGT,1.010000e+21,C1,...,female,11/20/2014,Spring,Unexposed,negative,7.0,21,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
900460,1,D1,SA504,CTGCGTGT,SA701,CGAGAGTT,SA701SA504,CGAGAGTT-CTGCGTGT,1.010000e+21,D1,...,female,9/23/2015,Spring,Unexposed,,4.0,40,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
900051,1,E1,SA505,TCATCGAG,SA701,CGAGAGTT,SA701SA505,CGAGAGTT-TCATCGAG,1.010000e+21,E1,...,male,4/21/2015,Autumn,Unexposed,negative,7.0,41,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Ca006ONL2,6,H1,SA508,GACACCGT,SB701,CTCGACTT,SB701SA508,CTCGACTT-GACACCGT,1.010000e+21,H1,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
Ca006ONNL,6,F2,SA506,CGTGAGTG,SB702,CGAAGTAT,SB702SA506,CGAAGTAT-CGTGAGTG,1.010000e+21,F2,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADNL
Ca006ONNL2,6,H2,SA508,GACACCGT,SB702,CGAAGTAT,SB702SA508,CGAAGTAT-GACACCGT,1.010000e+21,H2,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADNL
Ca006ONPN,6,F3,SA506,CGTGAGTG,SB703,TAGCAGCT,SB703SA506,TAGCAGCT-CGTGAGTG,1.010000e+21,F3,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,nares-AD


In [10]:
# ------------------------------
# Helper for Confusion Matrix and Report
# ------------------------------
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

def summarize_cv_predictions(cv_results, title_prefix, output_path=None):
    y_true = np.concatenate([r['y_true'] for r in cv_results])
    y_pred = np.concatenate([r['y_pred'] for r in cv_results])

    print(f"\n=== Classification Report: {title_prefix} ===")
    print(classification_report(y_true, y_pred))

    disp = ConfusionMatrixDisplay.from_predictions(y_true, y_pred)
    disp.plot(cmap='Blues')
    plt.title(f'{title_prefix}', fontsize=22)
    plt.xlabel('Predicted Label', fontsize=20)
    plt.ylabel('True Label', fontsize=20)
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.show()

# ------------------------------
# 1. Custom Group-Stratified K-Fold Function
# ------------------------------
def group_stratified_kfold(X, y, groups, n_splits=5, random_state=42):
    unique_groups = np.unique(groups)
    np.random.seed(random_state)
    np.random.shuffle(unique_groups)

    group_label_dist = {}
    for group in unique_groups:
        group_mask = groups == group
        group_y = y[group_mask]
        group_label_dist[group] = {label: np.sum(group_y == label) for label in np.unique(y)}

    folds = [[] for _ in range(n_splits)]
    fold_label_dist = [{label: 0 for label in np.unique(y)} for _ in range(n_splits)]

    sorted_groups = sorted(unique_groups, key=lambda g: sum(groups == g), reverse=True)

    for group in sorted_groups:
        best_fold = 0
        min_imbalance = float('inf')

        for fold_idx in range(n_splits):
            temp_fold_dist = fold_label_dist[fold_idx].copy()
            for label, count in group_label_dist[group].items():
                temp_fold_dist[label] += count
            fold_size = sum(temp_fold_dist.values())
            proportions = [count / fold_size for count in temp_fold_dist.values()] if fold_size else [0]*len(temp_fold_dist)
            imbalance = np.var(proportions) + fold_size / (sum(groups.shape) / n_splits)

            if imbalance < min_imbalance:
                min_imbalance = imbalance
                best_fold = fold_idx

        folds[best_fold].extend(np.where(groups == group)[0])
        for label, count in group_label_dist[group].items():
            fold_label_dist[best_fold][label] += count

    train_test_indices = []
    for i in range(n_splits):
        test_idx = np.array(folds[i])
        train_idx = np.concatenate([folds[j] for j in range(n_splits) if j != i])
        train_test_indices.append((train_idx, test_idx))

    return train_test_indices

# ------------------------------
# 2. Function to Run CV and Get ROC/Confusion/Importances
# ------------------------------
def run_group_stratified_cv(X, y, groups, n_splits=5):
    folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)

    cv_results = []
    feature_importances = pd.DataFrame(index=X.columns)
    fold_aucs = []

    for i, (train_idx, test_idx) in enumerate(folds):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

        if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
            print(f"Skipping fold {i+1} due to insufficient class representation")
            continue

        clf = RandomForestClassifier(n_estimators=1000, random_state=42)
        clf.fit(X_train, y_train)

        probas = clf.predict_proba(X_test)
        y_pred = clf.predict(X_test)

        feature_importances[f'fold_{i}'] = clf.feature_importances_

        fpr, tpr, _ = roc_curve(y_test, probas[:, 1])
        roc_auc = auc(fpr, tpr)
        fold_aucs.append(roc_auc)

        cv_results.append({
            'y_true': y_test,
            'y_pred': y_pred,
            'y_proba': probas[:, 1],
            'fpr': fpr,
            'tpr': tpr,
            'auc': roc_auc,
            'fold': i
        })

    feature_importances['mean_importance'] = feature_importances.mean(axis=1)
    feature_importances['std_importance'] = feature_importances.std(axis=1)
    feature_importances = feature_importances.sort_values('mean_importance', ascending=False)

    return cv_results, feature_importances, fold_aucs

# ------------------------------
# 3. Run Comparisons + Save Confusion Matrices
# ------------------------------
# Run CV for each comparison and collect results

def plot_combined_confusion_matrices(confusion_matrix_data, output_path=None):
    """
    Plots a horizontal row of confusion matrices.

    Parameters:
    - confusion_matrix_data: list of tuples in the form (title, y_true, y_pred)
    - output_path: path to save the resulting figure (optional)
    """
    fig, axs = plt.subplots(1, len(confusion_matrix_data), figsize=(5 * len(confusion_matrix_data), 5))

    if len(confusion_matrix_data) == 1:
        axs = [axs]  # Ensure axs is iterable if there's only one matrix

    for i, (title, y_true, y_pred) in enumerate(confusion_matrix_data):
        cm = confusion_matrix(y_true, y_pred)
        im = axs[i].imshow(cm, interpolation='nearest', cmap='Blues')
        axs[i].set_title(f"{string.ascii_uppercase[i]}. {title}", fontsize=14)
        axs[i].set_xlabel("Predicted Label", fontsize=12)
        axs[i].set_ylabel("True Label", fontsize=12)
        axs[i].set_xticks(np.arange(cm.shape[1]))
        axs[i].set_yticks(np.arange(cm.shape[0]))
        axs[i].set_xticklabels(np.unique(y_true))
        axs[i].set_yticklabels(np.unique(y_true))

        # Add numbers with custom font size
        thresh = cm.max() / 2.
        for j in range(cm.shape[0]):
            for k in range(cm.shape[1]):
                axs[i].text(k, j, format(cm[j, k], 'd'),
                            ha="center", va="center",
                            color="white" if cm[j, k] > thresh else "black",
                            fontsize=24)  # <-- increase font size here

        axs[i].set_title(f"{title}", fontsize=28)
        axs[i].set_xlabel("Predicted Label", fontsize=24)
        axs[i].set_ylabel("True Label", fontsize=24)

    plt.suptitle("Confusion Matrices for Random Forest Classifications", fontsize=32, y=1.05)
    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')

    plt.show()

# --- Skin vs Nares ---
meta_skin_nares = metadata[metadata['group'].str.startswith(('skin', 'nares'))].copy()
meta_skin_nares['site_label'] = meta_skin_nares['group'].apply(lambda x: 0 if x.startswith('skin') else 1)
common_samples = df.index.intersection(meta_skin_nares.index)
X_skin_nares = df.loc[common_samples]
y_skin_nares = meta_skin_nares.loc[common_samples, 'site_label']
groups_skin_nares = meta_skin_nares.loc[common_samples, 'pid']
skin_vs_nares_cv_results, _, _ = run_group_stratified_cv(X_skin_nares, y_skin_nares, groups_skin_nares, n_splits=3)

# --- Skin comparisons ---
skin_comparisons = [('skin-ADL', 'skin-H'), ('skin-ADNL', 'skin-ADL'), ('skin-ADNL', 'skin-H')]
skin_cv_results_dict = {}

for label1, label2 in skin_comparisons:
    meta_subset = metadata[metadata['group'].isin([label1, label2])]
    common_samples = df.index.intersection(meta_subset.index)
    X_skin = df.loc[common_samples]
    y_skin = meta_subset.loc[common_samples, 'group'].map({label1: 0, label2: 1})
    groups_skin = meta_subset.loc[common_samples, 'pid']
    cv_results, _, _ = run_group_stratified_cv(X_skin, y_skin, groups_skin, n_splits=3)
    skin_cv_results_dict[f"{label1}_vs_{label2}"] = cv_results

# --- Nares AD vs H ---
meta_nares = metadata[metadata['group'].isin(['nares-AD', 'nares-H'])]
common_samples = df.index.intersection(meta_nares.index)
X_nares = df.loc[common_samples]
y_nares = meta_nares.loc[common_samples, 'group'].map({'nares-AD': 0, 'nares-H': 1})
groups_nares = meta_nares.loc[common_samples, 'pid']
nares_ad_vs_h_cv_results, _, _ = run_group_stratified_cv(X_nares, y_nares, groups_nares, n_splits=3)

# --- Combine and plot all confusion matrices ---
confusion_matrix_data = [
    ("Skin vs Nares", np.concatenate([r['y_true'] for r in skin_vs_nares_cv_results]), np.concatenate([r['y_pred'] for r in skin_vs_nares_cv_results])),
    ("skin-ADL vs skin-H", np.concatenate([r['y_true'] for r in skin_cv_results_dict['skin-ADL_vs_skin-H']]), np.concatenate([r['y_pred'] for r in skin_cv_results_dict['skin-ADL_vs_skin-H']])),
    ("skin-ADNL vs skin-H", np.concatenate([r['y_true'] for r in skin_cv_results_dict['skin-ADNL_vs_skin-H']]), np.concatenate([r['y_pred'] for r in skin_cv_results_dict['skin-ADNL_vs_skin-H']])),
    ("skin-ADNL vs skin-ADL", np.concatenate([r['y_true'] for r in skin_cv_results_dict['skin-ADNL_vs_skin-ADL']]), np.concatenate([r['y_pred'] for r in skin_cv_results_dict['skin-ADNL_vs_skin-ADL']])),
    ("Nares AD vs Nares H", np.concatenate([r['y_true'] for r in nares_ad_vs_h_cv_results]), np.concatenate([r['y_pred'] for r in nares_ad_vs_h_cv_results]))
]

plot_combined_confusion_matrices(confusion_matrix_data, "../Figures/Main/Fig_5B.png")



  plt.show()


## Classification Split by Urban (Cape Town) vs Rural (Umtata) (Supplementary Plots)

In [11]:
# Split data by area (Cape Town vs Umtata)
cape_town_samples = metadata[metadata['area'] == 'Cape Town'].index
umtata_samples = metadata[metadata['area'] == 'Umtata'].index

# Filter the ASV table to get samples from each area
df_cape_town = df.loc[df.index.isin(cape_town_samples)]
df_umtata = df.loc[df.index.isin(umtata_samples)]



In [12]:
area_df_dict = {'Capetown': df_cape_town, 'Umtata': df_umtata}

for area_name, df in area_df_dict.items():   
    # split data into training and testing
    # ------------------------------
    # 1. Custom Group-Stratified K-Fold Function
    # ------------------------------
    def group_stratified_kfold(X, y, groups, n_splits=5, random_state=42):
        unique_groups = np.unique(groups)
        np.random.seed(random_state)
        np.random.shuffle(unique_groups)
        
        group_label_dist = {}
        for group in unique_groups:
            group_mask = groups == group
            group_y = y[group_mask]
            group_label_dist[group] = {label: np.sum(group_y == label) for label in np.unique(y)}
        
        folds = [[] for _ in range(n_splits)]
        fold_label_dist = [{label: 0 for label in np.unique(y)} for _ in range(n_splits)]
        
        sorted_groups = sorted(unique_groups, key=lambda g: sum(groups == g), reverse=True)
        
        for group in sorted_groups:
            best_fold = 0
            min_imbalance = float('inf')
            
            for fold_idx in range(n_splits):
                temp_fold_dist = fold_label_dist[fold_idx].copy()
                for label, count in group_label_dist[group].items():
                    temp_fold_dist[label] += count
                fold_size = sum(temp_fold_dist.values())
                proportions = [count / fold_size for count in temp_fold_dist.values()] if fold_size else [0]*len(temp_fold_dist)
                imbalance = np.var(proportions) + fold_size / (sum(groups.shape) / n_splits)
                
                if imbalance < min_imbalance:
                    min_imbalance = imbalance
                    best_fold = fold_idx
            
            folds[best_fold].extend(np.where(groups == group)[0])
            for label, count in group_label_dist[group].items():
                fold_label_dist[best_fold][label] += count
        
        train_test_indices = []
        for i in range(n_splits):
            test_idx = np.array(folds[i])
            train_idx = np.concatenate([folds[j] for j in range(n_splits) if j != i])
            train_test_indices.append((train_idx, test_idx))
        
        return train_test_indices

    # ------------------------------
    # 2. Function to Run CV and Get ROC/Feature Importance
    # ------------------------------
    def run_group_stratified_cv(X, y, groups, n_splits=5):
        folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)

        cv_results = []
        feature_importances = pd.DataFrame(index=X.columns)
        fold_aucs = []

        for i, (train_idx, test_idx) in enumerate(folds):
            X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
            y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

            if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
                print(f"Skipping fold {i+1} due to insufficient class representation")
                continue

            clf = RandomForestClassifier(n_estimators=1000, random_state=42)
            clf.fit(X_train, y_train)

            probas = clf.predict_proba(X_test)
            y_pred = clf.predict(X_test)

            feature_importances[f'fold_{i}'] = clf.feature_importances_

            fpr, tpr, _ = roc_curve(y_test, probas[:, 1])
            roc_auc = auc(fpr, tpr)
            fold_aucs.append(roc_auc)

            cv_results.append({
                'y_true': y_test,
                'y_pred': y_pred,
                'y_proba': probas[:, 1],
                'fpr': fpr,
                'tpr': tpr,
                'auc': roc_auc,
                'fold': i
            })

        feature_importances['mean_importance'] = feature_importances.mean(axis=1)
        feature_importances['std_importance'] = feature_importances.std(axis=1)
        feature_importances = feature_importances.sort_values('mean_importance', ascending=False)

        return cv_results, feature_importances, fold_aucs


    # ------------------------------
    # 3. Regenerate ROC Results
    # ------------------------------
    # Make sure your df (feature table) and metadata DataFrame are in scope.
    # Here, 'df' is assumed to contain your features (ASV relative abundances) 
    # and 'metadata' is assumed to contain columns: 'group', 'pid', and 'microbiome_type' if applicable.

    # 3a. Skin vs Nares (Binary Classification)
    meta_skin_nares = metadata[metadata['group'].str.startswith(('skin', 'nares'))].copy()
    meta_skin_nares['site_label'] = meta_skin_nares['group'].apply(lambda x: 0 if x.startswith('skin') else 1)
    common_samples = df.index.intersection(meta_skin_nares.index)
    X_skin_nares = df.loc[common_samples]
    y_skin_nares = meta_skin_nares.loc[common_samples, 'site_label']
    groups_skin_nares = meta_skin_nares.loc[common_samples, 'pid']
    y_skin_nares = pd.to_numeric(y_skin_nares, errors='coerce')
    skin_vs_nares_cv_results, _, _ = run_group_stratified_cv(X_skin_nares, y_skin_nares, groups_skin_nares, n_splits=3)

    # 3b. Skin Comparisons (Pairwise Among Skin Groups)
    skin_comparisons = [
        ('skin-ADL', 'skin-H'),
        ('skin-ADNL', 'skin-ADL'),
        ('skin-ADNL', 'skin-H')
    ]
    skin_cv_results_dict = {}
    for label1, label2 in skin_comparisons:
        meta_subset = metadata[metadata['group'].isin([label1, label2])]
        common_samples = df.index.intersection(meta_subset.index)
        X_skin = df.loc[common_samples]
        # Map to binary: 0 for first label, 1 for second label
        y_skin = meta_subset.loc[common_samples, 'group'].map({label1: 0, label2: 1})
        groups_skin = meta_subset.loc[common_samples, 'pid']
        cv_results, _, _ = run_group_stratified_cv(X_skin, y_skin, groups_skin, n_splits=3)
        key = f"{label1}_vs_{label2}"
        skin_cv_results_dict[key] = cv_results

    # 3c. Nares Comparison (nares-AD vs nares-H)
    meta_nares = metadata[metadata['group'].isin(['nares-AD', 'nares-H'])]
    common_samples = df.index.intersection(meta_nares.index)
    X_nares = df.loc[common_samples]
    y_nares = meta_nares.loc[common_samples, 'group'].map({'nares-AD': 0, 'nares-H': 1})
    groups_nares = meta_nares.loc[common_samples, 'pid']
    nares_ad_vs_h_cv_results, _, _ = run_group_stratified_cv(X_nares, y_nares, groups_nares, n_splits=3)

    # ------------------------------
    # 4. Assemble ROC Results Dictionary
    # ------------------------------
    roc_results_dict = {
        'Prediction of Skin vs. Nares Samples': {
            'skin_vs_nares': skin_vs_nares_cv_results
        },
        'Prediction of AD Status From Skin Samples': skin_cv_results_dict,
        'Prediction of AD Status From Nares Samples': {
            'nares-AD_vs_nares-H': nares_ad_vs_h_cv_results
        }
    }


    # ------------------------------
    # 5. Combined ROC Plotting Function (3-Panel)
    # ------------------------------
    def plot_combined_roc_3panel(roc_results_dict, comparison_order, color_map, output_path=None):
        """
        Plots a 3-panel ROC curve figure from grouped ROC results.
        
        Parameters:
        - roc_results_dict: dict
            A dictionary mapping panel titles to sub-dicts. Each sub-dict maps sublabels 
            to a list of cv_results dictionaries.
        - comparison_order: list
            List of panel titles in order.
        - color_map: dict
            Dictionary mapping each sublabel to a color.
        - output_path: str or None
            If provided, path to save the figure.
        """
        fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharey=True)

        pretty_labels = {
        'skin_vs_nares': 'All Skin vs All Nares',
        'skin-ADL_vs_skin-H': 'Skin ADL vs Skin H',
        'skin-ADNL_vs_skin-ADL': 'Skin ADNL vs Skin H',
        'skin-ADNL_vs_skin-H': 'Skin ADNL vs Skin ADL',
        'nares-AD_vs_nares-H': 'Nares AD vs Nares H'
    }
        
        for i, panel_title in enumerate(comparison_order):
            ax = axs[i]
            sub_dict = roc_results_dict[panel_title]
            # Loop over each sub-comparison within this panel
            for sublabel, curves in sub_dict.items():
                # Ensure curves is a list
                if not isinstance(curves, list):
                    curves = [curves]
                mean_fpr = np.linspace(0, 1, 100)
                tprs = []
                aucs = []
                for result in curves:
                    # Ensure the result dict has fpr and tpr keys
                    if 'fpr' not in result or 'tpr' not in result:
                        continue
                    tpr_interp = np.interp(mean_fpr, result['fpr'], result['tpr'])
                    tpr_interp[0] = 0.0
                    tprs.append(tpr_interp)
                    aucs.append(result['auc'])
                if len(tprs) == 0:
                    continue
                mean_tpr = np.mean(tprs, axis=0)
                mean_tpr[-1] = 1.0
                std_tpr = np.std(tprs, axis=0)
                tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
                tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
                ax.plot(mean_fpr, mean_tpr, lw=3,
                        label=f'{pretty_labels.get(sublabel, sublabel)} (AUC = {np.mean(aucs):.2f} ± {np.std(aucs):.2f})',
                        color=color_map.get(sublabel, 'gray'))
                ax.fill_between(mean_fpr, tprs_lower, tprs_upper, alpha=0.3,
                                color=color_map.get(sublabel, 'gray'))
        
            ax.plot([0, 1], [0, 1], 'k--', lw=1)
            ax.set_xlim([0.0, 1.0])
            ax.set_ylim([0.0, 1.05])
            ax.set_xlabel('False Positive Rate', fontsize=16)
            ax.set_title(panel_title, fontsize=18)
            ax.grid(True, linestyle='--', alpha=0.7)
            ax.tick_params(axis='both', labelsize=14)
            ax.legend(loc='lower right', fontsize=12)
        
        axs[0].set_ylabel('True Positive Rate', fontsize=16)
        if area_name == 'Capetown':
            plt.suptitle('Random Forest Classifications by 16S ASVs of Cape Town Samples', fontsize=26, x=0.5, y=1)
        elif area_name == 'Umtata':
            plt.suptitle('Random Forest Classifications by 16S ASVs of Umtata Samples', fontsize=26, x=0.5, y=1)

        plt.tight_layout()
        
        if output_path:
            plt.savefig(output_path, dpi=600, bbox_inches='tight')

    # ------------------------------
    # 6. Define Panel Order and Color Map, Then Plot
    # ------------------------------
    comparison_order = [
        'Prediction of Skin vs. Nares Samples',
        'Prediction of AD Status From Skin Samples',
        'Prediction of AD Status From Nares Samples'
    ]
    color_map = {
        'skin_vs_nares': 'black',
        'skin-ADL_vs_skin-H': 'blue',
        'skin-ADNL_vs_skin-ADL': 'green',
        'skin-ADNL_vs_skin-H': 'purple',
        'nares-AD_vs_nares-H': 'orange'
    }

    if area_name == 'Capetown':
        output_path = f'../Figures/Supplementary/Suppl_Fig_5A.png'
    elif area_name == 'Umtata':
        output_path = f'../Figures/Supplementary/Suppl_Fig_5B.png' 

    plot_combined_roc_3panel(
        roc_results_dict,
        comparison_order=comparison_order,
        color_map=color_map,
        output_path=output_path
    )