# Random Forest Analyses

In [1]:
# Standard library
import warnings
import logging
from itertools import combinations

# Scientific computing
import numpy as np
import pandas as pd
from numpy import array
import scipy
import scipy.stats as ss
from scipy import interp
from scipy.stats import wilcoxon, ttest_rel

# Visualization
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

# scikit-bio
from skbio.stats.distance import permanova

# BIOM format
import biom
from biom import load_table

# Scikit-learn
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, confusion_matrix, classification_report,
    roc_curve, auc, RocCurveDisplay
)
from sklearn.model_selection import GroupKFold, StratifiedKFold
from sklearn.preprocessing import label_binarize

# Set overall styling for plots
sns.set_context("paper", font_scale=1.5)
sns.set_style("ticks")

## Random Forest Classification Tasks

In [2]:
# Read in table at ASV level
biom_path = '..//Data/Tables/Absolute_Abundance_Tables/209766_filtered_feature_table.biom'
# biom_path = '../Data/Tables/Absolute_Abundance_Tables/209766_filtered_by_prevalence_1pct_rare_Genus-ASV-non-collapse.biom'

biom_tbl = load_table(biom_path)
df = pd.DataFrame(biom_tbl.to_dataframe().T)

# delete the prefix from the index
df.index = df.index.str.replace('15564.', '')
df

Unnamed: 0,GTGCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCCGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCCGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTAATCGGAATTATTGGGCGTAAAGCGAGTGCAGACGGTTACTTAAGCCAGATGTGAAATCCCC,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGAATTATTGGGCGTAAAGCGCGCGCAGGCGGTTTCTTAAGTCTGATGTGAAAGCCCC,GTGCCAGCAGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATTTATTGGGCGTAAAGGGCTCGTAGGTGGTTGATCGCGTCGGAAGTGTAATCTTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTCCAAGCGTTAATCGGAATTACTGGGCGTAAAGCGTGCGCAGGCGGTTGTGCAAGACCGATGTGAAATCCCC,GTGCCAGCCGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGATTTATTGGGCGTAAAGGGAGCGCAGGTGGTTTCTTAAGTCTGATGTGAAAGCCCA,GTGCCAGCCGCCGCGGTAATACGGAAGGTCCAGGCGTTATCCGGATTTATTGGGTTTAAAGGGAGCGTAGGCGGATTATTAAGTCAGTGGTGAAAGACGG,...,GTGCCAGCCGCCGCGGTAATACGTAGGGGGCAAGCGTTATCCGGATTTACTGGGTGTAAAGGGAGCGTAGACGGCGCAGCAAGTCTGATGTGAAAGGCAG,GTGCCAGCAGCCGCGGTAAGACAGAGGGTGCAAACGTTGCTCGGAATCACTGGGCGTAAAGGGCGTGTAGGCGGGAGAGAAAGTCGGGCGTGAAATCCCT,GTGCCAGCCGCGGTAATACGTAGGGGGCTAGCGTTGTCCGGAATCACTGGGCGTAAAGGGTTCGCAGGCGGAAATGCAAGTCAGGTGTAAAAGGCAGTAG,GTGCCAGCAGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCCGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCAGCCGCGGTAATACGGAGGGTGCAAGCGTTATCCGGAATCATTGGGTTTAAAGGGTCCGCAGGCGGATTTATAAGTCAGTGGTGAAAGCCTA,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCGAGCGTTGTCCGGAATTACTGGGTGTAAAGGGCGTGTAGGCGGGAAGGTAAGTCAGATGTGAAATACCG,GTGCCAGCCGCCGCGGTAATACGGAGGATGCGAGCGTTATTCGGAATCATTGGGTTTAAAGGGTCTGTAGGCGGGCTATTAAGTCAGAGGTGAAAGGTTT,GTGCCAGCCGCCGCGGTAAGACGAAGGGGGCTAGCGTTGTTCGGAATTACTGGGCGTAAAGCGCGTGCAGGCGGTTATCCAAGTCGGGTGTGAAAGCCTT,GTCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTGT
900344,984.0,611.0,114.0,82.0,22.0,15.0,8.0,8.0,6.0,3.0,...,0,0,0,0,0,0,0,0,0,0
900459,118.0,106.0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
900221,22.0,0,0,0,0,0,16.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
900570,389.0,0,0,0,8.0,0,11.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
900092,3106.0,1707.0,59.0,32.0,3.0,0,0,0,0,7.0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9003972,1168.0,593.0,16.0,0,28.0,0,736.0,0,0,36.0,...,0,0,0,0,0,0,0,0,0,0
900097,24.0,0,0,0,0,0,33.0,0,0,0,...,8.0,5.0,1.0,0,0,0,0,0,0,0
900498,15.0,17.0,0,0,0,0,34.0,0,14.0,0,...,0,0,0,15.0,10.0,8.0,0,0,0,0
900276,0,0,30.0,0,0,0,151.0,0,0,0,...,0,0,0,0,0,0,11.0,3.0,2.0,1.0


