# 16S Random Forest Classifier

### Objective:
Use genus-level 16S V4 rRNA amplicon sequencing data to train a Random Forest classifier
that predicts skin condition type (healthy, acne non-lesional, or acne lesional)
based on skin, nose, and skin + nose bacterial abundance profiles.

Inputs:
- 16S-level microbial feature tables (absolute abundance CLR transformed)

Outputs:
- Predicted skin condition label per sample

Goals:
1. Evaluate classification performance for skin AD L vs H, skin AD NL vs H, skin AD NL vs AD L: We would expect AD L vs H to likely have the best model performance, followed by AD NL vs H, then AD NL vs L. What are the driving taxa which contribute to classification of feature importance by each model?
2. Evaluate classification performance for nares AD positive vs AD negative: Can the model predict AD condition based on nares samples alone? This would be interesting. Driving taxa?
3. If 1 and 2 look pretty good, evaluate classification performance combining skin and nares samples improves classification performance further. If true, this would be very interesting. 

### Things to note:
- Use absolute abundance with center log ratio (CLR) transformation with a small pseudocount (see function below) (not relative abundance, and not relative abundance and absolute concatenated).
- Be cautious about class imbalance: All samples from a single person should be entirely in the training set OR entirely in the test set — never split across both. Currently, the random splitting means there are going to be samples from the same person in both groups. This can superinflate results because your model can essentially "cheat" by learning person-specific features rather than what truly distinguishes classes like lesional vs. non-lesional. See code below from sklearn.model_selection import GroupShuffleSplit.
- It may be better to use ASV level features rather than Genera-level collapsed features (but we can discuss this later if needed).
- I'm no expert in ML, so use online resources ;)

### Visualization outcomes:
- A ROC curve graph with 3 lines of different colors, skin AD L vs H, AD NL vs H, AD NL vs L. Show Area Under Curve (AUC) values for each binary classification.
- A ROC curve graph with 2 lines of different colors, nares AD vs nares healthy control. Show Area Under Curve (AUC) values for the binary classification.
- Perhaps some horizontal barplots to show features which highest classification performance for each model.

In [163]:
# Standard library
import warnings
import logging
from itertools import combinations

# Scientific computing
import numpy as np
import pandas as pd
from numpy import array
import scipy
import scipy.stats as ss
from scipy import interp
from scipy.stats import wilcoxon, ttest_rel

# Visualization
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

# scikit-bio
from skbio.stats.distance import permanova

# BIOM format
import biom
from biom import load_table

# Scikit-learn
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, confusion_matrix, classification_report,
    roc_curve, auc, RocCurveDisplay
)
from sklearn.model_selection import GroupKFold, StratifiedKFold
from sklearn.preprocessing import label_binarize


### Define functions for analyses

In [164]:
def read_and_convert_biom(biom_path: str) -> pd.DataFrame:
    """
    Read a BIOM-format table and convert it into a Pandas dataframe.

    Parameters:
    biom_path (str): The file path to the biom table.

    Returns:
    pd.DataFrame: The dataframe corresponding to the biom table.
    """
    logging.info(f"Loading BIOM file: {biom_path}")
    try:
        biom_table = load_table(biom_path)
        df = pd.DataFrame(biom_table.to_dataframe().T)
        logging.info(f"BIOM table shape: {df.shape}")
        logging.info("BIOM file successfully converted to DataFrame")
        return df
    except Exception as e:
        logging.error(f"Error in processing BIOM file: {e}")
        raise

In [165]:
def clr_transform_with_pseudocount(df, pseudocount=1e-6, use_inf=False):
    """
    Applies Centered Log-Ratio (CLR) transformation to a DataFrame of abundance data.

    Parameters:
    - df: pandas DataFrame
        Microbial abundance table (samples as rows, features as columns).
    - pseudocount: float
        Small constant to add to avoid log(0). Default is 1e-6.

    Returns:
    - clr_df: pandas DataFrame
        CLR-transformed DataFrame.
    """
    # Add pseudocount
    df_pseudo = df + pseudocount

    # Take the natural log
    log_df = np.log(df_pseudo)

    if not use_inf:
        # Subtract geometric mean per sample (i.e., row-wise)
        clr_values = log_df.subtract(log_df.mean(axis=1), axis=0)

    else:
        # the log values could be -inf, so we won't take them into the mean calculation, but put them back to -inf
        # replace -inf with nan
        log_df = log_df.replace(-6, np.nan)
        log_mean = log_df.apply(np.nanmean, axis=1) # maybe slow?
        clr_values = log_df.subtract(log_mean, axis=0)
        clr_values = clr_values.fillna(-np.inf)

    return clr_values


In [166]:
def relative_abundance(df):
    """
    Convert a DataFrame of absolute abundances to relative abundances.

    Parameters:
    - df: pandas DataFrame
        Microbial abundance table (samples as rows, features as columns).

    Returns:
    - rel_df: pandas DataFrame
        Relative abundance table (samples as rows, features as columns).
    """
    return df.div(df.sum(axis=1), axis=0)


In [167]:
def filter_data(X_clr, y, train_idx, test_idx, y_train, y_test, groups, case_types):
    train_idx_filtered = train_idx[y_train.isin(case_types)]
    test_idx_filtered = test_idx[y_test.isin(case_types)]

    X_train_filtered = X_clr.iloc[train_idx_filtered]
    X_test_filtered = X_clr.iloc[test_idx_filtered]
    y_train_filtered = y.iloc[train_idx_filtered]
    y_test_filtered = y.iloc[test_idx_filtered]

    print(f"Training set size: {X_train_filtered.shape[0]} samples (original: {len(X_train)})")
    print(f"Testing set size: {X_test_filtered.shape[0]} samples (original: {len(X_test)})")

    assert set(groups[y_train_filtered.index]) & set(groups[y_test_filtered.index]) == set(), "Participant leakage detected!"
    return X_train_filtered, y_train_filtered, X_test_filtered, y_test_filtered

In [168]:
# Read in table at ASV level
biom_path = '..//Data/Tables/Absolute_Abundance_Tables/209766_filtered_feature_table.biom'
biom_tbl = load_table(biom_path)
df = pd.DataFrame(biom_tbl.to_dataframe().T)

# delete the prefix from the index
df.index = df.index.str.replace('15564.', '')
df

