# Feature Importance Rankings from Random Forest Classification

In [1]:
# Standard library
import warnings
import logging
from itertools import combinations
from functools import reduce
import re

# Scientific computing
import numpy as np
import pandas as pd
from numpy import array
import scipy
import scipy.stats as ss
from scipy import interp
from scipy.stats import wilcoxon, ttest_rel
from scipy import stats

# Visualization
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns
import matplotlib.gridspec as gridspec
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
from matplotlib.colors import LinearSegmentedColormap

# scikit-bio
from skbio.stats.distance import permanova

# BIOM format
import biom
from biom import load_table

# Scikit-learn
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, confusion_matrix, classification_report,
    roc_curve, auc, RocCurveDisplay
)
from sklearn.model_selection import GroupKFold, StratifiedKFold
from sklearn.preprocessing import label_binarize


In [2]:
# Define file paths
files = {
    "skin_vs_nares": "../Data/RF_Feature_Importances/feature_importance_skin_vs_nares_ASV-name_known.csv",
    "skin_ADL_vs_H": "../Data/RF_Feature_Importances/feature_importance_skin-ADL_vs_skin-H_ASV-name_known.csv",
    "skin_ADNL_vs_H": "../Data/RF_Feature_Importances/feature_importance_skin-ADNL_vs_skin-H_ASV-name_known.csv",
    "skin_ADNL_vs_ADL": "../Data/RF_Feature_Importances/feature_importance_skin-ADNL_vs_skin-ADL_ASV-name_known.csv",
    "nares_AD_vs_H": "../Data/RF_Feature_Importances/feature_importance_nares_AD_vs_H_ASV-name_known.csv"
}

# Read and process each file
rank_dfs = {}

for key, path in files.items():
    df = pd.read_csv(path)
    
    # Remove rows where the index starts with ' g__ASV'
    df['ASV_Name'] = df['ASV_Name'].astype(str)
    df = df[~df["ASV_Name"].str.strip().str.startswith('g___')]
    # df = df[~df["Genus"].str.strip().str.startswith('g___ASV-6')]
    # df = df[~df["Genus"].str.strip().str.startswith('g___ASV-7')]

    df["ASV_Name"] = df["ASV_Name"].astype(str)
    df["rank"] = range(1, len(df) + 1)
    df = df[["ASV_Name", "rank"]]
    rank_dfs[key] = df

# List of comparisons to include
keys_to_merge = ["skin_vs_nares", "skin_ADL_vs_H", "skin_ADNL_vs_H", "skin_ADNL_vs_ADL", "nares_AD_vs_H"]

# Merge all rank dataframes
merged = reduce(
    lambda left, right: pd.merge(left, right, on="ASV_Name", how="outer"),
    [rank_dfs[key].rename(columns={"rank": key}) for key in keys_to_merge]
)

# Fill missing values with max_rank + 1
max_rank = max(df["rank"].max() for df in rank_dfs.values())
merged = merged.set_index("ASV_Name").fillna(max_rank + 1)


In [3]:
# Get top 10 genera from each comparison separately
top_genera = set()
for comparison in keys_to_merge:
    top10_in_col = merged.sort_values(comparison).head(10).index
    top_genera.update(top10_in_col)

# Subset merged to include only these genera
top10 = merged.loc[list(top_genera)]

# Sort again based on total summed rank (optional, for nicer display)
top10['sum'] = top10.sum(axis=1)
top10 = top10.sort_values('sum').drop(columns='sum')

# Replace specific genus names
top10 = top10.rename(index={
    'g__F0422_ASV-1': 'g__Veillonella_F0422_ASV-1',
})

# Remove 'g__' prefix and '_ASV' from index names
top10.index = top10.index.str.replace('g__', '')
top10.index = top10.index.str.replace('_ASV', ' ASV')

top10

Unnamed: 0_level_0,skin_vs_nares,skin_ADL_vs_H,skin_ADNL_vs_H,skin_ADNL_vs_ADL,nares_AD_vs_H
ASV_Name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Streptococcus ASV-1,4,1,4,1,1
Staphylococcus ASV-1,3,7,1,5,2
Cutibacterium ASV-1,7,8,3,13,6
Staphylococcus ASV-2,5,19,2,18,11
Cutibacterium ASV-2,9,17,6,16,10
Micrococcus ASV-1,8,4,10,15,21
Haemophilus_D_734546 ASV-1,6,30,15,10,4
Streptococcus ASV-2,15,16,37,8,7
Prevotella ASV-1,22,6,12,29,18
Brachybacterium ASV-1,23,13,5,30,26