In [3]:
# Load the metadata
metadata_path = '../Data/Metadata/updated_clean_ant_skin_metadata.tab'
metadata = pd.read_csv(metadata_path, sep='\t')

metadata['#sample-id'] = metadata['#sample-id'].str.replace('_', '')
# Set Sample-ID as the index for the metadata dataframe 
metadata = metadata.set_index('#sample-id')


# Create group column based on case_type to simplify group names
metadata['group'] = metadata['case_type'].map({
    'case-lesional skin': 'skin-ADL',
    'case-nonlesional skin': 'skin-ADNL', 
    'control-nonlesional skin': 'skin-H',
    'case-anterior nares': 'nares-AD',
    'control-anterior nares': 'nares-H'
})

metadata

Unnamed: 0_level_0,PlateNumber,PlateLocation,i5,i5Sequence,i7,i7Sequence,identifier,Sequence,Plate ID,Well location,...,sex,enrolment_date,enrolment_season,hiv_exposure,hiv_status,household_size,o_scorad,FWD_filepath,REV_filepath,group
#sample-id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Ca009STL,1,A1,SA501,ATCGTACG,SA701,CGAGAGTT,SA701SA501,CGAGAGTT-ATCGTACG,1.010000e+21,A1,...,male,4/16/2015,Autumn,Unexposed,negative,4.0,40,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
900221,1,B1,SA502,ACTATCTG,SA701,CGAGAGTT,SA701SA502,CGAGAGTT-ACTATCTG,1.010000e+21,B1,...,female,8/11/2015,Winter,Unexposed,negative,7.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
Ca010EBL,1,C1,SA503,TAGCGAGT,SA701,CGAGAGTT,SA701SA503,CGAGAGTT-TAGCGAGT,1.010000e+21,C1,...,female,11/20/2014,Spring,Unexposed,negative,7.0,21,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
900460,1,D1,SA504,CTGCGTGT,SA701,CGAGAGTT,SA701SA504,CGAGAGTT-CTGCGTGT,1.010000e+21,D1,...,female,9/23/2015,Spring,Unexposed,,4.0,40,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
900051,1,E1,SA505,TCATCGAG,SA701,CGAGAGTT,SA701SA505,CGAGAGTT-TCATCGAG,1.010000e+21,E1,...,male,4/21/2015,Autumn,Unexposed,negative,7.0,41,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Ca006ONL2,6,H1,SA508,GACACCGT,SB701,CTCGACTT,SB701SA508,CTCGACTT-GACACCGT,1.010000e+21,H1,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
Ca006ONNL,6,F2,SA506,CGTGAGTG,SB702,CGAAGTAT,SB702SA506,CGAAGTAT-CGTGAGTG,1.010000e+21,F2,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADNL
Ca006ONNL2,6,H2,SA508,GACACCGT,SB702,CGAAGTAT,SB702SA508,CGAAGTAT-GACACCGT,1.010000e+21,H2,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADNL
Ca006ONPN,6,F3,SA506,CGTGAGTG,SB703,TAGCAGCT,SB703SA506,TAGCAGCT-CGTGAGTG,1.010000e+21,F3,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,nares-AD


In [4]:
# Check o_scorad distribution across different groups
for group in metadata['group'].unique():
    print(f"\nO-SCORAD distribution for {group}:")
    print(metadata[metadata['group'] == group]['o_scorad'].value_counts().sort_index())



O-SCORAD distribution for skin-ADL:
o_scorad
21    4
22    2
23    3
24    2
25    4
26    2
28    3
29    2
30    3
32    4
33    2
34    5
35    1
36    3
37    2
38    2
39    3
40    4
41    4
42    5
43    7
44    9
45    1
47    1
48    1
49    1
51    1
52    1
53    3
54    3
56    2
58    2
61    2
66    2
67    1
75    2
78    4
82    1
NR    1
Name: count, dtype: int64

