In [1]:
import pandas as pd
# import numpy as np
from pathlib import Path
from definitions import ROOT_DIR
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rc, rc_context
import matplotlib
from seriate import seriate
from scipy.spatial.distance import pdist

matplotlib.rcParams['pdf.fonttype'] = 42
rc('font',**{'family':'sans-serif',
             'sans-serif':['Arial'],
             'size':12})

In [2]:
def get_class_size(metadata, class_column):
    sizes = metadata[class_column].value_counts()
    metadata['class_size'] = [sizes[k] for k in metadata[class_column]]
    return metadata


def filter_neutral_losses(df, neutral_losses=['']):
    '''
    Filter out entries for ions with neutral losses that are not in the list provided
    If neutral_loss value us "only_nl", than consider only ions that have neutral losses
    '''
    
    if neutral_losses == "only_nl":
        df = df[df.neutral_loss != ""]
    elif neutral_losses != None:
        df = df[df.neutral_loss.isin(neutral_losses)]
    return df


def filter_adducts(df, adducts=['']):
    '''
    Filter out entries for ions with adducts that are not in the list provided
    '''
    if adducts != None:
        df = df[df.adduct.isin(adducts)]
    return df


def filter_polarity(df, polarity=None):
    '''
    Filter out entries based on polarity pol ['positive', 'negative']
    '''
    if polarity != None:
        df = df[df.Polarity == polarity]
    return df


def filter_data(data, polarity=None, adducts=None, neutral_losses=None):
    '''
    Apply polarity, adduct and neutral_loss filters
    '''
    data = filter_polarity(data, polarity)
    data = filter_adducts(data, adducts)
    data = filter_neutral_losses(data, neutral_losses)
    return data


def group_by_molecule(df, groupby_columns):
    '''
    Aggregate intensity and detection values per groupby columns
    '''          
    data = df.groupby(groupby_columns).agg({
        'detectability' : 'max', # here detectability of metabolite is set to 1 if any of it's ions was detected
    }).reset_index()
    return data

In [3]:
p_root_dir = Path(ROOT_DIR)
p_data = p_root_dir / "data"
p_out = p_root_dir / "plots" / 'heatmap'
p_out.mkdir(exist_ok=True, parents=True)

# Classification
p_chem_class = p_data / "custom_classification_v2.csv"

# Predictions
p_predictions = p_data / "All_data_19Apr2023.csv" #"Interlab_data_19Apr2023.csv"
source = p_predictions.stem

# Interlaboratory comparison heatmap

In [4]:
def summarise_per_class(df):
    
    data = df.pivot_table(index=['Sample name', 'main_coarse_class'],
                      values=['detectability', 'class_size'],
                      aggfunc = {
                                'class_size':'first',
                                'detectability':'sum'
                                }).reset_index()

    data['fraction_detected'] = data.detectability / data.class_size 
    data.drop(columns=['detectability', 'class_size'], inplace=True)
    
    return data

In [5]:
# Load classification, add class size info

classes = pd.read_csv(p_chem_class, index_col='internal_id')
chem_class = get_class_size(classes[['name_short', 'main_coarse_class']].drop_duplicates(), 
                            'main_coarse_class')

# Load predictions and format neutral loss column
df = pd.read_csv(p_predictions)
df.neutral_loss.fillna('', inplace=True)

# Only consider data of detected ions
threshold = 0.8
df['detectability'] = df.pred_val >= threshold
data = df[df.detectability]

In [6]:
for polarity in ['positive', 'negative']:
    
    # Choose polarity, filter adducts and neutral losses
    filtered_data = filter_data(data,
                                polarity=polarity, 
                                neutral_losses=[''])
    
    # Summarise data per metabolite and dataset
    molecule_data = group_by_molecule(filtered_data, groupby_columns=['Sample name', 'name_short'])

    # Map chemical class to the metabolite
    mapped_data = molecule_data.merge(chem_class, on='name_short', how='left')

    # Summarise data per class (mean of detected metabolites, and fraction of detected ions per class)
    class_data = summarise_per_class(mapped_data)
    
    # Prepare data for plotting
    plot_data = class_data.pivot_table(index = 'Sample name',
                                       columns='main_coarse_class',
                                       values='fraction_detected',
                                       fill_value=0)

    # Change row order by seriation
    new_row_order = seriate(pdist(plot_data.to_numpy()))
    new_index = plot_data.index[new_row_order]
    plot_data = plot_data.reindex(index=new_index).T

    # Plot
    fname = f"heatmap_interlab_{polarity}_{source}"
    ax = sns.heatmap(data=plot_data, 
                    cmap='viridis',
                    cbar_kws={'label': 'Fraction Detected'},
                    xticklabels=True)
    ax.set(title=f"{polarity} mode", xlabel="", ylabel="", aspect='equal')
    plt.savefig(p_out / f"{fname}.png")
    plt.savefig(p_out/ f"{fname}.pdf", transparent=True)
    plt.close()
    
    plot_data.to_csv(p_out / f"{fname}.csv")

meta NOT subset; don't know how to subset; dropped
meta NOT subset; don't know how to subset; dropped