In [4]:
# Read in the metadata 
metadata_path = '../Metadata/16S_AD_South-Africa_metadata_subset.tsv'
metadata = pd.read_csv(metadata_path, sep='\t')

metadata['#sample-id'] = metadata['#sample-id'].str.replace('_', '')
# Set Sample-ID as the index for the metadata dataframe 
metadata = metadata.set_index('#sample-id')


# Create group column based on case_type to simplify group names
metadata['group'] = metadata['case_type'].map({
    'case-lesional_skin': 'skin-ADL',
    'case-nonlesional_skin': 'skin-ADNL', 
    'control-nonlesional_skin': 'skin-H',
    'case-anterior_nares': 'nares-AD',
    'control-anterior_nares': 'nares-H'
})

metadata

Unnamed: 0_level_0,PlateNumber,PlateLocation,i5,i5Sequence,i7,i7Sequence,identifier,Sequence,Plate ID,Well location,...,specimen,age_months,sex,enrolment_date,enrolment_season,hiv_exposure,hiv_status,household_size,o_scorad,group
#sample-id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Ca009STL,1,A1,SA501,ATCGTACG,SA701,CGAGAGTT,SA701SA501,CGAGAGTT-ATCGTACG,1.010000e+21,A1,...,skin,24.0,male,4/16/2015,Autumn,Unexposed,negative,4.0,40,skin-ADL
900221,1,B1,SA502,ACTATCTG,SA701,CGAGAGTT,SA701SA502,CGAGAGTT-ACTATCTG,1.010000e+21,B1,...,skin,9.0,female,8/11/2015,Winter,Unexposed,negative,7.0,34,skin-ADL
Ca010EBL,1,C1,SA503,TAGCGAGT,SA701,CGAGAGTT,SA701SA503,CGAGAGTT-TAGCGAGT,1.010000e+21,C1,...,skin,24.0,female,11/20/2014,Spring,Unexposed,negative,7.0,21,skin-ADL
900460,1,D1,SA504,CTGCGTGT,SA701,CGAGAGTT,SA701SA504,CGAGAGTT-CTGCGTGT,1.010000e+21,D1,...,skin,18.0,female,9/23/2015,Spring,Unexposed,,4.0,40,skin-ADL
900051,1,E1,SA505,TCATCGAG,SA701,CGAGAGTT,SA701SA505,CGAGAGTT-TCATCGAG,1.010000e+21,E1,...,skin,31.0,male,4/21/2015,Autumn,Unexposed,negative,7.0,41,skin-ADL
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
900401,5,C12,SB503,AGAGTCAC,SB712,CGTAGCGA,SB712SB503,CGTAGCGA-AGAGTCAC,1.010000e+21,C12,...,skin,21.0,female,9/17/2015,Spring,Exposed,negative,12.0,38,skin-ADNL
900402,6,B4,SA502,ACTATCTG,SB704,TCTCTATG,SB704SA502,TCTCTATG-ACTATCTG,1.010000e+21,B4,...,nasal,21.0,,,,,,,,nares-AD
Ca006ONL,6,F1,SA506,CGTGAGTG,SB701,CTCGACTT,SB701SA506,CTCGACTT-CGTGAGTG,1.010000e+21,F1,...,skin,35.0,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,skin-ADL
Ca006ONNL,6,F2,SA506,CGTGAGTG,SB702,CGAAGTAT,SB702SA506,CGAAGTAT-CGTGAGTG,1.010000e+21,F2,...,skin,35.0,female,3/25/2015,Autumn,Unexposed,negative,3.0,34,skin-ADNL


In [5]:
# Read in the df used in RF analysis
biom_path = '../Data/Tables/Count_Tables/2_209766_feature_table_dedup.biom'

biom_tbl = load_table(biom_path)
df = pd.DataFrame(biom_tbl.to_dataframe().T)

# delete the prefix from the index
df.index = df.index.str.replace('15564.', '')