O-SCORAD distribution for skin-ADNL:
o_scorad
10    1
21    4
22    2
23    3
24    2
25    6
26    2
28    4
29    2
30    3
32    4
33    2
34    5
35    1
36    3
37    2
38    2
39    3
40    4
41    4
42    5
43    7
44    9
45    1
47    1
48    1
49    2
51    1
52    1
53    3
54    3
56    2
58    2
61    2
66    2
67    1
78    2
NR    1
Name: count, dtype: int64

O-SCORAD distribution for nares-AD:
o_scorad
21    4
22    2
23    3
24    2
25    5
26    2
28    4
29    2
30    3
32    4
33    2
34    5
35    1
36    3
37    2
38    1
39    3
40    4
41    4
42    5
43    6
44    9
45    1
47    1
4

In [5]:
# split data into training and testing
# ------------------------------
# 1. Custom Group-Stratified K-Fold Function
# ------------------------------
def group_stratified_kfold(X, y, groups, n_splits=5, random_state=42):
    unique_groups = np.unique(groups)
    np.random.seed(random_state)
    np.random.shuffle(unique_groups)
    
    group_label_dist = {}
    for group in unique_groups:
        group_mask = groups == group
        group_y = y[group_mask]
        group_label_dist[group] = {label: np.sum(group_y == label) for label in np.unique(y)}
    
    folds = [[] for _ in range(n_splits)]
    fold_label_dist = [{label: 0 for label in np.unique(y)} for _ in range(n_splits)]
    
    sorted_groups = sorted(unique_groups, key=lambda g: sum(groups == g), reverse=True)
    
    for group in sorted_groups:
        best_fold = 0
        min_imbalance = float('inf')
        
        for fold_idx in range(n_splits):
            temp_fold_dist = fold_label_dist[fold_idx].copy()
            for label, count in group_label_dist[group].items():
                temp_fold_dist[label] += count
            fold_size = sum(temp_fold_dist.values())
            proportions = [count / fold_size for count in temp_fold_dist.values()] if fold_size else [0]*len(temp_fold_dist)
            imbalance = np.var(proportions) + fold_size / (sum(groups.shape) / n_splits)
            
            if imbalance < min_imbalance:
                min_imbalance = imbalance
                best_fold = fold_idx
        
        folds[best_fold].extend(np.where(groups == group)[0])
        for label, count in group_label_dist[group].items():
            fold_label_dist[best_fold][label] += count
    
    train_test_indices = []
    for i in range(n_splits):
        test_idx = np.array(folds[i])
        train_idx = np.concatenate([folds[j] for j in range(n_splits) if j != i])
        train_test_indices.append((train_idx, test_idx))
    
    return train_test_indices

# ------------------------------
# 2. Function to Run CV and Get ROC/Feature Importance
# ------------------------------
def run_group_stratified_cv(X, y, groups, n_splits=5):
    folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)
    
    cv_results = []
    feature_importances = pd.DataFrame(index=X.columns)
    fold_aucs = []
    
    for i, (train_idx, test_idx) in enumerate(folds):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
        
        if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
            print(f"Skipping fold {i+1} due to insufficient class representation")
            continue
            
        clf = RandomForestClassifier(n_estimators=1000, random_state=42)
        clf.fit(X_train, y_train)
        probas = clf.predict_proba(X_test)
        
        feature_importances[f'fold_{i}'] = clf.feature_importances_
        fpr, tpr, _ = roc_curve(y_test, probas[:, 1])
        roc_auc = auc(fpr, tpr)
        fold_aucs.append(roc_auc)
        
        cv_results.append({
            'y_true': y_test,
            'y_proba': probas[:, 1],
            'fpr': fpr,
            'tpr': tpr,
            'auc': roc_auc,
            'fold': i
        })
    
    feature_importances['mean_importance'] = feature_importances.mean(axis=1)
    feature_importances['std_importance'] = feature_importances.std(axis=1)
    feature_importances = feature_importances.sort_values('mean_importance', ascending=False)
    
    return cv_results, feature_importances, fold_aucs

# ------------------------------
# 3. Regenerate ROC Results
# ------------------------------
# Make sure your df (feature table) and metadata DataFrame are in scope.
# Here, 'df' is assumed to contain your features (ASV relative abundances) 
# and 'metadata' is assumed to contain columns: 'group', 'pid', and 'microbiome_type' if applicable.

