# Figures (final)

In [None]:
%matplotlib inline
import sys 
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import random
import pickle
from pathlib import Path
from itertools import product

modpath = "../scripts"
sys.path.append(os.path.abspath(modpath))

from misc import pickler, open_table
import plot_utils

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# pretty names
prdea = {"edgerlrt":"edgeR LRT","edgerqlf":"edgeR QLF","deseq2":"DESeq2 Wald",
         "DESeq2":"DESeq2 Wald","wilcox":"Wilcoxon rank-sum"}
prdata = {"BASHER":"BSHR","BASLUMA":"BSLA","LWPL":"GATB","BASLUMB":"BSLB","LUMAB":"LMAB","HERLUMA":"HRLA","HERLUMB":"HRLB"}
prdata_inv = {v:k for k,v in prdata.items()}

pretty_met = {"mcc": "MCC",
             "rec": "Recall",
             "rep": "Replicability",
             "prec": "Precision",
             "deg": "#DEGs",
             "terms": "#Terms",
              "std":"std",
              "kurt":"Kurtosis",
             "Spear": "Spearman correlation", "KL":"KL Divergence","Rep":"Replicability","Prec":"Precision","Rec":"Recall"
             }

# param sets
def get_paramset(data):
    if data in ["BSHR","BSLA","BSLB","LMAB","HRLA","HRLB","GIPF"]:
        return "p2c"
    elif data in ["SNF2","GATB"]:
        return "p3"
    else:
        return "p2"

In [None]:
combined_all = pd.read_csv("../data/multi/combined_all.csv", index_col=0)
combined_all["Out"].fillna("None", inplace=True)
combined_all = combined_all[~combined_all["isSynthetic"]]
combined_all.head()

In [None]:
combined_gsea = pd.read_csv(f"../data/multi/combined_gsea_td.rev3.csv", index_col=0)
combined_gsea["Out"].fillna("None", inplace=True)
combined_gsea["Data"].replace(prdata, inplace=True)

dea = "DESeq2 Wald"
logFC = 1
lfc_mode = "formal"
fdr = 0.05

c = combined_all[~combined_all["isSynthetic"]]
c = c[(c["Out"]=="None") & (c["logFC"]==logFC) & (c["FDR"]==fdr) & (c["lfc_mode"] == lfc_mode) & (c["DEA"]==dea)]
order_rep = c[c["N"]==15].sort_values(by="median_rep_method")["Data"].values
print(order_rep)

combined_gsea.head()

In [None]:
npg = plot_utils.npg_palette()
jco = plot_utils.jco_palette()
colors=plot_utils.matplotlib_init()
palette = jco[:len(order_rep)]
palette = sns.color_palette("crest", n_colors=len(order_rep))
#palette = palette[1:-1]
sns.set_style("whitegrid", {'axes.linewidth': 2, 'axes.edgecolor':'black'})

In [None]:
#datasetsfile = "../data/multi/datasets_wilcox.txt"
datasetsfile = "../data/multi/datasets.txt"

with open(datasetsfile, "rb") as f:
    datasets = pickle.load(f)

keys = [k for k in datasets]
for data in keys:
    if "SBRCA" in data:
        del datasets[data]
datasets.keys()

## Load bootstrapping df

In [None]:
import scipy.stats as stats
from scipy.optimize import curve_fit
from sklearn.metrics import auc

df_boot = pd.read_csv("../data/multi/df_boot_long.csv", index_col=0)

In [None]:
all_N = [5,10]
ref = "Cohort"

dfm = pd.DataFrame(index = list(set(combined_all["Data"])))
for N_ in all_N:
    c = combined_all
    c = c[(c["N"]==N_)&(c["DEA"]=="DESeq2 Wald")&(c["logFC"]==1)&(c["lfc_mode"]=="formal")]
    c.set_index("Data", inplace=True)
    dfm[f"Prec_N{N_}"] = c["median_prec"]
    dfm[f"Rec_N{N_}"] = c["median_rec"]
    dfm[f"Rep_N{N_}"] = c["median_rep"]
    dfm[f"MCC_N{N_}"] = c["median_mcc"]

    c = combined_gsea
    c = c[(c["N"]==N_)&(c["Library"]=="GO_Biological_Process_2023")]
    c.set_index("Data", inplace=True)
    dfm[f"Prec_N{N_}_GO"] = c["median_prec"]
    dfm[f"Rec_N{N_}_GO"] = c["median_rec"]
    dfm[f"Rep_N{N_}_GO"] = c["median_rep"]
    
    c = combined_gsea
    c = c[(c["N"]==N_)&(c["Library"]=="KEGG_2021_Human")]
    c.set_index("Data", inplace=True)
    dfm[f"Prec_N{N_}_KEGG"] = c["median_prec"]
    dfm[f"Rec_N{N_}_KEGG"] = c["median_rec"]
    dfm[f"Rep_N{N_}_KEGG"] = c["median_rep"]

def aggregate(df_ref, metric, agg_func):
    return pd.DataFrame(df_ref.groupby(["Data","Cohort"])[metric].mean()).reset_index().groupby("Data")[metric].apply(lambda x: agg_func(x))
    