# Convert to relative abundance by dividing each row by its row sum
df_dense = df.div(df.sum(axis=1), axis=0)
df = pd.DataFrame(df_dense.values, index=df_dense.index, columns=df_dense.columns)  # Force dense

df

Unnamed: 0,GTGCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCCGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCCGCCGCGGTAATACGTAGGGTGCAAGCGTTGTCCGGAATTACTGGGCGTAAAGAGCTCGTAGGTGGTTTGTCACGTCGTCTGTGAAATTCCA,GTGCCAGCAGCCGCGGTAATACGTAGGGTGCAAGCGTTAATCGGAATTATTGGGCGTAAAGCGAGTGCAGACGGTTACTTAAGCCAGATGTGAAATCCCC,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGAATTATTGGGCGTAAAGCGCGCGCAGGCGGTTTCTTAAGTCTGATGTGAAAGCCCC,GTGCCAGCAGCCGCGGTGATACGTAGGGTGCGAGCGTTGTCCGGATTTATTGGGCGTAAAGGGCTCGTAGGTGGTTGATCGCGTCGGAAGTGTAATCTTG,GTGCCAGCAGCCGCGGTAATACGTAGGGTCCAAGCGTTAATCGGAATTACTGGGCGTAAAGCGTGCGCAGGCGGTTGTGCAAGACCGATGTGAAATCCCC,GTGCCAGCCGCCGCGGTAATACGTAGGTGGCAAGCGTTGTCCGGATTTATTGGGCGTAAAGGGAGCGCAGGTGGTTTCTTAAGTCTGATGTGAAAGCCCA,GTGCCAGCCGCCGCGGTAATACGGAAGGTCCAGGCGTTATCCGGATTTATTGGGTTTAAAGGGAGCGTAGGCGGATTATTAAGTCAGTGGTGAAAGACGG,...,GTGCCAGCCGCCGCGGTAATACGTAGGGGGCAAGCGTTATCCGGATTTACTGGGTGTAAAGGGAGCGTAGACGGCGCAGCAAGTCTGATGTGAAAGGCAG,GTGCCAGCAGCCGCGGTAAGACAGAGGGTGCAAACGTTGCTCGGAATCACTGGGCGTAAAGGGCGTGTAGGCGGGAGAGAAAGTCGGGCGTGAAATCCCT,GTGCCAGCCGCGGTAATACGTAGGGGGCTAGCGTTGTCCGGAATCACTGGGCGTAAAGGGTTCGCAGGCGGAAATGCAAGTCAGGTGTAAAAGGCAGTAG,GTGCCAGCAGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCCGCCGCGGTAATACGTAGGGCGCGAGCGTTGTCCGGAATTATTGGGCGTAAAGAGCTTGTAGGCGGTTTGTTGCGTCTGCTGTGAAAGACCG,GTGCCAGCAGCCGCGGTAATACGGAGGGTGCAAGCGTTATCCGGAATCATTGGGTTTAAAGGGTCCGCAGGCGGATTTATAAGTCAGTGGTGAAAGCCTA,GTGCCAGCAGCCGCGGTAATACGTAGGTGGCGAGCGTTGTCCGGAATTACTGGGTGTAAAGGGCGTGTAGGCGGGAAGGTAAGTCAGATGTGAAATACCG,GTGCCAGCCGCCGCGGTAATACGGAGGATGCGAGCGTTATTCGGAATCATTGGGTTTAAAGGGTCTGTAGGCGGGCTATTAAGTCAGAGGTGAAAGGTTT,GTGCCAGCCGCCGCGGTAAGACGAAGGGGGCTAGCGTTGTTCGGAATTACTGGGCGTAAAGCGCGTGCAGGCGGTTATCCAAGTCGGGTGTGAAAGCCTT,GTCCAGCAGCCGCGGTAATACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTGT
900344,0.529602,0.328848,0.061356,0.044133,0.011841,0.008073,0.004306,0.004306,0.003229,0.001615,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900459,0.068445,0.061485,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900221,0.000744,0.000000,0.000000,0.000000,0.000000,0.000000,0.000541,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900570,0.058922,0.000000,0.000000,0.000000,0.001212,0.000000,0.001666,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900092,0.623945,0.342909,0.011852,0.006428,0.000603,0.000000,0.000000,0.000000,0.000000,0.001406,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
900294,0.017937,0.010463,0.000000,0.000000,0.000000,0.000000,0.144993,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900097,0.011594,0.000000,0.000000,0.000000,0.000000,0.000000,0.015942,0.000000,0.000000,0.000000,...,0.003865,0.002415,0.000483,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
900498,0.015707,0.017801,0.000000,0.000000,0.000000,0.000000,0.035602,0.000000,0.014660,0.000000,...,0.000000,0.000000,0.000000,0.015707,0.010471,0.008377,0.000000,0.000000,0.000000,0.000000
900276,0.000000,0.000000,0.041322,0.000000,0.000000,0.000000,0.207989,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.015152,0.004132,0.002755,0.001377