# 3a. Skin vs Nares (Binary Classification)
meta_skin_nares = metadata[metadata['group'].str.startswith(('skin', 'nares'))].copy()
meta_skin_nares['site_label'] = meta_skin_nares['group'].apply(lambda x: 0 if x.startswith('skin') else 1)
common_samples = df.index.intersection(meta_skin_nares.index)
X_skin_nares = df.loc[common_samples]
y_skin_nares = meta_skin_nares.loc[common_samples, 'site_label']
groups_skin_nares = meta_skin_nares.loc[common_samples, 'pid']
y_skin_nares = pd.to_numeric(y_skin_nares, errors='coerce')
skin_vs_nares_cv_results, _, _ = run_group_stratified_cv(X_skin_nares, y_skin_nares, groups_skin_nares, n_splits=3)

# 3b. Skin Comparisons (Pairwise Among Skin Groups)
skin_comparisons = [
    ('skin-ADL', 'skin-H'),
    ('skin-ADNL', 'skin-ADL'),
    ('skin-ADNL', 'skin-H')
]
skin_cv_results_dict = {}
for label1, label2 in skin_comparisons:
    meta_subset = metadata[metadata['group'].isin([label1, label2])]
    common_samples = df.index.intersection(meta_subset.index)
    X_skin = df.loc[common_samples]
    # Map to binary: 0 for first label, 1 for second label
    y_skin = meta_subset.loc[common_samples, 'group'].map({label1: 0, label2: 1})
    groups_skin = meta_subset.loc[common_samples, 'pid']
    cv_results, _, _ = run_group_stratified_cv(X_skin, y_skin, groups_skin, n_splits=3)
    key = f"{label1}_vs_{label2}"
    skin_cv_results_dict[key] = cv_results

# 3c. Nares Comparison (nares-AD vs nares-H)
meta_nares = metadata[metadata['group'].isin(['nares-AD', 'nares-H'])]
common_samples = df.index.intersection(meta_nares.index)
X_nares = df.loc[common_samples]
y_nares = meta_nares.loc[common_samples, 'group'].map({'nares-AD': 0, 'nares-H': 1})
groups_nares = meta_nares.loc[common_samples, 'pid']
nares_ad_vs_h_cv_results, _, _ = run_group_stratified_cv(X_nares, y_nares, groups_nares, n_splits=3)

# ------------------------------
# 4. Assemble ROC Results Dictionary
# ------------------------------
roc_results_dict = {
    'Prediction of Skin vs. Nares Samples': {
        'skin_vs_nares': skin_vs_nares_cv_results
    },
    'Prediction of AD Status From Skin Samples': skin_cv_results_dict,
    'Prediction of AD Status From Nares Samples': {
        'nares-AD_vs_nares-H': nares_ad_vs_h_cv_results
    }
}


# ------------------------------
# 5. Combined ROC Plotting Function (3-Panel)
# ------------------------------
def plot_combined_roc_3panel(roc_results_dict, comparison_order, color_map, output_path=None):
    """
    Plots a 3-panel ROC curve figure from grouped ROC results.
    
    Parameters:
    - roc_results_dict: dict
        A dictionary mapping panel titles to sub-dicts. Each sub-dict maps sublabels 
        to a list of cv_results dictionaries.
    - comparison_order: list
        List of panel titles in order.
    - color_map: dict
        Dictionary mapping each sublabel to a color.
    - output_path: str or None
        If provided, path to save the figure.
    """
    fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharey=True)

    pretty_labels = {
    'skin_vs_nares': 'All Skin vs All Nares',
    'skin-ADL_vs_skin-H': 'Skin ADL vs Skin H',
    'skin-ADNL_vs_skin-ADL': 'Skin ADNL vs Skin H',
    'skin-ADNL_vs_skin-H': 'Skin ADNL vs Skin ADL',
    'nares-AD_vs_nares-H': 'Nares AD vs Nares H'
}
    
    for i, panel_title in enumerate(comparison_order):
        ax = axs[i]
        sub_dict = roc_results_dict[panel_title]
        # Loop over each sub-comparison within this panel
        for sublabel, curves in sub_dict.items():
            # Ensure curves is a list
            if not isinstance(curves, list):
                curves = [curves]
            mean_fpr = np.linspace(0, 1, 100)
            tprs = []
            aucs = []
            for result in curves:
                # Ensure the result dict has fpr and tpr keys
                if 'fpr' not in result or 'tpr' not in result:
                    continue
                tpr_interp = np.interp(mean_fpr, result['fpr'], result['tpr'])
                tpr_interp[0] = 0.0
                tprs.append(tpr_interp)
                aucs.append(result['auc'])
            if len(tprs) == 0:
                continue
            mean_tpr = np.mean(tprs, axis=0)
            mean_tpr[-1] = 1.0
            std_tpr = np.std(tprs, axis=0)
            tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
            tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
            ax.plot(mean_fpr, mean_tpr, lw=3,
                    label=f'{pretty_labels.get(sublabel, sublabel)} (AUC = {np.mean(aucs):.2f} ± {np.std(aucs):.2f})',
                    color=color_map.get(sublabel, 'gray'))
            ax.fill_between(mean_fpr, tprs_lower, tprs_upper, alpha=0.3,
                            color=color_map.get(sublabel, 'gray'))
    
        ax.plot([0, 1], [0, 1], 'k--', lw=1)
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel('False Positive Rate', fontsize=16)
        ax.set_title(panel_title, fontsize=18)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.tick_params(axis='both', labelsize=14)
        ax.legend(loc='lower right', fontsize=12)
    
    axs[0].set_ylabel('True Positive Rate', fontsize=16)
    plt.suptitle('Random Forest Classifications by 16S ASVs', fontsize=26, x=0.5, y=1)
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path, dpi=600, bbox_inches='tight')

