In [None]:
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path

In [None]:
fn_formats = ["./GSEA/go/tables/table_{0}_BP.tsv",
              "./GSEA/go/tables/table_{0}_CC.tsv",
              "./GSEA/go/tables/table_{0}_MF.tsv",
              "./GSEA/kegg/tables/table_{0}.tsv",
              "./GSEA/reactome/tables/table_{0}.tsv",]
immune = ["immu", "interferon", "leukocyte", "lymphocyte"]
dna = ["chromo", "chroma", "dna", "double strand", "recombination"]
metabolism = ["cataboli", "metabol", "fatty", "oxidat", "respirat"]
apop = ["apopt", "p53", "pi3k", "death", "autopha", "tnf"]
inflam = ["inflam", "nf-kappa", "il-17", "il-10"]
mito = ["mitochondri"]

def get_gsea_sig_df(input_formats, 
                    comp, 
                    sel_by=["2_1","4_1","4_2", "5_2"],
                    keywords=metabolism):
    all_sig = set()
    df_dic = {}
    for c in comp:
        df = pd.concat([pd.read_csv(inp.format(c), sep='\t') for inp in input_formats], axis=0)
        if c in sel_by:
            all_sig |= set(df.query("qvalues < 0.2")["ID"])
        df_dic[c] = df
    for c in comp:
        df_dic[c] = df_dic[c][df_dic[c]["ID"].isin(all_sig)]
        df_dic[c]["comparison"] = c
    data = pd.concat(list(df_dic.values()), axis=0)
    data = data[data["Description"].str.lower().str.contains("|".join(keywords))]
    data = data[data["Description"].apply(lambda x: len(x)) < 50]
    split = [0.001, 0.01, 0.05, 0.1, 0.2]
    data["significant"] = data["qvalues"].apply(lambda x: x <= 0.2)
    data["FDR"] = data["qvalues"].apply(lambda x: f"< {split[-sum([x <= s for s in split])]}"
                                        if sum([x <= s for s in split]) else "> 0.2")
    return data

def get_bubble_plot(data, 
                    ax):
    split = [0.001, 0.01, 0.05, 0.1, 0.2]
    
    if not any(data["NES"] > 0):
        palette_name = "light:b"
    elif not any(data["NES"] < 0):
        palette_name = "light:r"
    else:
        palette_name = "coolwarm"
    
    sns.scatterplot(data=data, x="comparison", y="Description", size="FDR", hue="NES",
                style="significant", 
                size_order=[f"< {s}" for s in split] + ["> 0.2"],
                markers={True: "o", False: "X"},
                ax=ax, 
                palette=sns.color_palette(palette_name, as_cmap=True).reversed() if palette_name != "coolwarm" else sns.color_palette(palette_name, as_cmap=True), 
                sizes=(40, 300))
    
    