In [6]:
# # Truncate function to remove very white parts
# def truncate_colormap(cmap, minval=0.0, maxval=0.8, n=100):
#     new_cmap = LinearSegmentedColormap.from_list(
#         f'trunc({cmap.name},{minval:.2f},{maxval:.2f})',
#         cmap(np.linspace(minval, maxval, n))
#     )
#     return new_cmap

# # Your custom colormaps (non-reversed!)
# colormaps = {
#     'skin_vs_nares': 'Greys',
#     'skin_ADL_vs_H': 'Blues',
#     'skin_ADNL_vs_H': 'Greens',
#     'skin_ADNL_vs_ADL': 'Purples',
#     'nares_AD_vs_H': 'Oranges'
# }

# # Transpose so comparisons are rows, features are columns
# top10_t = top10.T

# # Create the figure
# fig, ax = plt.subplots(figsize=(12, 4))  # wide figure

# # Normalize rows
# normed_data = pd.DataFrame(index=top10_t.index, columns=top10_t.columns)

# for comparison in top10_t.index:
#     row = top10_t.loc[comparison]
#     normed = (row - row.min()) / (row.max() - row.min())
#     normed_data.loc[comparison] = normed

# # Plot manually
# for idx, comparison in enumerate(top10_t.index):
#     base_cmap = plt.get_cmap(colormaps[comparison])
#     cmap = truncate_colormap(base_cmap, 0.0, 0.8)  # Cut off super light top
#     row_data = normed_data.loc[comparison].astype(float)
    
#     for jdx, value in enumerate(row_data):
#         corrected_value = 1 - value  # <-- flip! low values colorful, high values light
#         color = cmap(corrected_value)
        
#         rect = plt.Rectangle((jdx, idx), 1, 1, facecolor=color, edgecolor='white', linewidth=0.5)
#         ax.add_patch(rect)
        
#         original_val = top10_t.loc[comparison].iloc[jdx]
#         ax.text(jdx + 0.5, idx + 0.5, f"{int(original_val)}",
#                 ha='center', va='center', color='black', fontsize=8)

# # Set axis limits
# ax.set_xlim(0, top10_t.shape[1])
# ax.set_ylim(0, top10_t.shape[0])

# # Set ticks
# ax.set_xticks(np.arange(top10_t.shape[1]) + 0.5)
# ax.set_yticks(np.arange(top10_t.shape[0]) + 0.5)
# ax.set_xticklabels(top10_t.columns, rotation=45, ha='right')
# ax.set_yticklabels(["All Skin vs. All Nares", "Skin ADL vs Skin H", "Skin ADNL vs Skin H", "Skin ADNL vs Skin ADL", "Nares AD vs Nares H"])

# # Reverse y-axis
# ax.invert_yaxis()

# # Clean up
# ax.set_xticks(np.arange(top10_t.shape[1]), minor=True)
# ax.set_yticks(np.arange(top10_t.shape[0]), minor=True)
# ax.grid(False)
# ax.tick_params(which="minor", bottom=False, left=False)

# # Title
# plt.title("Top 10 Features from Each Classification Pooled", fontsize=16, pad=35, x=0.45)

# # Smaller subtitle
# plt.text(
#     0.45, 1.25, 
#     "(lower ranks and greater color intensity correspond to higher feature importance)",
#     ha='center', va='center',
#     transform=ax.transAxes,
#     fontsize=12
# )

# plt.tight_layout()

# plt.savefig('../Figures/Main/Fig_2B.jpg', dpi=600, bbox_inches='tight')

### Calculate LFC from feature table between pairwise groups to annotate heatmap

In [7]:
# Read in the combined LFC results
lfc_all = pd.read_csv('../Data/RF_Feature_Importances/lfc_all_comparisons.csv')