Unnamed: 0,GTGCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCCGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCCGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTAATCGGAATTATTGGGCGTAAAGCGAGTGCAGACGGTTACTTAAGCCAGATGTGAAATCCCC,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGAATTATTGGGCGTAAAGCGCGCGCAGGCGGTTTCTTAAGTCTGATGTGAAAGCCCC,GTGCCAGCAGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATTTATTGGGCGTAAAGGGCTCGTAGGTGGTTGATCGCGTCGGAAGTGTAATCTTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTCCAAGCGTTAATCGGAATTACTGGGCGTAAAGCGTGCGCAGGCGGTTGTGCAAGACCGATGTGAAATCCCC,GTGCCAGCCGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGATTTATTGGGCGTAAAGGGAGCGCAGGTGGTTTCTTAAGTCTGATGTGAAAGCCCA,GTGCCAGCCGCCGCGGTAATACGGAAGGTCCAGGCGTTATCCGGATTTATTGGGTTTAAAGGGAGCGTAGGCGGATTATTAAGTCAGTGGTGAAAGACGG,...,GTGCCAGCCGCCGCGGTAATACGTAGGGGGCAAGCGTTATCCGGATTTACTGGGTGTAAAGGGAGCGTAGACGGCGCAGCAAGTCTGATGTGAAAGGCAG,GTGCCAGCAGCCGCGGTAAGACAGAGGGTGCAAACGTTGCTCGGAATCACTGGGCGTAAAGGGCGTGTAGGCGGGAGAGAAAGTCGGGCGTGAAATCCCT,GTGCCAGCCGCGGTAATACGTAGGGGGCTAGCGTTGTCCGGAATCACTGGGCGTAAAGGGTTCGCAGGCGGAAATGCAAGTCAGGTGTAAAAGGCAGTAG,GTGCCAGCAGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCCGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCAGCCGCGGTAATACGGAGGGTGCAAGCGTTATCCGGAATCATTGGGTTTAAAGGGTCCGCAGGCGGATTTATAAGTCAGTGGTGAAAGCCTA,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCGAGCGTTGTCCGGAATTACTGGGTGTAAAGGGCGTGTAGGCGGGAAGGTAAGTCAGATGTGAAATACCG,GTGCCAGCCGCCGCGGTAATACGGAGGATGCGAGCGTTATTCGGAATCATTGGGTTTAAAGGGTCTGTAGGCGGGCTATTAAGTCAGAGGTGAAAGGTTT,GTGCCAGCCGCCGCGGTAAGACGAAGGGGGCTAGCGTTGTTCGGAATTACTGGGCGTAAAGCGCGTGCAGGCGGTTATCCAAGTCGGGTGTGAAAGCCTT,GTCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTGT
900344,984.0,611.0,114.0,82.0,22.0,15.0,8.0,8.0,6.0,3.0,...,0,0,0,0,0,0,0,0,0,0
900459,118.0,106.0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
900221,22.0,0,0,0,0,0,16.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
900570,389.0,0,0,0,8.0,0,11.0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
900092,3106.0,1707.0,59.0,32.0,3.0,0,0,0,0,7.0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9003972,1168.0,593.0,16.0,0,28.0,0,736.0,0,0,36.0,...,0,0,0,0,0,0,0,0,0,0
900097,24.0,0,0,0,0,0,33.0,0,0,0,...,8.0,5.0,1.0,0,0,0,0,0,0,0
900498,15.0,17.0,0,0,0,0,34.0,0,14.0,0,...,0,0,0,15.0,10.0,8.0,0,0,0,0
900276,0,0,30.0,0,0,0,151.0,0,0,0,...,0,0,0,0,0,0,11.0,3.0,2.0,1.0


In [169]:
# Load the metadata
metadata_path = '../Data/Metadata/updated_clean_ant_skin_metadata.tab'
metadata = pd.read_csv(metadata_path, sep='\t')
# metadata['case_type'].value_counts()

metadata['#sample-id'] = metadata['#sample-id'].str.replace('_', '')


# Set Sample-ID as the index for the metadata dataframe 
metadata = metadata.set_index('#sample-id')


# Create group column based on case_type to simplify group names
metadata['group'] = metadata['case_type'].map({
    'case-lesional skin': 'skin-ADL',
    'case-nonlesional skin': 'skin-ADNL', 
    'control-nonlesional skin': 'skin-H',
    'case-anterior nares': 'nares-AD',
    'control-anterior nares': 'nares-H'
})

metadata
# split data into training and testing

Unnamed: 0_level_0,PlateNumber,PlateLocation,i5,i5Sequence,i7,i7Sequence,identifier,Sequence,Plate ID,Well location,...,sex,enrolment_date,enrolment_season,hiv_exposure,hiv_status,household_size,o_scorad,FWD_filepath,REV_filepath,group
#sample-id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Ca009STL,1,A1,SA501,ATCGTACG,SA701,CGAGAGTT,SA701SA501,CGAGAGTT-ATCGTACG,1.010000e+21,A1,...,male,4/16/2015,Autumn,Unexposed,negative,4.0,40,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
900221,1,B1,SA502,ACTATCTG,SA701,CGAGAGTT,SA701SA502,CGAGAGTT-ACTATCTG,1.010000e+21,B1,...,female,8/11/2015,Winter,Unexposed,negative,7.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
Ca010EBL,1,C1,SA503,TAGCGAGT,SA701,CGAGAGTT,SA701SA503,CGAGAGTT-TAGCGAGT,1.010000e+21,C1,...,female,11/20/2014,Spring,Unexposed,negative,7.0,21,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
900460,1,D1,SA504,CTGCGTGT,SA701,CGAGAGTT,SA701SA504,CGAGAGTT-CTGCGTGT,1.010000e+21,D1,...,female,9/23/2015,Spring,Unexposed,,4.0,40,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
900051,1,E1,SA505,TCATCGAG,SA701,CGAGAGTT,SA701SA505,CGAGAGTT-TCATCGAG,1.010000e+21,E1,...,male,4/21/2015,Autumn,Unexposed,negative,7.0,41,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
Ca006ONL2,6,H1,SA508,GACACCGT,SB701,CTCGACTT,SB701SA508,CTCGACTT-GACACCGT,1.010000e+21,H1,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADL
Ca006ONNL,6,F2,SA506,CGTGAGTG,SB702,CGAAGTAT,SB702SA506,CGAAGTAT-CGTGAGTG,1.010000e+21,F2,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADNL
Ca006ONNL2,6,H2,SA508,GACACCGT,SB702,CGAAGTAT,SB702SA508,CGAAGTAT-GACACCGT,1.010000e+21,H2,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,skin-ADNL
Ca006ONPN,6,F3,SA506,CGTGAGTG,SB703,TAGCAGCT,SB703SA506,TAGCAGCT-CGTGAGTG,1.010000e+21,F3,...,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,/Users/yac027/Gallo_lab/16S_AD_Dube_Dupont/ato...,nares-AD


In [170]:
# Set overall styling for plots
sns.set_context("paper", font_scale=1.5)
sns.set_style("ticks")

# Custom function for group-stratified k-fold
def group_stratified_kfold(X, y, groups, n_splits=5, random_state=42):
    """
    Custom implementation of cross-validation that respects both groups and stratification
    
    Parameters:
    -----------
    X : DataFrame
        Feature matrix
    y : Series
        Target labels
    groups : Series
        Group labels for samples (e.g., pid)
    n_splits : int
        Number of folds
    random_state : int
        Random seed
    
    Returns:
    --------
    list of tuples
        Each tuple contains (train_indices, test_indices)
    """
    # Get unique groups
    unique_groups = np.unique(groups)
    np.random.seed(random_state)
    np.random.shuffle(unique_groups)
    
    # Create label distribution per group
    group_label_dist = {}
    for group in unique_groups:
        group_mask = groups == group
        group_y = y[group_mask]
        group_label_dist[group] = {label: np.sum(group_y == label) for label in np.unique(y)}
    
    # Initialize folds with empty lists
    folds = [[] for _ in range(n_splits)]
    
    # Track current distribution of labels in each fold
    fold_label_dist = [{label: 0 for label in np.unique(y)} for _ in range(n_splits)]
    
    # Sort groups by size (number of samples) in descending order to place larger groups first
    sorted_groups = sorted(unique_groups, key=lambda g: sum(groups == g), reverse=True)
    
    # Assign groups to folds
    for group in sorted_groups:
        # Calculate which fold would benefit most from this group
        # by minimizing the imbalance across all labels
        best_fold = 0
        min_imbalance = float('inf')
        
        group_size = sum(groups == group)
        
        for fold_idx in range(n_splits):
            # Calculate current imbalance if we add this group
            temp_fold_dist = fold_label_dist[fold_idx].copy()
            for label, count in group_label_dist[group].items():
                temp_fold_dist[label] += count
            
            # Calculate imbalance as variance of label proportions
            fold_size = sum(temp_fold_dist.values())
            if fold_size == 0:
                proportions = [0] * len(temp_fold_dist)
            else:
                proportions = [count / fold_size for count in temp_fold_dist.values()]
            
            imbalance = np.var(proportions) + fold_size / (sum(groups.shape) / n_splits)
            
            if imbalance < min_imbalance:
                min_imbalance = imbalance
                best_fold = fold_idx
        
        # Assign group to best fold
        folds[best_fold].extend(np.where(groups == group)[0])
        # Update fold distribution
        for label, count in group_label_dist[group].items():
            fold_label_dist[best_fold][label] += count
    
    # Create train/test indices
    train_test_indices = []
    for i in range(n_splits):
        test_idx = np.array(folds[i])
        train_idx = np.concatenate([folds[j] for j in range(n_splits) if j != i])
        train_test_indices.append((train_idx, test_idx))
    
    return train_test_indices