# ------------------------------
# 6. Define Panel Order and Color Map, Then Plot
# ------------------------------
comparison_order = [
    'Prediction of Skin vs. Nares Samples',
    'Prediction of AD Status From Skin Samples',
    'Prediction of AD Status From Nares Samples'
]
color_map = {
    'skin_vs_nares': 'black',
    'skin-ADL_vs_skin-H': 'blue',
    'skin-ADNL_vs_skin-ADL': 'green',
    'skin-ADNL_vs_skin-H': 'purple',
    'nares-AD_vs_nares-H': 'orange'
}

plot_combined_roc_3panel(
    roc_results_dict,
    comparison_order=comparison_order,
    color_map=color_map,
    # output_path='../Plots/Analysis_figures/Random_Forest/rf_classifications_ROC_curves_filtered_by_prevalence_1pct_rare_Genus-ASV-non-collapse.png'
    output_path = '../Plots/Analysis_figures/Random_Forest/rf_classifications_final.png'
)


## Combined skin and nares classification

In [8]:
# Define combined dataset (Skin ADL + Nares AD vs Skin H + Nares H)
meta_skin_nares_AD = metadata[metadata['group'].isin(['skin-ADL', 'nares-AD'])].copy()
meta_skin_nares_H = metadata[metadata['group'].isin(['skin-H', 'nares-H'])].copy()
meta_combined = pd.concat([meta_skin_nares_AD, meta_skin_nares_H])

# Label as 1 for AD, 0 for Healthy
meta_combined['combined_label'] = meta_combined['group'].map({
    'skin-ADL': 1,
    'nares-AD': 1,
    'skin-H': 0,
    'nares-H': 0
})
meta_combined['site'] = meta_combined['group'].apply(lambda x: 'skin' if 'skin' in x else 'nares')

# Get matching features
common_samples = df.index.intersection(meta_combined.index)
X_combined = df.loc[common_samples]
y_combined = meta_combined.loc[common_samples, 'combined_label']
groups_combined = meta_combined.loc[common_samples, 'pid']

# One-hot encode the site and combine with original features
site_dummies = pd.get_dummies(meta_combined.loc[common_samples, 'site'], prefix='site')
X_combined_with_site = pd.concat([X_combined.reset_index(drop=True), site_dummies.reset_index(drop=True)], axis=1)

# Run classification using the enhanced feature set
combined_AD_vs_H_results, combined_feature_importance, combined_aucs = run_group_stratified_cv(
    X_combined_with_site, y_combined, groups_combined, n_splits=3
)

combined_AD_vs_H_results