k = df_boot
for ref in ["Truth","Cohort"]:
    for N in all_N:
        df_ref = k[(k["Reference"]==ref) & (k["N"]==N)]

        # Crude
        # dfm[f"KL_{ref}_N{N}_mean"] = df_ref.groupby("Data")["KL"].mean()
        # dfm[f"KL_{ref}_N{N}_median"] = df_ref.groupby("Data")["KL"].median()
        # dfm[f"KL_{ref}_N{N}_std"] = df_ref.groupby("Data")["KL"].std()
        # dfm[f"Spear_{ref}_N{N}_mean"] = df_ref.groupby("Data")["Spearman"].mean()
        # dfm[f"Spear_{ref}_N{N}_median"] = df_ref.groupby("Data")["Spearman"].median()
        # dfm[f"Spear_{ref}_N{N}_std"] = df_ref.groupby("Data")["Spearman"].std()

        # Better
        dfm[f"KL_{ref}_N{N}_mean"] = aggregate(df_ref, "KL", np.nanmean)
        dfm[f"KL_{ref}_N{N}_median"] = aggregate(df_ref, "KL", np.nanmedian)
        dfm[f"KL_{ref}_N{N}_std"] = aggregate(df_ref, "KL", np.nanstd)
        dfm[f"Spear_{ref}_N{N}_mean"] = aggregate(df_ref, "Spearman", np.nanmean)
        dfm[f"Spear_{ref}_N{N}_median"] = aggregate(df_ref, "Spearman", np.nanmedian)
        dfm[f"Spear_{ref}_N{N}_std"] = aggregate(df_ref, "Spearman", np.nanstd)

# Figure 1: Replicability

In [None]:
from matplotlib.ticker import MaxNLocator