# Modified function to run group-stratified cross-validation with feature importance
def run_group_stratified_cv(X, y, groups, n_splits=5):
    # Get group-stratified folds
    folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)
    
    # Initialize arrays to store results
    cv_results = []
    feature_importances = pd.DataFrame(index=X.columns)
    fold_aucs = []
    
    # Run cross-validation
    for i, (train_idx, test_idx) in enumerate(folds):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
        
        # Handle cases where train set might contain only one class
        if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
            print(f"Skipping fold {i+1} due to insufficient class representation")
            continue
            
        # Train classifier
        clf = RandomForestClassifier(n_estimators=1000, random_state=42)
        clf.fit(X_train, y_train)
        
        # Predict probabilities
        probas = clf.predict_proba(X_test)
        
        # Store feature importance for this fold
        feature_importances[f'fold_{i}'] = clf.feature_importances_
        
        # Store results
        fpr, tpr, _ = roc_curve(y_test, probas[:, 1])
        roc_auc = auc(fpr, tpr)
        fold_aucs.append(roc_auc)
        
        cv_results.append({
            'y_true': y_test,
            'y_proba': probas[:, 1],
            'fpr': fpr,
            'tpr': tpr,
            'auc': roc_auc,
            'fold': i
        })
    
    # Calculate mean feature importance across folds
    feature_importances['mean_importance'] = feature_importances.mean(axis=1)
    feature_importances['std_importance'] = feature_importances.std(axis=1)
    feature_importances = feature_importances.sort_values('mean_importance', ascending=False)
    
    return cv_results, feature_importances, fold_aucs

# Function to compute and plot ROC curves with error bars plus perform pairwise comparisons
def plot_roc_curves_with_comparisons(tables_dict, metadata, pair_comparisons, n_splits=5):
    # Separate comparisons
    skin_comparisons = [("skin-ADL", "skin-H"), ("skin-ADNL", "skin-ADL"), ("skin-ADNL", "skin-H")]
    nares_comparisons = [("nares-AD", "nares-H")]

    fig, axs = plt.subplots(1, 2, figsize=(12, 6), sharey=True)

    color_map_skin = {
        'skin-ADL_vs_skin-H': 'blue',
        'skin-ADNL_vs_skin-ADL': 'green',
        'skin-ADNL_vs_skin-H': 'purple'
    }
    color_map_nares = {
        'nares-AD_vs_nares-H': 'orange'
    }

    all_feature_importances = {}
    all_fold_aucs = {}

    for label1, label2 in pair_comparisons:
        comparison_key = f'{label1}_vs_{label2}'
        is_skin = (label1.startswith("skin") and label2.startswith("skin"))
        ax = axs[0] if is_skin else axs[1]

        ax.set_title("Skin Comparisons" if is_skin else "Nares Comparison", fontsize=16)
        all_feature_importances[comparison_key] = {}
        all_fold_aucs[comparison_key] = {}

        for table_name, table in tables_dict.items():
            meta_subset = metadata[metadata['group'].isin([label1, label2])]
            common_samples = table.index.intersection(meta_subset.index)
            X = table.loc[common_samples]
            meta_filtered = meta_subset.loc[common_samples]

            if len(common_samples) < 10:
                print(f"Skipping {table_name} for {label1} vs {label2}: insufficient samples ({len(common_samples)})")
                continue

            y = meta_filtered['group'].map({label1: 0, label2: 1})
            groups = meta_filtered['pid']
            cv_results, feature_imp, fold_aucs = run_group_stratified_cv(X, y, groups, n_splits=n_splits)
            all_feature_importances[comparison_key][table_name] = feature_imp
            all_fold_aucs[comparison_key][table_name] = fold_aucs

            if len(cv_results) < 2:
                print(f"Skipping {table_name} for {label1} vs {label2}: CV returned insufficient results")
                continue

            mean_fpr = np.linspace(0, 1, 100)
            tprs, aucs = [], []
            for result in cv_results:
                tprs.append(interp(mean_fpr, result['fpr'], result['tpr']))
                tprs[-1][0] = 0.0
                aucs.append(result['auc'])

            mean_tpr = np.mean(tprs, axis=0)
            mean_tpr[-1] = 1.0
            std_tpr = np.std(tprs, axis=0)
            tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
            tprs_lower = np.maximum(mean_tpr - std_tpr, 0)

            # Color based on type
            plot_color = color_map_skin.get(comparison_key) if is_skin else color_map_nares.get(comparison_key)

            ax.plot(mean_fpr, mean_tpr, lw=2,
                    label=f'{label1} vs {label2} (AUC = {np.mean(aucs):.2f} ± {np.std(aucs):.2f})',
                    color=plot_color)
            ax.fill_between(mean_fpr, tprs_lower, tprs_upper, alpha=0.3, color=plot_color)

    for ax in axs:
        ax.plot([0, 1], [0, 1], 'k--', lw=1)
        ax.set_xlabel('False Positive Rate', fontsize=14)
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.tick_params(axis='both', which='major', labelsize=12)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.legend(loc='lower right', fontsize=10)

    axs[0].set_ylabel('True Positive Rate', fontsize=14)
    plt.tight_layout()

    pairwise_comparisons = compute_pairwise_comparisons(all_fold_aucs)
    return fig, all_feature_importances, pairwise_comparisons



# Function to perform pairwise statistical tests
def compute_pairwise_comparisons(fold_aucs_dict):
    """
    Perform pairwise statistical tests between methods for each task
    
    Parameters:
    -----------
    fold_aucs_dict : dict
        Dictionary with fold-wise AUC values for each method
    
    Returns:
    --------
    DataFrame
        Table with pairwise comparisons and p-values
    """
    results = []
    
    for task, methods_dict in fold_aucs_dict.items():
        # Get list of methods that have AUC values
        methods = list(methods_dict.keys())
        
        # Perform pairwise comparisons
        for method1, method2 in combinations(methods, 2):
            # Get AUC values for both methods
            aucs1 = methods_dict[method1]
            aucs2 = methods_dict[method2]
            
            # Ensure equal length (use only common folds)
            min_len = min(len(aucs1), len(aucs2))
            if min_len < 2:
                continue
                
            aucs1 = aucs1[:min_len]
            aucs2 = aucs2[:min_len]
            
            # Calculate mean AUCs
            mean_auc1 = np.mean(aucs1)
            mean_auc2 = np.mean(aucs2)
            diff_auc = mean_auc1 - mean_auc2
            
            # Perform statistical tests
            # Wilcoxon signed-rank test (non-parametric)
            try:
                _, p_wilcoxon = wilcoxon(aucs1, aucs2)
            except:
                p_wilcoxon = np.nan
                
            # Paired t-test (parametric)
            _, p_ttest = ttest_rel(aucs1, aucs2)
            
            # Store results
            results.append({
                'Task': task,
                'Method 1': method1,
                'Method 2': method2,
                'Mean AUC 1': mean_auc1,
                'Mean AUC 2': mean_auc2,
                'AUC Difference': diff_auc,
                'p-value (Wilcoxon)': p_wilcoxon,
                'p-value (t-test)': p_ttest,
                'Significant (p<0.05)': (p_wilcoxon < 0.05) if not np.isnan(p_wilcoxon) else (p_ttest < 0.05)
            })
    
    # Create DataFrame
    results_df = pd.DataFrame(results)
    
    return results_df