[{'y_true': 9004002    1
  9004022    1
  900400     1
  900402     1
  9003932    0
            ..
  900466     1
  900463     1
  900469     1
  900224     1
  900258     1
  Name: combined_label, Length: 130, dtype: int64,
  'y_proba': array([0.733, 0.786, 0.396, 0.552, 0.226, 0.585, 0.432, 0.295, 0.501,
         0.491, 0.61 , 0.287, 0.386, 0.596, 0.649, 0.318, 0.527, 0.623,
         0.483, 0.404, 0.512, 0.372, 0.5  , 0.524, 0.503, 0.638, 0.35 ,
         0.5  , 0.598, 0.449, 0.364, 0.462, 0.721, 0.501, 0.744, 0.216,
         0.275, 0.724, 0.293, 0.762, 0.493, 0.472, 0.718, 0.503, 0.726,
         0.527, 0.7  , 0.33 , 0.501, 0.663, 0.359, 0.449, 0.294, 0.675,
         0.885, 0.702, 0.392, 0.392, 0.324, 0.576, 0.605, 0.272, 0.549,
         0.639, 0.445, 0.848, 0.514, 0.438, 0.731, 0.2  , 0.492, 0.735,
         0.711, 0.506, 0.565, 0.511, 0.63 , 0.609, 0.674, 0.528, 0.258,
         0.256, 0.39 , 0.258, 0.54 , 0.291, 0.501, 0.498, 0.789, 0.526,
         0.276, 0.765, 0.814, 0.463, 0.556,

In [9]:
def plot_single_roc(cv_results, label, color='black', output_path=None):
    """
    Plot a ROC curve from cross-validation results.

    Parameters:
    - cv_results: list of dicts from run_group_stratified_cv
    - label: str, label for legend
    - color: str, line color
    - output_path: str, optional path to save plot
    """
    mean_fpr = np.linspace(0, 1, 100)
    tprs, aucs = [], []

    for result in cv_results:
        fpr, tpr, auc_val = result['fpr'], result['tpr'], result['auc']
        interp_tpr = np.interp(mean_fpr, fpr, tpr)
        interp_tpr[0] = 0.0
        tprs.append(interp_tpr)
        aucs.append(auc_val)

    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    std_tpr = np.std(tprs, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)

    plt.figure(figsize=(6, 6))
    plt.plot(mean_fpr, mean_tpr, color=color, lw=3,
             label=f'{label} (AUC = {np.mean(aucs):.2f} ± {np.std(aucs):.2f})')
    plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color=color, alpha=0.3)
    plt.plot([0, 1], [0, 1], 'k--', lw=1)

    plt.xlabel('False Positive Rate', fontsize=14)
    plt.ylabel('True Positive Rate', fontsize=14)
    plt.title('ROC Curve: Combined Skin+Nares AD vs H', fontsize=16)
    plt.legend(loc='lower right', fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=600, bbox_inches='tight')


In [10]:
plot_single_roc(
    combined_AD_vs_H_results,
    label='Skin+Nares AD vs H',
    color='darkred',
    output_path='../Plots/Analysis_figures/Random_Forest/roc_combined_skin_nares_AD_vs_H.png'
)


## Weighted skin and nares classification

In [36]:
skin_weight = 1.18
nares_weight = 0.82


In [37]:
def run_group_stratified_cv(X, y, groups, site_series, n_splits=5):
    """
    Run group-stratified CV using RandomForestClassifier, with site-based sample weights.

    Parameters:
    - X: pd.DataFrame of features
    - y: pd.Series of binary labels
    - groups: pd.Series of group IDs (e.g., subject IDs)
    - site_series: pd.Series indicating 'skin' or 'nares' for each sample
    - n_splits: number of CV splits

    Returns:
    - cv_results: list of ROC results per fold
    - feature_importances: pd.DataFrame of feature importances
    - fold_aucs: list of AUCs per fold
    """
    folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)

    cv_results = []
    feature_importances = pd.DataFrame(index=X.columns)
    fold_aucs = []

    # Define weights: based on AUCs of site-specific classifiers
    site_weights = {'skin': skin_weight, 'nares': nares_weight}
    sample_weights_all = site_series.map(site_weights).values

    for i, (train_idx, test_idx) in enumerate(folds):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

        if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
            print(f"Skipping fold {i+1} due to insufficient class representation")
            continue

        clf = RandomForestClassifier(n_estimators=1000, random_state=42)
        clf.fit(X_train, y_train, sample_weight=sample_weights_all[train_idx])
        probas = clf.predict_proba(X_test)

        feature_importances[f'fold_{i}'] = clf.feature_importances_
        fpr, tpr, _ = roc_curve(y_test, probas[:, 1])
        roc_auc = auc(fpr, tpr)
        fold_aucs.append(roc_auc)

        cv_results.append({
            'y_true': y_test,
            'y_proba': probas[:, 1],
            'fpr': fpr,
            'tpr': tpr,
            'auc': roc_auc,
            'fold': i
        })

    feature_importances['mean_importance'] = feature_importances.mean(axis=1)
    feature_importances['std_importance'] = feature_importances.std(axis=1)
    feature_importances = feature_importances.sort_values('mean_importance', ascending=False)

    return cv_results, feature_importances, fold_aucs


In [38]:
# site_series is the 'site' column you created earlier
site_series = meta_combined.loc[common_samples, 'site']

combined_AD_vs_H_weighted_results, combined_weighted_importance, combined_weighted_aucs = run_group_stratified_cv(
    X_combined_with_site,  # include site one-hot encoded features if you want
    y_combined,
    groups_combined,
    site_series,
    n_splits=3
)




In [39]:
def plot_roc_comparison(
    unweighted_results, weighted_results,
    labels=('Unweighted', 'Weighted'),
    colors=('black', 'darkred'),
    output_path=None
):
    """
    Plot side-by-side ROC curves for unweighted vs. weighted classifiers.
    
    Parameters:
    - unweighted_results: list of dicts from run_group_stratified_cv (unweighted)
    - weighted_results: list of dicts from run_group_stratified_cv (weighted)
    - labels: tuple of (label_unweighted, label_weighted)
    - colors: tuple of (color_unweighted, color_weighted)
    - output_path: path to save image (optional)
    """
    def compute_mean_roc(cv_results):
        mean_fpr = np.linspace(0, 1, 100)
        tprs, aucs = [], []
        for r in cv_results:
            interp_tpr = np.interp(mean_fpr, r['fpr'], r['tpr'])
            interp_tpr[0] = 0.0
            tprs.append(interp_tpr)
            aucs.append(r['auc'])
        mean_tpr = np.mean(tprs, axis=0)
        mean_tpr[-1] = 1.0
        std_tpr = np.std(tprs, axis=0)
        return mean_fpr, mean_tpr, std_tpr, np.mean(aucs), np.std(aucs)

    fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharey=True)

    for ax, results, label, color in zip(axs, [unweighted_results, weighted_results], labels, colors):
        mean_fpr, mean_tpr, std_tpr, auc_mean, auc_std = compute_mean_roc(results)
        ax.plot(mean_fpr, mean_tpr, lw=3, color=color,
                label=f'{label} (AUC = {auc_mean:.2f} ± {auc_std:.2f})')
        ax.fill_between(mean_fpr, mean_tpr - std_tpr, mean_tpr + std_tpr,
                        color=color, alpha=0.3)
        ax.plot([0, 1], [0, 1], 'k--', lw=1)
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel('False Positive Rate', fontsize=14)
        ax.set_title(f'{label} Classifier', fontsize=16)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.tick_params(axis='both', labelsize=12)
        ax.legend(loc='lower right', fontsize=12)

    axs[0].set_ylabel('True Positive Rate', fontsize=14)
    plt.suptitle('ROC Comparison: Skin+Nares AD vs H', fontsize=18)
    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=600, bbox_inches='tight')