# Transpose so comparisons are rows, features are columns
top10_t = top10.T

# Get the heatmap feature names
heatmap_feats = top10_t.columns.to_list()

# Convert spaces to underscores in heatmap feature names to match ASV_Name format
heatmap_feats_formatted = [feat.replace(' ', '_') for feat in heatmap_feats]

heatmap_feats_formatted = ['g__' + feat.replace(' ', '_') for feat in heatmap_feats]

# Define custom colormaps (non-reversed!)
colormaps = {
    'skin_vs_nares': 'Greys',
    'skin_ADL_vs_H': 'Blues',
    'skin_ADNL_vs_H': 'Greens',
    'skin_ADNL_vs_ADL': 'Purples',
    'nares_AD_vs_H': 'Oranges'
}


In [8]:
def calculate_lfc_for_comparisons(df, metadata, comparisons):
    """
    Calculate log fold change (LFC) for each ASV across multiple pairwise comparisons.
    
    Parameters:
   --------
    df : DataFrame
        Relative abundance data (samples x ASVs)
    metadata : DataFrame
        Metadata with 'group' column
    comparisons : dict
        Dictionary mapping comparison names to colormap names
    
    Returns:
   -----
    lfc_results : dict
        Dictionary with comparison names as keys and DataFrames of LFC values as values
    """
    
    # Define the pairwise comparisons based on the colormap keys
    comparison_groups = {
        'skin_vs_nares': (['skin-ADL', 'skin-ADNL', 'skin-H'], ['nares-AD', 'nares-H']),
        'skin_ADL_vs_H': (['skin-ADL'], ['skin-H']),
        'skin_ADNL_vs_H': (['skin-ADNL'], ['skin-H']),
        'skin_ADNL_vs_ADL': (['skin-ADNL'], ['skin-ADL']),
        'nares_AD_vs_H': (['nares-AD'], ['nares-H'])
    }
    
    lfc_results = {}
    
    for comparison_name, (group1, group2) in comparison_groups.items():
        print(f"\nCalculating LFC for {comparison_name}...")
        print(f"  Group 1: {group1}")
        print(f"  Group 2: {group2}")
        
        # Get samples for each group
        samples_group1 = metadata[metadata['group'].isin(group1)].index
        samples_group2 = metadata[metadata['group'].isin(group2)].index
        
        # Filter to only include samples present in df
        samples_group1 = [s for s in samples_group1 if s in df.index]
        samples_group2 = [s for s in samples_group2 if s in df.index]
        
        print(f"  N Group 1: {len(samples_group1)}")
        print(f"  N Group 2: {len(samples_group2)}")
        
        if len(samples_group1) == 0 or len(samples_group2) == 0:
            print(f"  WARNING: One or both groups have no samples!")
            continue
        
        # Get abundance data for each group
        group1_data = df.loc[samples_group1]
        group2_data = df.loc[samples_group2]
        
        # Calculate mean abundance for each ASV in each group
        # Add pseudocount to avoid log(0)
        pseudocount = 1e-6
        mean_group1 = group1_data.mean(axis=0) + pseudocount
        mean_group2 = group2_data.mean(axis=0) + pseudocount
        
        # Calculate log fold change: log2(group1/group2)
        lfc = np.log2(mean_group1 / mean_group2)
        
        # Create results dataframe
        results_df = pd.DataFrame({
            'ASV': lfc.index,
            'LFC': lfc.values,
            'mean_group1': mean_group1.values - pseudocount,
            'mean_group2': mean_group2.values - pseudocount,
            'comparison': comparison_name
        })
        
        # Optionally calculate p-values using Mann-Whitney U test
        pvalues = []
        for asv in df.columns:
            if group1_data[asv].sum() == 0 and group2_data[asv].sum() == 0:
                pvalues.append(1.0)
            else:
                _, pval = stats.mannwhitneyu(group1_data[asv], group2_data[asv], 
                                             alternative='two-sided')
                pvalues.append(pval)
        
        results_df['pvalue'] = pvalues
        
        # Sort by absolute LFC
        results_df = results_df.sort_values('LFC', key=abs, ascending=False)
        
        lfc_results[comparison_name] = results_df
    
    return lfc_results