# Create a dictionary of tables
tables = {
    'V4': df
}

# Define pairwise comparisons
comparisons = [('skin-ADL', 'skin-H'), ('skin-ADNL', 'skin-ADL'), ('skin-ADNL', 'skin-H'), ('nares-AD', 'nares-H')]


# Set number of CV splits
n_cv_splits = 3

# Run analysis and plot
fig, feature_importances, pairwise_stats = plot_roc_curves_with_comparisons(tables, metadata, comparisons, n_splits=n_cv_splits)

# Display pairwise performance comparison table
print("\n" + "="*80)
print("Pairwise Performance Comparison of Methods")
print("="*80)
print(pairwise_stats.to_string(index=False, float_format=lambda x: f"{x:.4f}"))

# Display the top 10 most important features for each comparison and data type
for comparison, data_types in feature_importances.items():
    print(f"\n{'='*50}")
    print(f"Top 10 important features for {comparison}:")
    print(f"{'='*50}")
    
    for data_type, features_df in data_types.items():
        print(f"\n{data_type}:")
        print("-" * 40)
        top_features = features_df.sort_values('mean_importance', ascending=False).head(10)
        print(top_features[['mean_importance', 'std_importance']])

# Add supertitle to the plot
plt.suptitle('Random Forest Classifications by 16S V4 ASVs', fontsize=18, y=1.02)

plt.savefig('../Plots/Analysis_figures/Random_Forest/rf_ASV_skin-groups_nares-groups.png', dpi=600, bbox_inches='tight')
plt.show()



Pairwise Performance Comparison of Methods
Empty DataFrame
Columns: []
Index: []

Top 10 important features for skin-ADL_vs_skin-H:

V4:
----------------------------------------
                                                    mean_importance  \
GTGCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATT...         0.022228   
GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGAAT...         0.014513   
GTGCCAGCAGCCGCGGTAATACGGAGGGTGCGAGCGTTAATCGGAAT...         0.014384   
GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTATCCGGAAT...         0.011816   
GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGAAT...         0.011095   
GTGTCAGCAGCCGCGGTAATACGGAAGGTCCGGGCGTTATCCGGATT...         0.010805   
GTGCCAGCCGCCGCGGTAATACGTAGGTGGCAAGCGTTATCCGGAAT...         0.010598   
GTGCCAGCAGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATT...         0.010437   
GTGCCAGCCGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATT...         0.009562   
GTGCCAGCAGCCGCGGTAATACGGAAGGTCCAGGCGTTATCCGGATT...         0.009415   

                                       

In [171]:
# Add microbiome_type column based on group values
metadata['microbiome_type'] = metadata['group'].apply(lambda x: 'skin' if x.startswith('skin') else 'nares' if x.startswith('nares') else None)
metadata['microbiome_type'].value_counts()

microbiome_type
skin     305
nares    197
Name: count, dtype: int64

In [172]:
# Set overall styling for plots
sns.set_context("paper", font_scale=1.5)
sns.set_style("ticks")

# Custom function for group-stratified k-fold
def group_stratified_kfold(X, y, groups, n_splits=5, random_state=42):
    """
    Custom implementation of cross-validation that respects both groups and stratification
    
    Parameters:
    -----------
    X : DataFrame
        Feature matrix
    y : Series
        Target labels
    groups : Series
        Group labels for samples (e.g., pid)
    n_splits : int
        Number of folds
    random_state : int
        Random seed
    
    Returns:
    --------
    list of tuples
        Each tuple contains (train_indices, test_indices)
    """
    # Get unique groups
    unique_groups = np.unique(groups)
    np.random.seed(random_state)
    np.random.shuffle(unique_groups)
    
    # Create label distribution per group
    group_label_dist = {}
    for group in unique_groups:
        group_mask = groups == group
        group_y = y[group_mask]
        group_label_dist[group] = {label: np.sum(group_y == label) for label in np.unique(y)}
    
    # Initialize folds with empty lists
    folds = [[] for _ in range(n_splits)]
    
    # Track current distribution of labels in each fold
    fold_label_dist = [{label: 0 for label in np.unique(y)} for _ in range(n_splits)]
    
    # Sort groups by size (number of samples) in descending order to place larger groups first
    sorted_groups = sorted(unique_groups, key=lambda g: sum(groups == g), reverse=True)
    
    # Assign groups to folds
    for group in sorted_groups:
        # Calculate which fold would benefit most from this group
        # by minimizing the imbalance across all labels
        best_fold = 0
        min_imbalance = float('inf')
        
        group_size = sum(groups == group)
        
        for fold_idx in range(n_splits):
            # Calculate current imbalance if we add this group
            temp_fold_dist = fold_label_dist[fold_idx].copy()
            for label, count in group_label_dist[group].items():
                temp_fold_dist[label] += count
            
            # Calculate imbalance as variance of label proportions
            fold_size = sum(temp_fold_dist.values())
            if fold_size == 0:
                proportions = [0] * len(temp_fold_dist)
            else:
                proportions = [count / fold_size for count in temp_fold_dist.values()]
            
            imbalance = np.var(proportions) + fold_size / (sum(groups.shape) / n_splits)
            
            if imbalance < min_imbalance:
                min_imbalance = imbalance
                best_fold = fold_idx
        
        # Assign group to best fold
        folds[best_fold].extend(np.where(groups == group)[0])
        # Update fold distribution
        for label, count in group_label_dist[group].items():
            fold_label_dist[best_fold][label] += count
    
    # Create train/test indices
    train_test_indices = []
    for i in range(n_splits):
        test_idx = np.array(folds[i])
        train_idx = np.concatenate([folds[j] for j in range(n_splits) if j != i])
        train_test_indices.append((train_idx, test_idx))
    
    return train_test_indices

# Modified function to run group-stratified cross-validation with feature importance
def run_group_stratified_cv(X, y, groups, n_splits=5):
    # Get group-stratified folds
    folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)
    
    # Initialize arrays to store results
    cv_results = []
    feature_importances = pd.DataFrame(index=X.columns)
    fold_aucs = []
    
    # Run cross-validation
    for i, (train_idx, test_idx) in enumerate(folds):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
        
        # Handle cases where train set might contain only one class
        if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
            print(f"Skipping fold {i+1} due to insufficient class representation")
            continue
            
        # Train classifier
        clf = RandomForestClassifier(n_estimators=1000, random_state=42)
        clf.fit(X_train, y_train)
        
        # Predict probabilities
        probas = clf.predict_proba(X_test)
        
        # Store feature importance for this fold
        feature_importances[f'fold_{i}'] = clf.feature_importances_
        
        # Store results
        fpr, tpr, _ = roc_curve(y_test, probas[:, 1])
        roc_auc = auc(fpr, tpr)
        fold_aucs.append(roc_auc)
        
        cv_results.append({
            'y_true': y_test,
            'y_proba': probas[:, 1],
            'fpr': fpr,
            'tpr': tpr,
            'auc': roc_auc,
            'fold': i
        })
    
    # Calculate mean feature importance across folds
    feature_importances['mean_importance'] = feature_importances.mean(axis=1)
    feature_importances['std_importance'] = feature_importances.std(axis=1)
    feature_importances = feature_importances.sort_values('mean_importance', ascending=False)
    
    return cv_results, feature_importances, fold_aucs