In [40]:
plot_roc_comparison(
    unweighted_results=combined_AD_vs_H_results,
    weighted_results=combined_AD_vs_H_weighted_results,
    labels=('Unweighted', f'Weighted (skin={skin_weight}, nares={nares_weight})'),
    colors=('black', 'darkred'),
    output_path='../Plots/Analysis_figures/Random_Forest/roc_comparison_unweighted_vs_weighted.png'
)


## Separate training for Skin and Nares, then combine

In [43]:
def run_group_stratified_cv_unweighted(X, y, groups, n_splits=5):
    folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)

    cv_results = []
    feature_importances = pd.DataFrame(index=X.columns)
    fold_aucs = []

    for i, (train_idx, test_idx) in enumerate(folds):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

        if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
            print(f"Skipping fold {i+1} due to insufficient class representation")
            continue

        clf = RandomForestClassifier(n_estimators=1000, random_state=42)
        clf.fit(X_train, y_train)
        probas = clf.predict_proba(X_test)

        feature_importances[f'fold_{i}'] = clf.feature_importances_
        fpr, tpr, _ = roc_curve(y_test, probas[:, 1])
        roc_auc = auc(fpr, tpr)
        fold_aucs.append(roc_auc)

        cv_results.append({
            'y_true': y_test,
            'y_proba': probas[:, 1],
            'fpr': fpr,
            'tpr': tpr,
            'auc': roc_auc,
            'fold': i
        })

    feature_importances['mean_importance'] = feature_importances.mean(axis=1)
    feature_importances['std_importance'] = feature_importances.std(axis=1)
    feature_importances = feature_importances.sort_values('mean_importance', ascending=False)

    return cv_results, feature_importances, fold_aucs