# Calculate LFC for all comparisons
lfc_results = calculate_lfc_for_comparisons(df, metadata, colormaps)

# Combine all results into one dataframe
combined_lfc = pd.concat(lfc_results.values(), ignore_index=True)
combined_lfc.to_csv('../Data/RF_Feature_Importances/lfc_all_comparisons.csv', index=False)



Calculating LFC for skin_vs_nares...
  Group 1: ['skin-ADL', 'skin-ADNL', 'skin-H']
  Group 2: ['nares-AD', 'nares-H']
  N Group 1: 282
  N Group 2: 180

Calculating LFC for skin_ADL_vs_H...
  Group 1: ['skin-ADL']
  Group 2: ['skin-H']
  N Group 1: 99
  N Group 2: 84

Calculating LFC for skin_ADNL_vs_H...
  Group 1: ['skin-ADNL']
  Group 2: ['skin-H']
  N Group 1: 99
  N Group 2: 84

Calculating LFC for skin_ADNL_vs_ADL...
  Group 1: ['skin-ADNL']
  Group 2: ['skin-ADL']
  N Group 1: 99
  N Group 2: 99

Calculating LFC for nares_AD_vs_H...
  Group 1: ['nares-AD']
  Group 2: ['nares-H']
  N Group 1: 96
  N Group 2: 84


In [9]:
# Read in the ASV name mapping
asv_mapping_path = '../Data/Taxonomy/ASV_readable_name_mapping_abundance_ranked.csv'
asv_mapping = pd.read_csv(asv_mapping_path)

# Create a dictionary for easy mapping from ASV_Sequence to ASV_Name
asv_name_dict = dict(zip(asv_mapping['ASV_Sequence'], asv_mapping['ASV_Name']))

print(f"Loaded {len(asv_name_dict)} ASV name mappings")

# Map ASV_Name to the LFC results
lfc_all['ASV_Name'] = lfc_all['ASV'].map(asv_name_dict)

print(f"\nSuccessfully mapped ASV_Name for {lfc_all['ASV_Name'].notna().sum()} rows")
print(f"Missing mappings: {lfc_all['ASV_Name'].isna().sum()} rows")

# Save file
lfc_all.to_csv('../Data/RF_Feature_Importances/lfc_all_comparisons_name.csv')

Loaded 2261 ASV name mappings

Successfully mapped ASV_Name for 11305 rows
Missing mappings: 110 rows


In [10]:
# Read in the combined LFC results
lfc_all = pd.read_csv('../Data/RF_Feature_Importances/lfc_all_comparisons_name.csv', index_col=0)

# Filter LFC results to only include heatmap features
lfc_heatmap = lfc_all[lfc_all['ASV_Name'].isin(heatmap_feats_formatted)].copy()

# Define a function to determine direction based on comparison and LFC sign
def get_direction(row):
    comparison = row['comparison']
    lfc = row['LFC']
    
    # Define the groups for each comparison
    comparison_mapping = {
        'skin_vs_nares': ('skin', 'nares'),
        'skin_ADL_vs_H': ('skin-ADL', 'skin-H'),
        'skin_ADNL_vs_H': ('skin-ADNL', 'skin-H'),
        'skin_ADNL_vs_ADL': ('skin-ADNL', 'skin-ADL'),
        'nares_AD_vs_H': ('nares-AD', 'nares-H')
    }
    
    if comparison in comparison_mapping:
        group1, group2 = comparison_mapping[comparison]
        # Positive LFC = enriched in group1, Negative LFC = enriched in group2
        return group1 if lfc > 0 else group2
    else:
        return 'Unknown'

# Apply the function to create the Direction column
lfc_heatmap['Direction'] = lfc_heatmap.apply(get_direction, axis=1)

# Save the updated dataframe
lfc_heatmap.to_csv('../Data/RF_Feature_Importances/lfc_heatmap_features_with_direction.csv', index=False)


In [16]:
# Read in the LFC data
lfc_data = pd.read_csv('../Data/RF_Feature_Importances/lfc_heatmap_features_with_direction.csv')

# Truncate function to remove very white parts
def truncate_colormap(cmap, minval=0.0, maxval=0.8, n=100):
    new_cmap = LinearSegmentedColormap.from_list(
        f'trunc({cmap.name},{minval:.2f},{maxval:.2f})',
        cmap(np.linspace(minval, maxval, n))
    )
    return new_cmap