# Function to compute and plot ROC curves with error bars plus perform pairwise comparisons
def plot_roc_curves_with_comparisons(tables_dict, metadata, pair_comparisons, n_splits=5):
    # Separate comparisons
    comparisons = [("skin", "nares")]

    fig, axs = plt.subplots(1, 1, figsize=(6, 6), sharey=True)

    color_map = {
        'skin_vs_nares': 'black',
        'skin-ADNL_vs_skin-H': 'purple'
    }


    all_feature_importances = {}
    all_fold_aucs = {}

    for label1, label2 in pair_comparisons:
        comparison_key = f'{label1}_vs_{label2}'
        is_skin = (label1.startswith("skin") and label2.startswith("skin"))
        ax = axs[0] if is_skin else axs[1]

        ax.set_title("Skin Comparisons" if is_skin else "Nares Comparison", fontsize=16)
        all_feature_importances[comparison_key] = {}
        all_fold_aucs[comparison_key] = {}

        for table_name, table in tables_dict.items():
            meta_subset = metadata[metadata['microbiome_type'].isin([label1, label2])]
            common_samples = table.index.intersection(meta_subset.index)
            X = table.loc[common_samples]
            meta_filtered = meta_subset.loc[common_samples]

            if len(common_samples) < 10:
                print(f"Skipping {table_name} for {label1} vs {label2}: insufficient samples ({len(common_samples)})")
                continue

            y = meta_filtered['microbiome_type'].map({label1: 0, label2: 1})
            groups = meta_filtered['pid']
            cv_results, feature_imp, fold_aucs = run_group_stratified_cv(X, y, groups, n_splits=n_splits)
            all_feature_importances[comparison_key][table_name] = feature_imp
            all_fold_aucs[comparison_key][table_name] = fold_aucs

            if len(cv_results) < 2:
                print(f"Skipping {table_name} for {label1} vs {label2}: CV returned insufficient results")
                continue

            mean_fpr = np.linspace(0, 1, 100)
            tprs, aucs = [], []
            for result in cv_results:
                tprs.append(interp(mean_fpr, result['fpr'], result['tpr']))
                tprs[-1][0] = 0.0
                aucs.append(result['auc'])

            mean_tpr = np.mean(tprs, axis=0)
            mean_tpr[-1] = 1.0
            std_tpr = np.std(tprs, axis=0)
            tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
            tprs_lower = np.maximum(mean_tpr - std_tpr, 0)

            # Color based on type
            plot_color = color_map.get(comparison_key) if is_skin else color_map.get(comparison_key)

            ax.plot(mean_fpr, mean_tpr, lw=2,
                    label=f'{label1} vs {label2} (AUC = {np.mean(aucs):.2f} ± {np.std(aucs):.2f})',
                    color=plot_color)
            ax.fill_between(mean_fpr, tprs_lower, tprs_upper, alpha=0.3, color=plot_color)

    for ax in axs:
        ax.plot([0, 1], [0, 1], 'k--', lw=1)
        ax.set_xlabel('False Positive Rate', fontsize=14)
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.tick_params(axis='both', which='major', labelsize=12)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.legend(loc='lower right', fontsize=10)

    axs[0].set_ylabel('True Positive Rate', fontsize=14)
    plt.tight_layout()

    pairwise_comparisons = compute_pairwise_comparisons(all_fold_aucs)
    return fig, all_feature_importances, pairwise_comparisons



# Function to perform pairwise statistical tests
def compute_pairwise_comparisons(fold_aucs_dict):
    """
    Perform pairwise statistical tests between methods for each task
    
    Parameters:
    -----------
    fold_aucs_dict : dict
        Dictionary with fold-wise AUC values for each method
    
    Returns:
    --------
    DataFrame
        Table with pairwise comparisons and p-values
    """
    results = []
    
    for task, methods_dict in fold_aucs_dict.items():
        # Get list of methods that have AUC values
        methods = list(methods_dict.keys())
        
        # Perform pairwise comparisons
        for method1, method2 in combinations(methods, 2):
            # Get AUC values for both methods
            aucs1 = methods_dict[method1]
            aucs2 = methods_dict[method2]
            
            # Ensure equal length (use only common folds)
            min_len = min(len(aucs1), len(aucs2))
            if min_len < 2:
                continue
                
            aucs1 = aucs1[:min_len]
            aucs2 = aucs2[:min_len]
            
            # Calculate mean AUCs
            mean_auc1 = np.mean(aucs1)
            mean_auc2 = np.mean(aucs2)
            diff_auc = mean_auc1 - mean_auc2
            
            # Perform statistical tests
            # Wilcoxon signed-rank test (non-parametric)
            try:
                _, p_wilcoxon = wilcoxon(aucs1, aucs2)
            except:
                p_wilcoxon = np.nan
                
            # Paired t-test (parametric)
            _, p_ttest = ttest_rel(aucs1, aucs2)
            
            # Store results
            results.append({
                'Task': task,
                'Method 1': method1,
                'Method 2': method2,
                'Mean AUC 1': mean_auc1,
                'Mean AUC 2': mean_auc2,
                'AUC Difference': diff_auc,
                'p-value (Wilcoxon)': p_wilcoxon,
                'p-value (t-test)': p_ttest,
                'Significant (p<0.05)': (p_wilcoxon < 0.05) if not np.isnan(p_wilcoxon) else (p_ttest < 0.05)
            })
    
    # Create DataFrame
    results_df = pd.DataFrame(results)
    
    return results_df

# Create a dictionary of tables
tables = {
    'V4': df
}

# Define pairwise comparisons
comparisons = [('skin', 'nares')]


# Set number of CV splits
n_cv_splits = 3

# Run analysis and plot
fig, feature_importances, pairwise_stats = plot_roc_curves_with_comparisons(tables, metadata, comparisons, n_splits=n_cv_splits)

# Display pairwise performance comparison table
print("\n" + "="*80)
print("Pairwise Performance Comparison of Methods")
print("="*80)
print(pairwise_stats.to_string(index=False, float_format=lambda x: f"{x:.4f}"))

# Display the top 10 most important features for each comparison and data type
for comparison, data_types in feature_importances.items():
    print(f"\n{'='*50}")
    print(f"Top 10 important features for {comparison}:")
    print(f"{'='*50}")
    
    for data_type, features_df in data_types.items():
        print(f"\n{data_type}:")
        print("-" * 40)
        top_features = features_df.sort_values('mean_importance', ascending=False).head(10)
        print(top_features[['mean_importance', 'std_importance']])

# Add supertitle to the plot
plt.suptitle('Random Forest Classifications by 16S V4 ASVs', fontsize=18, y=1.02)

plt.savefig('../Plots/Analysis_figures/Random_Forest/rf_ASV_skin_vs_nares.png', dpi=600, bbox_inches='tight')
plt.show()


No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.



Pairwise Performance Comparison of Methods
Empty DataFrame
Columns: []
Index: []

Top 10 important features for skin_vs_nares:

V4:
----------------------------------------
                                                    mean_importance  \
GTGCCAGCAGCCGCGGTAATACGTAGGTGACAAGCGTTGTCCGGATT...         0.055981   
GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAAT...         0.050087   
GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTATCCGGAAT...         0.033611   
GTGCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATT...         0.030663   
GTGCCAGCAGCCGCGGTAATACGGAGGGTGCGAGCGTTAATCGGAAT...         0.023109   
GTGCCAGCAGCCGCGGTAATACGTAGGGTGCGAGCGTTATCCGGAAT...         0.022101   
GTGCCAGCCGCCGCGGTAATACGTAGGTGGCAAGCGTTATCCGGAAT...         0.020351   
GTGCCAGCAGCCGCGGTAATACAGAGGGTGCGAGCGTTAATCGGAAT...         0.019715   
GTGCCAGCCGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATT...         0.018544   
GTGCCAGCAGCCGCGGTAATACGGAGGGTGCAAGCGTTAATCGGAAT...         0.018052   

                                            

In [173]:
# Save or print the top 10 ASVs per comparison
top_asv_summary = []