In [44]:
# Subset metadata and assign binary labels
meta_skin = metadata[metadata['group'].isin(['skin-ADL', 'skin-H'])].copy()
meta_nares = metadata[metadata['group'].isin(['nares-AD', 'nares-H'])].copy()

meta_skin['label'] = meta_skin['group'].map({'skin-ADL': 1, 'skin-H': 0})
meta_nares['label'] = meta_nares['group'].map({'nares-AD': 1, 'nares-H': 0})

# Match features
X_skin = df.loc[df.index.intersection(meta_skin.index)]
X_nares = df.loc[df.index.intersection(meta_nares.index)]
y_skin = meta_skin.loc[X_skin.index, 'label']
y_nares = meta_nares.loc[X_nares.index, 'label']
groups_skin = meta_skin.loc[X_skin.index, 'pid']
groups_nares = meta_nares.loc[X_nares.index, 'pid']


In [46]:
skin_cv_results, _, _ = run_group_stratified_cv_unweighted(X_skin, y_skin, groups_skin, n_splits=3)
nares_cv_results, _, _ = run_group_stratified_cv_unweighted(X_nares, y_nares, groups_nares, n_splits=3)


In [47]:
def extract_fold_predictions(cv_results):
    all_y_true, all_y_proba = [], []
    for fold in cv_results:
        all_y_true.extend(fold['y_true'])
        all_y_proba.extend(fold['y_proba'])
    return np.array(all_y_true), np.array(all_y_proba)

y_true_skin, proba_skin = extract_fold_predictions(skin_cv_results)
y_true_nares, proba_nares = extract_fold_predictions(nares_cv_results)

# Check that labels match — they should both be 0/1
assert np.array_equal(y_true_skin, y_true_nares) == False  # different samples


In [48]:
# Assign weights based on AUCs
auc_skin = 0.84
auc_nares = 0.69
total_auc = auc_skin + auc_nares

weight_skin = auc_skin / total_auc
weight_nares = auc_nares / total_auc

# Combine probs (if matching samples, otherwise normalize separately)
# Here, we'll just combine ROC curves separately
from sklearn.metrics import roc_curve, auc

# Interpolate to uniform fpr grid for ensemble
def interpolate_roc(cv_results):
    mean_fpr = np.linspace(0, 1, 100)
    tprs, aucs = [], []
    for r in cv_results:
        interp_tpr = np.interp(mean_fpr, r['fpr'], r['tpr'])
        interp_tpr[0] = 0.0
        tprs.append(interp_tpr)
        aucs.append(r['auc'])
    mean_tpr = np.mean(tprs, axis=0)
    std_tpr = np.std(tprs, axis=0)
    return mean_fpr, mean_tpr, std_tpr, np.mean(aucs), np.std(aucs)

fpr_skin, tpr_skin, std_skin, auc_skin, _ = interpolate_roc(skin_cv_results)
fpr_nares, tpr_nares, std_nares, auc_nares, _ = interpolate_roc(nares_cv_results)

# Combine TPRs using weighted average
combined_tpr = weight_skin * tpr_skin + weight_nares * tpr_nares
combined_auc = auc(fpr_skin, combined_tpr)


In [50]:
plt.figure(figsize=(6, 6))
plt.plot(fpr_skin, tpr_skin, '--', label=f'Skin (AUC={auc_skin:.2f})', color='blue')
plt.plot(fpr_nares, tpr_nares, '--', label=f'Nares (AUC={auc_nares:.2f})', color='orange')
plt.plot(fpr_skin, combined_tpr, '-', label=f'Ensemble (AUC={combined_auc:.2f})', color='darkred', lw=3)

plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('True Positive Rate', fontsize=14)
plt.title('Ensemble ROC: Skin + Nares AD vs H', fontsize=16)
plt.legend(loc='lower right', fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig('../Plots/Analysis_figures/Random_Forest/roc_ensemble_skin_nares_ADL_vs_H.png', dpi=600, bbox_inches='tight')