# Your custom colormaps (non-reversed!)
colormaps = {
    'skin_vs_nares': 'Greys',
    'skin_ADL_vs_H': 'Blues',
    'skin_ADNL_vs_H': 'Greens',
    'skin_ADNL_vs_ADL': 'Purples',
    'nares_AD_vs_H': 'Oranges'
}

# Create a mapping for LFC direction and values
lfc_info = {}
for _, row in lfc_data.iterrows():
    comparison = row['comparison']
    asv_name = row['ASV_Name'].replace('g__', '').replace('_', ' ')
    lfc_value = row['LFC']
    
    if comparison not in lfc_info:
        lfc_info[comparison] = {}
    lfc_info[comparison][asv_name] = lfc_value


# Replace _ with space in column names
top10.index = top10.index.str.replace('_', ' ')

# Transpose so comparisons are rows, features are columns
top10_t = top10.T

# Create the figure
fig, ax = plt.subplots(figsize=(15, 5))  # wide figure

# Normalize rows
normed_data = pd.DataFrame(index=top10_t.index, columns=top10_t.columns)

for comparison in top10_t.index:
    row = top10_t.loc[comparison]
    normed = (row - row.min()) / (row.max() - row.min())
    normed_data.loc[comparison] = normed

# Plot manually
for idx, comparison in enumerate(top10_t.index):
    base_cmap = plt.get_cmap(colormaps[comparison])
    cmap = truncate_colormap(base_cmap, 0.0, 0.8)  # Cut off super light top
    row_data = normed_data.loc[comparison].astype(float)
    
    for jdx, value in enumerate(row_data):
        corrected_value = 1 - value  # <-- flip! low values colorful, high values light
        color = cmap(corrected_value)
        
        rect = plt.Rectangle((jdx, idx), 1, 1, facecolor=color, edgecolor='white', linewidth=0.5)
        ax.add_patch(rect)
        
        original_val = top10_t.loc[comparison].iloc[jdx]
        feature_name = top10_t.columns[jdx]
        
        # Get the LFC value for this comparison and feature
        lfc_value = None
        if comparison in lfc_info and feature_name in lfc_info[comparison]:
            lfc_value = lfc_info[comparison][feature_name]
        
        # Add rank with direction for ALL cells (no filter)
        text_label = f"{int(original_val)}"
        if lfc_value is not None:
            if lfc_value > 0:
                text_label = f"{int(original_val)} (+)"
            else:
                text_label = f"{int(original_val)} (-)"
        
        ax.text(jdx + 0.5, idx + 0.5, text_label,
                ha='center', va='center', color='black', fontsize=8, weight='bold')

# Set axis limits
ax.set_xlim(0, top10_t.shape[1])
ax.set_ylim(0, top10_t.shape[0])

# Set ticks
ax.set_xticks(np.arange(top10_t.shape[1]) + 0.5)
ax.set_yticks(np.arange(top10_t.shape[0]) + 0.5)
ax.set_xticklabels(top10_t.columns, rotation=45, ha='right', fontsize=14)

# Updated y-axis labels with +/- indicators
ax.set_yticklabels([
    "Skin (+) vs. Nares (-)", 
    "Skin ADL (+) vs Skin H (-)", 
    "Skin ADNL (+) vs Skin H (-)", 
    "Skin ADNL (+) vs Skin ADL (-)", 
    "Nares AD (+) vs Nares H (-)"
], fontsize=14)

# Reverse y-axis
ax.invert_yaxis()

# Clean up
ax.set_xticks(np.arange(top10_t.shape[1]), minor=True)
ax.set_yticks(np.arange(top10_t.shape[0]), minor=True)
ax.grid(False)
ax.tick_params(which="minor", bottom=False, left=False)

# Title
plt.title("Top 10 Features from Each Classification Pooled", fontsize=19, pad=15, x=0.45)

# Smaller subtitle
# plt.text(
#     0.45, 1.15, 
#     "(lower ranks and greater color intensity correspond to higher feature importance; +/- indicates LFC direction)",
#     ha='center', va='center',
#     transform=ax.transAxes,
#     fontsize=14
# )

plt.tight_layout()

plt.savefig('../Figures/Main/Fig_2B.jpg', dpi=600, bbox_inches='tight')