for comparison_key, method_dict in feature_importances.items():
    print(f"\nTop 10 ASVs for comparison: {comparison_key}")
    print("=" * 60)

    for method_name, imp_df in method_dict.items():
        top_asvs = imp_df.sort_values('mean_importance', ascending=False).head(10)
        print(f"\n{method_name} dataset:")
        print(top_asvs[['mean_importance', 'std_importance']])

        # Optional: store in a summary list for later export
        for feature_name, row in top_asvs.iterrows():
            top_asv_summary.append({
                'Comparison': comparison_key,
                'Method': method_name,
                'ASV': feature_name,
                'Mean Importance': row['mean_importance'],
                'Std Importance': row['std_importance']
            })

# Convert to DataFrame for CSV or further use
top_asv_df = pd.DataFrame(top_asv_summary)

# Save to CSV
top_asv_df.to_csv('../Plots/Analysis_figures/Random_Forest/top10_ASVs_per_comparison.csv', index=False)



Top 10 ASVs for comparison: skin_vs_nares

V4 dataset:
                                                    mean_importance  \
GTGCCAGCAGCCGCGGTAATACGTAGGTGACAAGCGTTGTCCGGATT...         0.055981   
GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAAT...         0.050087   
GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTATCCGGAAT...         0.033611   
GTGCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATT...         0.030663   
GTGCCAGCAGCCGCGGTAATACGGAGGGTGCGAGCGTTAATCGGAAT...         0.023109   
GTGCCAGCAGCCGCGGTAATACGTAGGGTGCGAGCGTTATCCGGAAT...         0.022101   
GTGCCAGCCGCCGCGGTAATACGTAGGTGGCAAGCGTTATCCGGAAT...         0.020351   
GTGCCAGCAGCCGCGGTAATACAGAGGGTGCGAGCGTTAATCGGAAT...         0.019715   
GTGCCAGCCGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATT...         0.018544   
GTGCCAGCAGCCGCGGTAATACGGAGGGTGCAAGCGTTAATCGGAAT...         0.018052   

                                                    std_importance  
GTGCCAGCAGCCGCGGTAATACGTAGGTGACAAGCGTTGTCCGGATT...        0.004162  
GTGCCAGCAGCCGCGGTAATACGT

In [174]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_curve, auc
from scipy.stats import wilcoxon, ttest_rel
from itertools import combinations
from scipy import interp
import warnings
warnings.filterwarnings('ignore')

sns.set_context("paper", font_scale=1.5)
sns.set_style("ticks")

def group_stratified_kfold(X, y, groups, n_splits=5, random_state=42):
    unique_groups = np.unique(groups)
    np.random.seed(random_state)
    np.random.shuffle(unique_groups)

    group_label_dist = {}
    for group in unique_groups:
        group_mask = groups == group
        group_y = y[group_mask]
        group_label_dist[group] = {label: np.sum(group_y == label) for label in np.unique(y)}

    folds = [[] for _ in range(n_splits)]
    fold_label_dist = [{label: 0 for label in np.unique(y)} for _ in range(n_splits)]
    sorted_groups = sorted(unique_groups, key=lambda g: sum(groups == g), reverse=True)

    for group in sorted_groups:
        best_fold, min_imbalance = 0, float('inf')
        for fold_idx in range(n_splits):
            temp_fold_dist = fold_label_dist[fold_idx].copy()
            for label, count in group_label_dist[group].items():
                temp_fold_dist[label] += count
            fold_size = sum(temp_fold_dist.values())
            proportions = [count / fold_size if fold_size else 0 for count in temp_fold_dist.values()]
            imbalance = np.var(proportions) + fold_size / (len(groups) / n_splits)
            if imbalance < min_imbalance:
                min_imbalance = imbalance
                best_fold = fold_idx
        folds[best_fold].extend(np.where(groups == group)[0])
        for label, count in group_label_dist[group].items():
            fold_label_dist[best_fold][label] += count

    train_test_indices = []
    for i in range(n_splits):
        test_idx = np.array(folds[i])
        train_idx = np.concatenate([folds[j] for j in range(n_splits) if j != i])
        train_test_indices.append((train_idx, test_idx))

    return train_test_indices

def run_group_stratified_cv(X, y, groups, n_splits=5):
    folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)
    cv_results = []
    feature_importances = pd.DataFrame(index=X.columns)
    fold_aucs = []

    for i, (train_idx, test_idx) in enumerate(folds):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

        if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
            print(f"Skipping fold {i+1} due to insufficient class representation")
            continue

        clf = RandomForestClassifier(n_estimators=1000, random_state=42)
        clf.fit(X_train, y_train)
        probas = clf.predict_proba(X_test)
        feature_importances[f'fold_{i}'] = clf.feature_importances_

        fpr, tpr, _ = roc_curve(y_test, probas[:, 1])
        roc_auc = auc(fpr, tpr)
        fold_aucs.append(roc_auc)

        cv_results.append({
            'y_true': y_test,
            'y_proba': probas[:, 1],
            'fpr': fpr,
            'tpr': tpr,
            'auc': roc_auc,
            'fold': i
        })

    feature_importances['mean_importance'] = feature_importances.mean(axis=1)
    feature_importances['std_importance'] = feature_importances.std(axis=1)
    feature_importances = feature_importances.sort_values('mean_importance', ascending=False)
    return cv_results, feature_importances, fold_aucs

def plot_roc_curves_with_comparisons(tables_dict, metadata, pair_comparisons, n_splits=5):
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))

    color_map_skin = {
        'skin-ADL_vs_skin-H': 'green',
        'skin-ADNL_vs_skin-ADL': 'blue',
        'skin-ADNL_vs_skin-H': 'purple'
    }

    all_feature_importances = {}
    all_fold_aucs = {}

    for label1, label2 in pair_comparisons:
        comparison_key = f'{label1}_vs_{label2}'
        all_feature_importances[comparison_key] = {}
        all_fold_aucs[comparison_key] = {}

        for table_name, table in tables_dict.items():
            meta_subset = metadata[metadata['group'].isin([label1, label2])]
            common_samples = table.index.intersection(meta_subset.index)
            X = table.loc[common_samples]
            meta_filtered = meta_subset.loc[common_samples]

            if len(common_samples) < 10:
                print(f"Skipping {table_name} for {label1} vs {label2}: insufficient samples ({len(common_samples)})")
                continue

            y = meta_filtered['group'].map({label1: 0, label2: 1})
            groups = meta_filtered['pid']
            cv_results, feature_imp, fold_aucs = run_group_stratified_cv(X, y, groups, n_splits=n_splits)

            if len(cv_results) < 2:
                print(f"Skipping {table_name} for {label1} vs {label2}: CV returned insufficient results")
                continue

            all_feature_importances[comparison_key][table_name] = feature_imp
            all_fold_aucs[comparison_key][table_name] = fold_aucs

            mean_fpr = np.linspace(0, 1, 100)
            tprs, aucs = [], []
            for result in cv_results:
                tprs.append(interp(mean_fpr, result['fpr'], result['tpr']))
                tprs[-1][0] = 0.0
                aucs.append(result['auc'])

            mean_tpr = np.mean(tprs, axis=0)
            mean_tpr[-1] = 1.0
            std_tpr = np.std(tprs, axis=0)
            tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
            tprs_lower = np.maximum(mean_tpr - std_tpr, 0)

            ax.plot(mean_fpr, mean_tpr, lw=2,
                    label=f'{label1} vs {label2} (AUC = {np.mean(aucs):.2f} ± {np.std(aucs):.2f})',
                    color=color_map_skin.get(comparison_key, 'grey'))
            ax.fill_between(mean_fpr, tprs_lower, tprs_upper, alpha=0.3,
                            color=color_map_skin.get(comparison_key, 'grey'))

    ax.plot([0, 1], [0, 1], 'k--', lw=1)
    ax.set_xlabel('False Positive Rate', fontsize=14)
    ax.set_ylabel('True Positive Rate', fontsize=14)
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.grid(True, linestyle='--', alpha=0.7)
    ax.legend(loc='lower right', fontsize=10)
    plt.tight_layout()

    pairwise_comparisons = compute_pairwise_comparisons(all_fold_aucs)
    return fig, all_feature_importances, pairwise_comparisons