def make4x4fig(c, suffix, metrics, deg_logscale=False):
    
    fig, axes = plt.subplots(2,len(metrics)//2, figsize=(4.5*len(metrics)//2,8), gridspec_kw={'hspace': 0.3,'wspace':0.39})
    axes = axes.flatten()
    
    all_N = sorted(set(c["N"]))
    c = c.sort_values(by=['Data'], key=lambda col: col.map({k: i for i, k in enumerate(order_rep)}))
    
    for i, (ax, met) in enumerate(zip(axes, metrics)):
    
        sns.lineplot(data=c, x="N", y=f"median_{met}{suffix}", hue="Data", style="Data", markers=True, ax=ax,lw=3,ms=11,palette=palette, hue_order=order_rep, style_order=order_rep)

        if i == 1 and deg_logscale: ax.set_yscale("log")
        ax.legend([],[], frameon=False)
        ax.set_ylabel(f"Median {pretty_met[met]}")
        ax.set_xlabel("Cohort Size N")
        #ax.xaxis.set_major_locator(MaxNLocator(integer=True))
        ax.xaxis.set_ticks(all_N)
        ax.set(xticks=(all_N))
        ax.annotate(chr(ord('A')+i), xy=(-0.08, 1.08), xycoords="axes fraction", weight="bold", va='center',ha='center', fontsize=20)
    
    handles, labels = axes[1].get_legend_handles_labels()
    fig.legend(handles[::-1], labels[::-1], loc='center left', bbox_to_anchor=(0.92,0.5),framealpha=1,title=None,ncol=1,markerscale=1)
    return fig


c = combined_all[~combined_all["isSynthetic"]]
c = c[(c["Out"]=="None") & (c["logFC"]==logFC) & (c["FDR"]==fdr) & (c["lfc_mode"] == lfc_mode) & (c["DEA"]==dea)]

fig = make4x4fig(c, suffix = "_method", metrics = ["rep","deg","prec","rec"])

figpath = f"../figures/fig2_metrics_vs_N.pdf"
#fig.tight_layout()
fig.savefig(figpath, bbox_inches="tight")
print(figpath)
print(f"Median precision for N=15: {c[c['N']==15]['median_prec'].median():.2f}")

In [None]:
from matplotlib.ticker import MaxNLocator

c = combined_all

for logFC in [0,1]:
    for lfc_mode in ["post_hoc","formal"]:
        if logFC == 0 and lfc_mode == "formal": continue
        for dea in set(c["DEA"]):
            print(logFC, lfc_mode, dea)
            
            cc = c[(c["Out"]=="None") & (c["logFC"]==logFC) & (c["FDR"]==fdr) & (c["lfc_mode"] == lfc_mode) & (c["DEA"]==dea)]
        
            fig = make4x4fig(cc, suffix = "_method", metrics = ["rep","deg","prec","rec"])
        
            if logFC == 0 and lfc_mode == "formal":
                figpath = f"../figures/fig2_metrics_vs_N.lfc{logFC}.{lfc_mode}.pdf"
            else:
                figpath = f"../figures/sfig_metrics_vs_N.{dea.replace(' ','')}.lfc{logFC}.{lfc_mode}.pdf"
                
            fig.tight_layout()
            fig.savefig(figpath, bbox_inches="tight")
            print(figpath)
            print(f"Median precision for N=15: {cc[cc['N']==15]['median_prec'].median():.2f}")

In [None]:
# At which N do we see >90% precision
cc=c[c["median_prec_method"]>0.95].sort_values(by="N")
cc[~cc["Data"].duplicated()]

cc=c[c["median_rec_method"]>0.5].sort_values(by="N")
cc[~cc["Data"].duplicated()]

# Figure 3: Enrichment

In [None]:
lib = "GO_Biological_Process_2023"
#lib = "KEGG_2021_Human"

dea = "DESeq2 Wald"
logFC = 1
lfc_mode = "formal"
fdr = 0.05

c = combined_gsea
c = c[(c["Out"]=="None") & (c["Library"]==lib)]

fig = make4x4fig(c, suffix = "", metrics = ["rep","terms","prec","rec"])

figpath = f"../figures/fig3_enrich_metrics_{lib}_vs_N.pdf"
fig.savefig(figpath, bbox_inches="tight")
print(figpath)

print(f"Median precision for N=15: {c[c['N']==15]['median_prec'].median():.2f}")

# Figure 2: logFC

# Figure 5: Bootstrap

In [None]:
# We'll store computed Pearson statistics in a dict

pearson_dict_fname = "../data/multi/pearson_dict.txt"
if os.path.isfile(pearson_dict_fname):
    with open(pearson_dict_fname, "rb") as f:
        pearson_dict = pickle.load(f)
    print("Loaded dict")
else:
    pearson_dict = dict()

In [None]:
import scipy.stats as stats

cohorts = range(50)
reference = "Cohort"
metric_suffix = f"median"
metric_prefix = "Spear" # KL, Spear

suffix = "" # for DEGs
#suffix = "_KEGG"
#suffix = "_GO"

boxprops = dict(boxstyle='round', facecolor='#e4eaf3', alpha=1, edgecolor="#2a3b76")

fit_prec = "linear"
#fit_prec = "binormal"

if metric_prefix == "KL": fit_prec = "linear"

y_prefixes = ["Prec", "Rec", "Rep"]
x1 = f"{metric_prefix}_{reference}_N5_{metric_suffix}"
x2 = f"{metric_prefix}_{reference}_N10_{metric_suffix}"
y1_suffix = f"N5{suffix}"
y2_suffix = f"N10{suffix}"

def bootstrap_plot(metric_prefix,metric_suffix,x1,x2,y1_suffix,y2_suffix, y_prefixes, N1=5, N2=10):

    all_N = {N1: (x1,y1_suffix),
             N2: (x2,y2_suffix)
            }
    
    scale=1.24
    figsize = (scale*7.2,scale*(-1+4*len(y_prefixes)))
    print("Figsize",figsize)
    fig, axes = plt.subplots(len(y_prefixes), 2, figsize=figsize,sharex=False,sharey=False)
    
    for ax,  y_prefix in zip(axes, y_prefixes):
        ax = ax.flatten()
    
        sns.scatterplot(data=dfm, y=f"{y_prefix}_{y1_suffix}", x=x1, hue=dfm.index, style=dfm.index, hue_order=order_rep, style_order=order_rep, s=200, ax=ax[0], palette=palette)
        sns.scatterplot(data=dfm, y=f"{y_prefix}_{y2_suffix}", x=x2, hue=dfm.index, hue_order=order_rep, style_order=order_rep, style=dfm.index, s=200, ax=ax[1], palette=palette)
    
        if fit_prec == "linear":
            sns.regplot(data=dfm, y=f"{y_prefix}_{y1_suffix}", x=x1, ax=ax[0], scatter_kws={'s':0}, order=1)
            sns.regplot(data=dfm, y=f"{y_prefix}_{y2_suffix}", x=x2, scatter_kws={'s':0}, ax=ax[1], order=1)
    
    
    
        for N, a in zip(all_N, ax):
            xx = all_N[N][0]
            yy = f"{y_prefix}_{all_N[N][1]}"
            x = dfm[xx].dropna()
            y = dfm[yy].dropna()
            common = x.index.intersection(y.index)
            x, y = x.loc[common], y.loc[common]
        
            if fit_prec == "binormal":
                def binormal(x, a, b):
                    return stats.norm.cdf(a * stats.norm.ppf(x) + b)
            
                p0 = [3,-2] # this is an mandatory initial guess
                params, pcov = curve_fit(binormal, x, y, p0=p0, bounds=(-2.9, np.inf))
                sigma_ab = np.sqrt(np.diagonal(pcov))
                print(params)
                xlin = np.linspace(0.7,1,100)
                y_binormal = binormal(xlin, *params)
                sns.lineplot(x=xlin, y=y_binormal, color="#4d72b0",lw=2,zorder=99,ax=a)
                bound_upper = binormal(xlin, *(params + sigma_ab))
                bound_lower = binormal(xlin, *(params - sigma_ab))
                a.fill_between(xlin, bound_lower, bound_upper,
                             color = '#e4eaf3', alpha = 1, zorder=0)
        
            else:
                r_val, p_val = stats.pearsonr(x,y)
                pearson_dict[(xx,yy)] = (r_val, p_val)
                r2_val = r_val ** 2
    
                if y_prefix == "Prec":
                    loc = (0.95, 0.05)
                    ha = "right"
                    va = "bottom"
                else:
                    loc = (0.05,0.95)
                    ha="left"
                    va="top"
                a.text(loc[0], loc[1], f"r = {r_val:.2f}\nr² = {r2_val:.2f}\np = {p_val:.2e}", 
                       transform=a.transAxes, fontsize=13, va=va,ha=ha, bbox=boxprops)
    
            a.set(ylabel=f"Median {pretty_met[y_prefix]}")
        
    
    ### MISC.
    
    handles, labels = axes[0][0].get_legend_handles_labels()
    fig.legend(handles[::-1], labels[::-1], loc='center left', 
               bbox_to_anchor=(1,0.5),framealpha=1,title=None,ncol=1,markerscale=1)
    
    for i, a in enumerate(axes.flatten()):
        a.legend().remove()
        a.set_box_aspect(1)
        a.set(xlabel=(f"{metric_suffix.split('_')[-1].capitalize()} {pretty_met[metric_prefix]}"))
        a.set(ylim=(-0.05,1.05))
        if "Spear" in x1:
            pass
            a.xaxis.set_ticks(np.arange(0.75, 1, 0.05))
        a.annotate(chr(ord('A')+i), xy=(-0.08, 1.08), xycoords="axes fraction", weight="bold", va='center',ha='center', fontsize=20)
        a.set_title(f"N={N1}" if i%2==0 else f"N={N2}", size=16)

    return fig
    
fig = bootstrap_plot(metric_prefix,metric_suffix,x1,x2,y1_suffix,y2_suffix, y_prefixes, N1=5, N2=10)

#fig.suptitle(f"25 Bootstrap trials | {len(cohorts)} Cohorts | {suffix}")
fig.tight_layout()
figpath = f"../figures/boot.vs.{len(y_prefixes)}metrics.{metric_prefix}.{suffix}.pdf"
fig.savefig(figpath, bbox_inches="tight")
print(figpath)

## Sup. Figure: non-bootstrapped statistics

Here, we show that non-bootstrapped statistics such as logFC std and number of DEGs cannot predict the performance metrics as reliably as the bootstrapped Spearman correlation.

In [None]:
Ns = [5,10]
stds = {N: dict() for N in Ns}
kurts = {N: dict() for N in Ns}
degs = {N: dict() for N in Ns}

for N in Ns:
    for data in datasets:
        p = datasets[data]["outpath"]
        name = prdata[data] if data in prdata else data
        ps = get_paramset(name)

        # logFC std
        pp = f"{p}/{data}_N{N}/all.logFC.none.deseq2.{ps}.feather"
        tab = open_table(pp)
        std = tab.std(axis=0).median()
        stds[N][name] = std

        # logFC kurtosis
        k = tab.kurtosis(axis=0).median()
        kurts[N][name] = k

        # Number of DEGS
        pp = f"{p}/{data}_N{N}/all.FDR.none.deseq2.{ps}.feather"
        tab = open_table(pp)
        deg = (tab<0.05).sum().median()
        degs[N][name] = deg

In [None]:
for N in Ns:
    dfm[f"std_N{N}"] = stds[N]
    dfm[f"kurt_N{N}"] = kurts[N]
    dfm[f"deg_N{N}"] = degs[N]

y_prefixes = ['Prec', 'Rec', 'Rep']
metric_prefix = "std" # "std", "kurt", "deg"
metric_suffix = "median"

suffix = "" # for DEGs
#suffix = "_KEGG"
#suffix = "_GO"

x1 = f"{metric_prefix}_N5"
x2 = f"{metric_prefix}_N10"

y1_suffix = f"N5{suffix}"
y2_suffix = f"N10{suffix}"

fig = bootstrap_plot(metric_prefix,metric_suffix,x1,x2,y1_suffix,y2_suffix, y_prefixes, N1=5, N2=10)

fig.tight_layout()
figpath = f"../figures/non-boot.vs.{len(y_prefixes)}metrics.{metric_prefix}.{suffix}.pdf"
fig.savefig(figpath, bbox_inches="tight")
print(figpath)

## Predictor comparisons

In [None]:
#pickler(pearson_dict, pearson_dict_fname)

pdf = pd.DataFrame(pearson_dict, index=["r","p"]).T.reset_index(names=["Predictor","Metric"])
pdf["N"] = pdf["Predictor"].str.split("_N", expand=True)[1].str.split("_", expand=True)[0].astype(int)
pdf["Target"] = pdf["Metric"].apply(lambda x: "GO" if "GO" in x else "KEGG" if "KEGG" in x else "DEG")
pdf["Metric"] = pdf["Metric"].apply(lambda x: "Precision" if "Prec" in x else "Recall" if "Rec" in x else "Replicability")
pdf["Predictor"] = pdf["Predictor"].apply(lambda x: "Spearman" if "Spear" in x else "#DEG" if "deg" in x else "logFC std")
pdf["Metric"] = pdf["Metric"].replace({"Replicability":"Repl."})
pdf.head()

In [None]:
scale = 1.25
fig, axes = plt.subplots(3,2,figsize=(scale*5.8,scale*9),sharey=True,sharex=False,gridspec_kw={'wspace': 0.,'hspace':0.34})

def print_pval(pval):
    if pval > 0.05:
        return "ns"
    elif pval > 0.001:
        return "*"
    elif pval > 0.0001:
        return "**"
    else:
        return "***"

for ax_row, target in zip(axes, ["DEG","GO","KEGG"]):
            
    for ax, N in zip(ax_row, [5,10]):

        pdff = pdf[(pdf["N"]==N)&(pdf["Target"]==target)]
        bars = sns.barplot(data=pdff, x="Metric", y="r", ax=ax,hue="Predictor",palette=npg)
        ax.set_title(f"N{N} {target}")

        x_labels = [t.get_text() for t in ax.get_xticklabels()]*3
        x_labels = np.array(x_labels).flatten()
        bars_sorted = bars.patches#sorted([b for b in bars.patches if b.get_width()>0], key = lambda x: x.get_x())
        for bar,label,predictor in zip(bars_sorted, x_labels, pdff["Predictor"]):
            pval = pdff[(pdff["Metric"]==label) & (pdff["Predictor"]==predictor)]["p"].iloc[0]
            ax.text(
                bar.get_x() + bar.get_width() / 2,  # X position: Center of bar
                bar.get_height() + 0.02,            # Y position: Just above bar
                print_pval(pval),               # Custom text (e.g., "Metric: Value")
                ha='center', va='bottom', fontsize=10,zorder=99
            )
            
for i, a in enumerate(axes.flatten()):
    a.legend().remove()
    a.set_box_aspect(1)
    a.set_xticks(a.get_xticks())
    a.set(xlabel=None)
    #a.set_xticklabels(a.get_xticklabels(), rotation=5, ha='right')
    #a.get_xaxis().set_visible(False)
    a.set(ylim=(0,1.04))

    if i%2==0:
        a.set(ylabel="Pearson correlation")
    a.annotate(chr(ord('A')+i), xy=(-0.04, 1.08), xycoords="axes fraction", weight="bold", va='center',ha='center', fontsize=14)

handles, labels = axes[0][0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', title=r"$\bf{Predictor}$",
           bbox_to_anchor=(0.5,0.98),framealpha=1,ncol=3,markerscale=1,fontsize=14)
    
fig.tight_layout()

figpath = f"../figures/sfig20_pred_metric_comparison.pdf"
fig.savefig(figpath, bbox_inches="tight")
print(figpath)

# Sup. Figure 1: Ground truth size

In [None]:
#datasetsfile = "../data/multi/datasets_wilcox.txt"
datasetsfile = "../data/multi/datasets.txt"

with open(datasetsfile, "rb") as f:
    datasets = pickle.load(f)

keys = [k for k in datasets]
for data in keys:
    if "SBRCA" in data:
        del datasets[data]
datasets.keys()

In [None]:
FDRs, logFCs = [0.05], [0,1, "post hoc"]
DEAs = ["edgerqlf", "edgerlrt", "deseq2"]#, "wilcox"]

iterables = [datasets,[prdea[d] for d in DEAs],FDRs,logFCs]
multi_cols = pd.MultiIndex.from_product(iterables, names=["Data","DEA","FDR","logFC"])
gt = pd.DataFrame(columns=multi_cols)
for data in datasets:
    for dea in DEAs:
        for fdr in FDRs:
            for lfc in logFCs:
                col = (data,prdea[dea],fdr,lfc)
                if lfc == "post hoc":
                    gt.loc[0,col] = datasets[data]["truth_stats"][0][fdr][1][dea]
                elif dea != "wilcox":
                    gt.loc[0,col] = datasets[data]["truth_stats"][lfc][fdr][lfc][dea]
                        
                        
gt=gt.unstack().reset_index(level=["Data","DEA","FDR","logFC"], drop=False)
gt.sort_values(by="DEA", inplace=True)
gt.rename(columns={0: "#DEG"}, inplace=True)

# Jaccard
iterables = [datasets,FDRs,logFCs]
multi_cols = pd.MultiIndex.from_product(iterables, names=["Data","FDR","logFC"])
gtj = pd.DataFrame(columns=multi_cols)

for data in datasets:
    for fdr in FDRs:
        for lfc in logFCs:
            col = (data,fdr,lfc)
            if lfc == "post hoc":
                gtj.loc[0,col] = datasets[data]["truth_stats"][0][fdr][1]["jaccard"]
            else:
                gtj.loc[0,col] = datasets[data]["truth_stats"][lfc][fdr][lfc]["jaccard"]
                    
gtj=gtj.unstack().reset_index(level=["Data","FDR","logFC"], drop=False)
gtj.rename(columns={0: "Jaccard"}, inplace=True)

gt["#DEG"] = gt["#DEG"].astype(float)
gtj["Jaccard"] = gtj["Jaccard"].astype(float)
df = gt
gtj.index = range(len(gtj))
df.index = range(len(df))
df["Data"].replace(prdata, inplace=True)

In [None]:
order = df[df["logFC"]==1].groupby("Data")["#DEG"].mean().sort_values().index

def kf(x):
    try:
        return np.where(order==x)[0][0] # sort by data
    except IndexError:
        pass
        #return x # sort by DEA
    
df = df.sort_values(by=['Data',"DEA"], key=lambda col: col.map(kf))
gtj = gtj.sort_values(by='Data', key=lambda col: col.map(kf))

In [None]:
plt.rcParams['legend.title_fontsize'] = '15'

fig, ax = plt.subplots(3,1,figsize=(7,12), sharey=False)
ax = ax.flatten()

palette = list(np.array(npg)[[0,1,2]])
palette2 = list(np.array(npg)[[5,6,7]])

### ax[0]
dfa=df[df["logFC"]==0]
sns.barplot(data=dfa,x="Data",y="#DEG",hue="DEA", ax=ax[0],palette=palette)

### ax[1]
dfa=df[(df["logFC"]==1)]
sns.barplot(data=dfa,x="Data",y="#DEG",hue="DEA", ax=ax[1], palette=palette)

### ax[1]
dfa=df[(df["logFC"]=="post hoc")]
sns.barplot(data=dfa,x="Data",y="#DEG",hue="DEA", ax=ax[2], palette=palette)

# intersection
jaccards = []
prdata_inv = {v:k for k,v in prdata.items()}
for i,data in enumerate(order):
    for logFC, a in zip([0,1,"post hoc"],ax[:3]):
        if logFC == "post hoc":
            logFC, logFC_test = 1, 0
        else:
            logFC_test = logFC
        d = data if data not in prdata_inv else prdata_inv[data]
        inter=datasets[d]["truth_stats"][logFC_test][0.05][logFC]["inter"]
        jaccard=datasets[d]["truth_stats"][logFC_test][0.05][logFC]["jaccard"]
        jaccards.append(jaccard)
        a.scatter(i,inter,marker="_",zorder=99,color="black",s=300,label="Intersection" if i < 1 and logFC < 1 else "")
print(np.mean(jaccards))

for i, a in enumerate(ax):
    a.legend([],[], frameon=False)
    a.set_xticklabels(ax[1].get_xticklabels(), rotation=45, ha='center')
    a.set(xlabel=None)
    a.annotate(chr(ord('A')+i), xy=(-0.08, 1.08), xycoords="axes fraction", weight="bold", va='center',ha='center', fontsize=20)
    
ax[0].set_title(r"$|\log_2\mathrm{FC}$|>0",size=16, pad=10)
ax[1].set_title(r"$|\log_2\mathrm{FC}$|>1",size=16, pad=10)
ax[2].set_title(r"$|\log_2\mathrm{FC}$|>1 (post hoc)",size=16, pad=10)


handles, labels = ax[0].get_legend_handles_labels()
labels = [l + " Wald" if l == "DESeq2" else l for l in labels]
ax[1].legend(handles, labels, loc='upper left',framealpha=1,fontsize=14)
#ax[0].text(0.01,0.825,r"$|\log_2\mathrm{FC}$|>0", transform=ax[0].transAxes, ha="left")
#ax[1].text(0.01,0.825,r"$|\log_2\mathrm{FC}$|>1", transform=ax[1].transAxes, ha="left")


figpath = f"../figures/sfig1_ground_truth_size.pdf"
fig.tight_layout()
fig.savefig(figpath)
print(figpath)

# Unused figures

In [None]:
Ns = [3,7,15]

lib = "GO_Biological_Process_2023"
rd = dict()

for N in Ns:
    for data in datasets:
        if data == "SNF2": continue
        p = datasets[data]["outpath"]
        name = prdata[data] if data in prdata else data

        # logFC std
        pp = f"{p}/{data}_N{N}/all.FDR.gseapy.logFC.{lib}.none.deseq2.p2.feather"
        tab = open_table(pp)
        rd[(data,N,"std")] = (tab<0.05).sum().std()
        rd[(data,N,"mean")] = (tab<0.05).sum().mean()

In [None]:
d = pd.DataFrame(rd.values(), index=rd.keys(), columns=["Val"])
d = d.reset_index(names=["Data","N","metric"])
sns.barplot(data=d,x="N",y="Val",hue="metric")

In [None]:
Ns = [3,10]

for N in Ns:
    for data in datasets:
        p = datasets[data]["outpath"]
        name = prdata[data] if data in prdata else data
        ps = get_paramset(name)

        pp = f"{p}/{data}_N{N}/all.logFC.none.deseq2.{ps}.feather"
        tab = open_table(pp)
        assert 0

In [None]:
# Average logFC of many small cohorts approximates ground truth logFC from big cohort

print(data)
f = f"{datasets[data]["outpath"]}/{data}.deseq2.lfc1.csv"
truth_lfc = pd.read_csv(f, index_col=0)["logFC"]
plt.scatter(tab["1"], truth_lfc)
plt.scatter(tab.mean(axis=1), truth_lfc)

## lfc inflation

In [None]:
Ns = [3,10]
stds = {N: dict() for N in Ns}
kurts = {N: dict() for N in Ns}
degs = {N: dict() for N in Ns}

for N in Ns:
    for data in datasets:
        if data != "LMAB": continue
        p = datasets[data]["outpath"]
        name = prdata[data] if data in prdata else data
        ps = get_paramset(name)

        # logFC std
        pp = f"{p}/{data}_N{N}/all.logFC.none.deseq2.{ps}.feather"
        tab = open_table(pp)
        std = tab.std(axis=0).median()
        stds[N][name] = std

        # logFC kurtosis
        k = tab.kurtosis(axis=0).median()
        kurts[N][name] = k
        assert 0

In [None]:
from misc import get_kl_div


print(data)
f = f"{datasets[data]["outpath"]}/{data}.deseq2.lfc1.csv"
truth_lfc = pd.read_csv(f, index_col=0)["logFC"]
kl_divs = []
for i in range(1,101):
    sns.kdeplot(tab[str(i)],color="grey",alpha=0.5)
    kl_divs.append(get_kl_div(truth_lfc, tab[str(i)], np.linspace(-5,5,50)))
    print(f"{i}\r", end="")
sns.kdeplot(truth_lfc,color="red",alpha=1)
print(np.mean(kl_divs), np.std(kl_divs))

In [None]:
from misc import get_kl_div


print(data)
f = f"{datasets[data]["outpath"]}/{data}.deseq2.lfc1.csv"
truth_lfc = pd.read_csv(f, index_col=0)["logFC"]
kl_divs = []
for i in range(1,101):
    sns.kdeplot(tab[str(i)],color="grey",alpha=0.5)
    kl_divs.append(get_kl_div(truth_lfc, tab[str(i)], np.linspace(-5,5,50)))
    print(f"{i}\r", end="")
sns.kdeplot(truth_lfc,color="red",alpha=1)
print(np.mean(kl_divs), np.std(kl_divs))

In [None]:
sns.set_style("whitegrid", {'axes.linewidth': 2, 'axes.edgecolor':'black'})

met = "rep"
met = "mcc"
met = "prec"
#met = "deg"
#met = "rec"

suffix = "_method"

fig, ax = plt.subplots(4,3, figsize=(15,17), sharex=True, sharey="row")

deas = ["edgeR QLF","edgeR LRT","DESeq2 Wald"]
logFC = "0"
fdr = 0.05

df = combined_all[~combined_all["isSynthetic"]]
df = df[(df["logFC"]==1) & (df["lfc_mode"]=="formal")]
df = df[(df["FDR"]==0.05) & (df["Out"]=="None")]

#order_rep = np.array(["LUSC","COAD","KIRC","LUAD","BRCA","THCA","LIHC","PRAD"])
#order_rep = set(combined_all["Data"])
order_rep = df[(df["DEA"]=="edgeR LRT")&(df["N"]==3)].sort_values(by="median_prec_method")["Data"].values
def kf(x):
    print(order_rep,x)
    return np.where(order_rep==x)[0][0]

palette = jco[:len(order_rep)]

for j, dea in enumerate(deas):
    dea_print = rf"{dea}$_\mathrm{{TREAT}}$" if dea != "DESeq2 Wald" else dea
    ax[0][j].set_title(dea_print)
    
    for i, met in enumerate(["prec","rec","rep","deg"]):
        
        a = ax[i][j]
        c = df
        c = c.sort_values(by=['Data'], key=lambda col: col.map({k: i for i, k in enumerate(order_rep)}))
        cc = c[c["DEA"]==dea]
        if len(cc) < 1: 
            continue
    
        sns.lineplot(data=cc, x="N", y=f"median_{met}{suffix}", hue="Data", style="Data", markers=True, ax=a, palette=palette,lw=3,ms=14)
                
        prettymet = "MCC" if met.startswith("mcc") else ("Recall" if met == "rec" else ("Precision" if met.startswith("prec") else ("#DEG" if met == "deg" else "Replicability")))
        a.set_ylabel(f"Median {prettymet}")
        a.legend([],[], frameon=False)
        
handles, labels = ax[0][1].get_legend_handles_labels()
fig.legend(handles[::-1], labels[::-1], loc="upper center", ncol=8, bbox_to_anchor=(0.5, 1.05),markerscale=1,framealpha=1)

for a in ax[-1]: a.set_xlabel("Cohort Size N")

prefix = "sfig2" if met == "deg" else "sfig3" if met == "rep" else "sfig4" if met.startswith("mcc") else "sfig5" if met == "rec" else "sfig6"
figpath = f"../figures/{prefix}_{met}_vs_N_data.pdf"
fig.tight_layout()
#fig.savefig(figpath)
print(figpath)

In [None]:
sns.set_style("whitegrid", {'axes.linewidth': 2, 'axes.edgecolor':'black'})

method = "DESeq2 Wald"
all_N_sub = [3,5,9,15]

fig, ax = plt.subplots(1,1, figsize=(5,5), sharey=True, sharex=True)

df = combined_all[~combined_all["isSynthetic"]]

df = df[(df["logFC"]==1) & (df["lfc_mode"]=="formal") & (df["DEA"]==method)]
df = df[(df["N"]).isin(all_N_sub)]
sns.scatterplot(data=df, x="median_rec", y="median_prec", hue="N", style="Data", ax=ax, palette=npg,s=200)

for i, N in enumerate(all_N_sub):
    sns.kdeplot(data=df[df["N"]==N], x="median_rec", y="median_prec",color=npg[i],alpha=0.3)

ax.legend(framealpha=1,title=None,ncol=2,markerscale=1,bbox_to_anchor=(1,1))
ax.set(xlim=(-0.05,1.05),ylim=(-0.05,1.05))

In [None]:
from adjustText import adjust_text

sns.set_style("whitegrid", {'axes.linewidth': 2, 'axes.edgecolor':'black'})

method = "DESeq2 Wald"
all_N_sub = [3,5,9,15]

fig, ax = plt.subplots(1,4, figsize=(15,5), sharey=True, sharex=True)

df = combined_all[~combined_all["isSynthetic"]]

df = df[(df["logFC"]==1) & (df["lfc_mode"]=="formal") & (df["DEA"]==method)]

for i, N in enumerate(all_N_sub):
    df_N = df[df["N"]==N]
    sns.scatterplot(data=df_N, x="median_rec_method", y="median_prec_method", color=npg[i], style="Data", ax=ax[i], palette=npg,s=200)
    #sns.kdeplot(data=df_N, x="median_rec_method", y="median_prec_method",color=npg[i],alpha=0.3,ax=ax[i])

    ax[i].set(xlim=(-0.05,1.05),ylim=(-0.05,1.05))
    ax[i].legend([],[], frameon=False)
    ax[i].set_title(f"N={N}",color=npg[i], fontsize=20)
    ax[i].set(xlabel="Median Recall", ylabel="Median Precision")

    texts = []
    for x, y, s in zip(df_N["median_rec_method"], df_N["median_prec_method"], df_N["Data"]):
      texts.append(ax[i].text(x, y, s, size=12))    
    adjust_text(texts, arrowprops=dict(arrowstyle="-", color='black', lw=0.5),ax=ax[i])
    
handles, labels = ax[-1].get_legend_handles_labels()
handles, labels = handles[:-2], labels[:-2]
fig.legend(handles[::-1], labels[::-1], loc="upper center", ncol=8, bbox_to_anchor=(0.51, 1.15),markerscale=1,framealpha=1)
fig.tight_layout()

In [None]:
import scipy.stats as stats
from scipy.optimize import curve_fit
from sklearn.metrics import auc

scale=0.6
fig,ax=plt.subplots(1,3,figsize=(scale*20,scale*10))
axes=ax.flatten()

df_a = combined_all[~combined_all["isSynthetic"]]
df_a = df_a[(df_a["logFC"]==1) & (df_a["lfc_mode"]=="formal")]

methods = ["DESeq2 Wald","edgeR QLF","edgeR LRT"]
for ax, method in zip(axes, methods):

    df = df_a[df_a["DEA"]==method]
    sns.scatterplot(data=df, x="median_rec", y="median_prec", color="grey", style="Data", palette=npg,s=200,alpha=0.5,ax=ax)
    ax.set(xlim=(-0.05,1.05),ylim=(-0.05,1.05))
    x = np.linspace(0,1,100)
    

    def binormal(x, a, b):
        return stats.norm.cdf(a * stats.norm.ppf(x) + b)
    
    params, _ = curve_fit(binormal, df['median_rec'], df['median_prec'], bounds=(0, np.inf))
    y_binormal = binormal(x, *params)
    sns.lineplot(x=x, y=y_binormal, color="red",lw=4,zorder=99,ax=ax)
    
    # params_data = [[]]
    # for data in set(df["Data"]):
    #     if data in ["LMAB","SNF2"]: continue
    #     df_d = df[df["Data"]==data]
    #     params_d, _ = curve_fit(binormal, df_d['median_rec'], df_d['median_prec'])
    #     params_data.append(params_d)
    #     y_binormal = binormal(x, *params_d)
    #     sns.lineplot(x=x, y=y_binormal,alpha=0.25,color="grey",ax=ax)
    print(method,params)
    ax.set(xlabel="Median Recall", ylabel="Median Precision")
    ax.set_aspect(1)
    auc_value = auc(x,y_binormal)
    ax.set_title(f"{method} |logFC|>1 AUC: {auc_value:.2f}")
    ax.legend([],[], frameon=False)

axes[1].legend(loc="lower right",ncol=2,prop={'size': 12})
#handles, labels = axes[-1].get_legend_handles_labels()
#fig.legend(handles[::-1], labels[::-1], loc="upper center", ncol=8, bbox_to_anchor=(0.51, 1),markerscale=1,framealpha=1)
fig.tight_layout()