def compute_pairwise_comparisons(fold_aucs_dict):
    results = []
    for task, methods_dict in fold_aucs_dict.items():
        methods = list(methods_dict.keys())
        for method1, method2 in combinations(methods, 2):
            aucs1, aucs2 = methods_dict[method1], methods_dict[method2]
            min_len = min(len(aucs1), len(aucs2))
            if min_len < 2:
                continue
            aucs1, aucs2 = aucs1[:min_len], aucs2[:min_len]
            try:
                _, p_wilcoxon = wilcoxon(aucs1, aucs2)
            except:
                p_wilcoxon = np.nan
            _, p_ttest = ttest_rel(aucs1, aucs2)
            results.append({
                'Task': task,
                'Method 1': method1,
                'Method 2': method2,
                'Mean AUC 1': np.mean(aucs1),
                'Mean AUC 2': np.mean(aucs2),
                'AUC Difference': np.mean(aucs1) - np.mean(aucs2),
                'p-value (Wilcoxon)': p_wilcoxon,
                'p-value (t-test)': p_ttest,
                'Significant (p<0.05)': (p_wilcoxon < 0.05) if not np.isnan(p_wilcoxon) else (p_ttest < 0.05)
            })
    return pd.DataFrame(results)

# Sample execution
tables = {'V4': df}
comparisons = [('skin-ADL', 'skin-H'), ('skin-ADNL', 'skin-ADL'), ('skin-ADNL', 'skin-H')]
n_cv_splits = 3

fig, feature_importances, pairwise_stats = plot_roc_curves_with_comparisons(tables, metadata, comparisons, n_splits=n_cv_splits)

print("\n" + "="*80)
print("Pairwise Performance Comparison of Methods")
print("="*80)
print(pairwise_stats.to_string(index=False, float_format=lambda x: f"{x:.4f}"))

for comparison, data_types in feature_importances.items():
    print(f"\n{'='*50}\nTop 10 important features for {comparison}:\n{'='*50}")
    for data_type, features_df in data_types.items():
        print(f"\n{data_type}:\n{'-'*40}")
        print(features_df[['mean_importance', 'std_importance']].head(10))

plt.suptitle('Random Forest Classifications by 16S V4 ASVs', fontsize=16, y=1.02)
plt.savefig('../Plots/Analysis_figures/Random_Forest/rf_ASV_skin-groups-only.png', dpi=600, bbox_inches='tight')
plt.show()



Pairwise Performance Comparison of Methods
Empty DataFrame
Columns: []
Index: []

Top 10 important features for skin-ADL_vs_skin-H:

V4:
----------------------------------------
                                                    mean_importance  \
GTGCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATT...         0.022228   
GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGAAT...         0.014513   
GTGCCAGCAGCCGCGGTAATACGGAGGGTGCGAGCGTTAATCGGAAT...         0.014384   
GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTATCCGGAAT...         0.011816   
GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGAAT...         0.011095   
GTGTCAGCAGCCGCGGTAATACGGAAGGTCCGGGCGTTATCCGGATT...         0.010805   
GTGCCAGCCGCCGCGGTAATACGTAGGTGGCAAGCGTTATCCGGAAT...         0.010598   
GTGCCAGCAGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATT...         0.010437   
GTGCCAGCCGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATT...         0.009562   
GTGCCAGCAGCCGCGGTAATACGGAAGGTCCAGGCGTTATCCGGATT...         0.009415   

                                       

In [175]:
# Set overall styling for plots
sns.set_context("paper", font_scale=1.5)
sns.set_style("ticks")

# Custom function for group-stratified k-fold
def group_stratified_kfold(X, y, groups, n_splits=5, random_state=42):
    """
    Custom implementation of cross-validation that respects both groups and stratification
    
    Parameters:
    -----------
    X : DataFrame
        Feature matrix
    y : Series
        Target labels
    groups : Series
        Group labels for samples (e.g., pid)
    n_splits : int
        Number of folds
    random_state : int
        Random seed
    
    Returns:
    --------
    list of tuples
        Each tuple contains (train_indices, test_indices)
    """
    # Get unique groups
    unique_groups = np.unique(groups)
    np.random.seed(random_state)
    np.random.shuffle(unique_groups)
    
    # Create label distribution per group
    group_label_dist = {}
    for group in unique_groups:
        group_mask = groups == group
        group_y = y[group_mask]
        group_label_dist[group] = {label: np.sum(group_y == label) for label in np.unique(y)}
    
    # Initialize folds with empty lists
    folds = [[] for _ in range(n_splits)]
    
    # Track current distribution of labels in each fold
    fold_label_dist = [{label: 0 for label in np.unique(y)} for _ in range(n_splits)]
    
    # Sort groups by size (number of samples) in descending order to place larger groups first
    sorted_groups = sorted(unique_groups, key=lambda g: sum(groups == g), reverse=True)
    
    # Assign groups to folds
    for group in sorted_groups:
        # Calculate which fold would benefit most from this group
        # by minimizing the imbalance across all labels
        best_fold = 0
        min_imbalance = float('inf')
        
        group_size = sum(groups == group)
        
        for fold_idx in range(n_splits):
            # Calculate current imbalance if we add this group
            temp_fold_dist = fold_label_dist[fold_idx].copy()
            for label, count in group_label_dist[group].items():
                temp_fold_dist[label] += count
            
            # Calculate imbalance as variance of label proportions
            fold_size = sum(temp_fold_dist.values())
            if fold_size == 0:
                proportions = [0] * len(temp_fold_dist)
            else:
                proportions = [count / fold_size for count in temp_fold_dist.values()]
            
            imbalance = np.var(proportions) + fold_size / (sum(groups.shape) / n_splits)
            
            if imbalance < min_imbalance:
                min_imbalance = imbalance
                best_fold = fold_idx
        
        # Assign group to best fold
        folds[best_fold].extend(np.where(groups == group)[0])
        # Update fold distribution
        for label, count in group_label_dist[group].items():
            fold_label_dist[best_fold][label] += count
    
    # Create train/test indices
    train_test_indices = []
    for i in range(n_splits):
        test_idx = np.array(folds[i])
        train_idx = np.concatenate([folds[j] for j in range(n_splits) if j != i])
        train_test_indices.append((train_idx, test_idx))
    
    return train_test_indices

# Modified function to run group-stratified cross-validation with feature importance
def run_group_stratified_cv(X, y, groups, n_splits=5):
    # Get group-stratified folds
    folds = group_stratified_kfold(X, y, groups, n_splits=n_splits)
    
    # Initialize arrays to store results
    cv_results = []
    feature_importances = pd.DataFrame(index=X.columns)
    fold_aucs = []
    
    # Run cross-validation
    for i, (train_idx, test_idx) in enumerate(folds):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
        
        # Handle cases where train set might contain only one class
        if len(np.unique(y_train)) < 2 or len(np.unique(y_test)) < 2:
            print(f"Skipping fold {i+1} due to insufficient class representation")
            continue
            
        # Train classifier
        clf = RandomForestClassifier(n_estimators=1000, random_state=42)
        clf.fit(X_train, y_train)
        
        # Predict probabilities
        probas = clf.predict_proba(X_test)
        
        # Store feature importance for this fold
        feature_importances[f'fold_{i}'] = clf.feature_importances_
        
        # Store results
        fpr, tpr, _ = roc_curve(y_test, probas[:, 1])
        roc_auc = auc(fpr, tpr)
        fold_aucs.append(roc_auc)
        
        cv_results.append({
            'y_true': y_test,
            'y_proba': probas[:, 1],
            'fpr': fpr,
            'tpr': tpr,
            'auc': roc_auc,
            'fold': i
        })
    
    # Calculate mean feature importance across folds
    feature_importances['mean_importance'] = feature_importances.mean(axis=1)
    feature_importances['std_importance'] = feature_importances.std(axis=1)
    feature_importances = feature_importances.sort_values('mean_importance', ascending=False)
    
    return cv_results, feature_importances, fold_aucs

# Function to compute and plot ROC curves with error bars plus perform pairwise comparisons
def plot_roc_curves_with_comparisons(tables_dict, metadata, pair_comparisons, n_splits=5):
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))

    color_map = {
        'skin-ADL_vs_skin-H': '#4343a3',
        'skin-ADNL_vs_skin-ADL': '#6ab2bd',
        'skin-ADNL_vs_skin-H': '#dd7966',
        'nares-AD_vs_nares-H': 'orange'
    }

    all_feature_importances = {}
    all_fold_aucs = {}

    for label1, label2 in pair_comparisons:
        comparison_key = f'{label1}_vs_{label2}'
        all_feature_importances[comparison_key] = {}
        all_fold_aucs[comparison_key] = {}

        for table_name, table in tables_dict.items():
            meta_subset = metadata[metadata['group'].isin([label1, label2])]
            common_samples = table.index.intersection(meta_subset.index)
            X = table.loc[common_samples]
            meta_filtered = meta_subset.loc[common_samples]

            if len(common_samples) < 10:
                print(f"Skipping {table_name} for {label1} vs {label2}: insufficient samples ({len(common_samples)})")
                continue

            y = meta_filtered['group'].map({label1: 0, label2: 1})
            groups = meta_filtered['pid']
            cv_results, feature_imp, fold_aucs = run_group_stratified_cv(X, y, groups, n_splits=n_splits)
            all_feature_importances[comparison_key][table_name] = feature_imp
            all_fold_aucs[comparison_key][table_name] = fold_aucs

            if len(cv_results) < 2:
                print(f"Skipping {table_name} for {label1} vs {label2}: CV returned insufficient results")
                continue

            mean_fpr = np.linspace(0, 1, 100)
            tprs, aucs = [], []
            for result in cv_results:
                tprs.append(np.interp(mean_fpr, result['fpr'], result['tpr']))
                tprs[-1][0] = 0.0
                aucs.append(result['auc'])

            mean_tpr = np.mean(tprs, axis=0)
            mean_tpr[-1] = 1.0
            std_tpr = np.std(tprs, axis=0)
            tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
            tprs_lower = np.maximum(mean_tpr - std_tpr, 0)

            plot_color = color_map.get(comparison_key, 'gray')

            ax.plot(mean_fpr, mean_tpr, lw=2,
                    label=f'{label1} vs {label2} (AUC = {np.mean(aucs):.2f} ± {np.std(aucs):.2f})',
                    color=plot_color)
            ax.fill_between(mean_fpr, tprs_lower, tprs_upper, alpha=0.3, color=plot_color)

    ax.plot([0, 1], [0, 1], 'k--', lw=1)
    ax.set_xlabel('False Positive Rate', fontsize=14)
    ax.set_ylabel('True Positive Rate', fontsize=14)
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.tick_params(axis='both', which='major', labelsize=12)
    ax.grid(True, linestyle='--', alpha=0.7)
    ax.legend(loc='lower right', fontsize=10)
    plt.tight_layout()

    pairwise_comparisons = compute_pairwise_comparisons(all_fold_aucs)
    return fig, all_feature_importances, pairwise_comparisons


# Function to perform pairwise statistical tests
def compute_pairwise_comparisons(fold_aucs_dict):
    """
    Perform pairwise statistical tests between methods for each task
    
    Parameters:
    -----------
    fold_aucs_dict : dict
        Dictionary with fold-wise AUC values for each method
    
    Returns:
    --------
    DataFrame
        Table with pairwise comparisons and p-values
    """
    results = []
    
    for task, methods_dict in fold_aucs_dict.items():
        # Get list of methods that have AUC values
        methods = list(methods_dict.keys())
        
        # Perform pairwise comparisons
        for method1, method2 in combinations(methods, 2):
            # Get AUC values for both methods
            aucs1 = methods_dict[method1]
            aucs2 = methods_dict[method2]
            
            # Ensure equal length (use only common folds)
            min_len = min(len(aucs1), len(aucs2))
            if min_len < 2:
                continue
                
            aucs1 = aucs1[:min_len]
            aucs2 = aucs2[:min_len]
            
            # Calculate mean AUCs
            mean_auc1 = np.mean(aucs1)
            mean_auc2 = np.mean(aucs2)
            diff_auc = mean_auc1 - mean_auc2
            
            # Perform statistical tests
            # Wilcoxon signed-rank test (non-parametric)
            try:
                _, p_wilcoxon = wilcoxon(aucs1, aucs2)
            except:
                p_wilcoxon = np.nan
                
            # Paired t-test (parametric)
            _, p_ttest = ttest_rel(aucs1, aucs2)
            
            # Store results
            results.append({
                'Task': task,
                'Method 1': method1,
                'Method 2': method2,
                'Mean AUC 1': mean_auc1,
                'Mean AUC 2': mean_auc2,
                'AUC Difference': diff_auc,
                'p-value (Wilcoxon)': p_wilcoxon,
                'p-value (t-test)': p_ttest,
                'Significant (p<0.05)': (p_wilcoxon < 0.05) if not np.isnan(p_wilcoxon) else (p_ttest < 0.05)
            })
    
    # Create DataFrame
    results_df = pd.DataFrame(results)
    
    return results_df

# Create a dictionary of tables
tables = {
    'V4': df
}

# Define pairwise comparisons
comparisons = [('nares-AD', 'nares-H')]


# Set number of CV splits
n_cv_splits = 3

# Run analysis and plot
fig, feature_importances, pairwise_stats = plot_roc_curves_with_comparisons(tables, metadata, comparisons, n_splits=n_cv_splits)

# Display pairwise performance comparison table
print("\n" + "="*80)
print("Pairwise Performance Comparison of Methods")
print("="*80)
print(pairwise_stats.to_string(index=False, float_format=lambda x: f"{x:.4f}"))

# Display the top 10 most important features for each comparison and data type
for comparison, data_types in feature_importances.items():
    print(f"\n{'='*50}")
    print(f"Top 10 important features for {comparison}:")
    print(f"{'='*50}")
    
    for data_type, features_df in data_types.items():
        print(f"\n{data_type}:")
        print("-" * 40)
        top_features = features_df.sort_values('mean_importance', ascending=False).head(10)
        print(top_features[['mean_importance', 'std_importance']])

# Add supertitle to the plot
plt.suptitle('Random Forest Classification by 16S V4 ASVs', fontsize=18, y=1.02)

plt.savefig('../Plots/Analysis_figures/Random_Forest/rf_ASV__nares-groups-only.png', dpi=600, bbox_inches='tight')
plt.show()



Pairwise Performance Comparison of Methods
Empty DataFrame
Columns: []
Index: []

Top 10 important features for nares-AD_vs_nares-H:

V4:
----------------------------------------
                                                    mean_importance  \
GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTATCCGGAAT...         0.027070   
GTGCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATT...         0.022989   
GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAAT...         0.021538   
GTGCCAGCAGCCGCGGTAATACGTAGGTGACAAGCGTTGTCCGGATT...         0.021311   
GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGAAT...         0.019339   
GTGCCAGCAGCCGCGGTAATACGGAGGGTGCGAGCGTTAATCGGAAT...         0.017363   
GTGCCAGCAGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATT...         0.017226   
GTGCCAGCCGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATT...         0.015726   
GTGCCAGCCGCCGCGGTAATACGTAGGTGGCAAGCGTTATCCGGAAT...         0.012700   
GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGAAT...         0.012563   

                                      