This notebook can be used to compare the results of multiple data sets.

# Imports

In [3]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle

scripts_dir = os.path.abspath(os.path.join(os.getcwd(), "../scripts"))
sys.path.append(scripts_dir)
workflows_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(workflows_dir)

from scripts.plots import npg_palette

npg = npg_palette()
sns.set_style("darkgrid")
palette = sns.color_palette()
pd.options.mode.copy_on_write = True


pretty_print = {"string": "STRING",
      "gseapy": "GSEApy",
      "clusterProfiler": "ClusterProfiler",
      "neg_signed_logpval": "signed logPValue",
      "logFC": "logFC",
      "s2n": "Signal-to-noise"}

In [28]:
selected = "manuscript"

max_depth = 9
metrics = ['logFC', 'neg_signed_logpval', 's2n']
libraries = ["KEGG", "GO_STRING_human.gmt"]
tools = ["clusterProfiler", "gseapy", "string"]
lib_names = {"GO" if lib.startswith("GO_") else os.path.splitext(os.path.basename(lib))[0]: lib for lib in libraries}

if selected == "manuscript":

    project_names = ["KIRC.QLF", "PRAD.QLF",
                    "LMAB.QLF", "BSLA.QLF",
                    "CHK2.QLF",
                    #"Chiara.QLF.KO_WT", "Chiara.QLF.SA_WT", "Chiara.QLF.SD_WT",
                    #"met.Exc7_DL.P90.p19rc", "met.Inh_Sncg.P14.p19rc",
                    "GIPF.QLF", "GATB.QLF", "HSPL.QLF"]

    pretty_datanames = {p: p for p in project_names}
    pretty_datanames_old = {"BRCA.QLF": "TCGA.BRCA.N-T",
                        "THCA.QLF": "TCGA.THCA.N-T",
                        "KIRC.QLF": "TCGA.KIRC.N-T",
                        "LIHC.QLF": "TCGA.LIHC.N-T",
                        "met.Exc7_DL.P90.p19rc": "sn.Exc7.P90.WT-SA",
                        "met.Inh_Sncg.P14.p19rc": "sn.Inh.Sncg.P14.WT-SA",
                        "Chiara.QLF.KO_WT": "Ser1016.WT-KO",
                        "Chiara.QLF.SA_WT": "Ser1016.WT-SA",
                        "Chiara.QLF.SD_WT": "Ser1016.WT-SD",
                        "Carmen.paired.QLF": "CHK2.WT-KO"}

    meta_savepath = "../../results/meta"
    meta_project_name = "meta"

elif selected == "MET":

    project_names = [
        "met.Astrocytes1.P14",
        "met.Astrocytes2.P14",
        "met.COPs.P14",
        "met.Endothelial.P14",
        "met.Exc1_SL.P14",
        "met.Exc2_ML.P14",
        "met.Exc3_ML.P14",
        "met.Exc4_ML.P14",
        "met.Exc5_ML.P14",
        "met.Exc6_ML.P14",
        "met.Exc7_DL.P14",
        "met.Exc8_DL.P14",
        "met.Exc9_DL.P14",
        "met.Inh_Lamp5.P14",
        "met.Inh_Meis2.P14",
        "met.Inh_Pvalb.P14",
        "met.Inh_Sst.P14",
        "met.Inh_Vip.P14",
        "met.MFO.P14",
        "met.Microglia.P14",
        "met.OPC.P14",

        "met.Astrocytes1.P90",
        "met.Endothelial.P90",
        "met.Exc1_SL.P90",
        "met.Exc2_ML.P90",
        "met.Exc3_ML.P90",
        "met.Exc6_ML.P90",
        "met.Exc7_DL.P90",
        "met.Exc8_DL.P90",
        "met.Exc9_DL.P90",
        "met.Inh_Lamp5.P90",
        "met.Inh_Meis2.P90",
        "met.Inh_Pvalb.P90",
        "met.Inh_Sst.P90",
        "met.Inh_Vip.P90",
        "met.MFO.P90",
        "met.Microglia.P90",
        "met.OPC.P90"
    ]

    meta_savepath = "../../results/meta.met"
    meta_project_name = "meta.met"

    pretty_datanames = {p: p.split("met.")[1] for p in project_names}

meta_dict = {}
for project in project_names:
    try:
        with open(f"../../results/{project}/combined/syn.summary_dict.{project}.txt", "rb") as f:
            meta_dict[project] = pickle.load(f)
    except FileNotFoundError:
        print(f"Project not found: {project}")

In [29]:
meta_summary_dict = {}

for lib in lib_names:
    meta_summary_df = []
    for project in meta_dict:
        df = meta_dict[project][lib]["summary_df"]
        df.index = df.index + "." + project  # needed for venn
        meta_summary_df.append(df)

    meta_summary_df = pd.concat(meta_summary_df)
    meta_summary_dict[lib] = {}
    meta_summary_dict[lib]["summary_df"] = meta_summary_df

    meta_depth_df = []
    for project in meta_dict:
        depth_df_p = meta_dict[project][lib]["depth_df"]
        depth_df_p["Project"] = project
        depth_df_p.replace({"Project": pretty_datanames}, inplace=True)
        meta_depth_df.append(depth_df_p)

    meta_depth_df = pd.concat(meta_depth_df)
    meta_summary_dict[lib]["depth_df"] = meta_depth_df

# Figures

In [None]:

from scripts.plots import make_bar_plots

make_bar_plots(summary_dict=meta_summary_dict,
               figpath = meta_savepath,
                    project_name =  meta_project_name,
                    lib_names = lib_names,
                    pretty_print = pretty_print,
                    qval = 0.05,
                    palette = npg_palette(),
                    max_depth = max_depth,
                    ext = "png")


In [None]:
from scripts.plots import make_bar_plots


def get_depth_ratio(summary_dict, lib_name):
    vc = summary_dict[lib_name]["depth_df"]["Depth"].value_counts().sort_index()
    vc /= vc.max()
    return vc.iloc[-1] / vc.iloc[0]


def pick_project(meta_summary_dict, project, lib_names):
    d = {lib: {} for lib in lib_names}
    for lib in lib_names:
        ms = meta_summary_dict[lib]["summary_df"]
        md = meta_summary_dict[lib]["depth_df"]
        d[lib]["summary_df"] = ms.loc[ms.index.str.endswith(project)]
        d[lib]["depth_df"] = md[md["Project"] == project]
    return d


plot = False

for project in project_names:
    d = pick_project(meta_summary_dict, project, lib_names)
    r_kegg = get_depth_ratio(d, "KEGG")
    r_go = get_depth_ratio(d, "GO")
    print(project, f"{r_kegg:.2f} {r_go:.2f}")

    if plot:
        make_bar_plots(summary_dict=d,
                        figpath=None,
                        project_name=meta_project_name,
                        lib_names=lib_names,
                        pretty_print=pretty_print,
                        qval=0.05,
                        palette=npg_palette(),
                        max_depth=max_depth,
                        ext="png")

In [None]:
from scripts.plots import make_venn_plots

make_venn_plots(summary_dict = meta_summary_dict, 
                    figpath = meta_savepath,
                    project_name = meta_project_name,
                    lib_names = lib_names,
                    metrics = metrics,
                    tools = tools,
                    pretty_print = pretty_print,
                    qval = 0.05,
                    ext = "png")

In [None]:
from scripts.plots import make_upset_plots

make_upset_plots(summary_dict = meta_summary_dict, 
                    lib_names = lib_names,
                    figpath = meta_savepath,
                    project_name = meta_project_name,
                    pretty_print = pretty_print,
                    ext = "png")

In [None]:
sns.set(font_scale=1.2)
sns.set_style("whitegrid")
sns.set_style("ticks")
sns.despine()

fig, axes = plt.subplots(2, 2, figsize=(12+len(project_names)//3, 10))
axes = axes.flatten()

### Number of terms per dataset

for ax, lib in zip(axes[:2],lib_names):
    sns.despine()
    depth_df = meta_summary_dict[lib]["depth_df"]
    depth_df.replace({"Project":pretty_datanames}, inplace=True)

    counts = depth_df.groupby(["Project"]).count()["Configurations"].sort_values(ascending=False)

    counts = pd.DataFrame(counts)

    counts["hue"] = counts.index
    #counts.replace({"hue":pretty_datanames}, inplace=True)

    if ax == axes[0]:
        hue_order = {counts.iloc[i]["hue"]: npg[i%len(npg)] for i in range(len(counts))}
        
    b = sns.barplot(data=counts, y="Configurations", x="hue", ax=ax, hue="hue", hue_order=hue_order, palette=npg[:len(project_names)])
    ax.set(title=lib,xlabel=None,ylabel="Combined Terms")

    for i in b.containers:
        b.bar_label(i,)

    # if ax == axes[1]:
    #     ax.set_ylim(0,2999)
    
    # if ax == axes[0]:
    #     ax.set_ylim(0,199)

    # mark depths with horzontal lines
    g = depth_df.groupby(["Project"])["Depth"].value_counts()
    for j, p in enumerate(counts.index):
        sum = 0
        prev_sum = 0
        maxdepth = g.index.get_level_values("Depth").max()
        for i in range(1,1+maxdepth):
            try:
                sum += g.loc[(p,i)]
            except KeyError:
                pass
            if i < maxdepth:
                ax.scatter(j, sum, marker="_", color="black", s=550, alpha=0.5)
            if j == 0:
                ax.text(s=f"d{i}",x=j,y=prev_sum + 0.5*(sum-prev_sum), ha="center", va="center",fontsize=10)
            prev_sum = sum

for ax, lib in zip(axes[2:], lib_names):
    depth_df = meta_summary_dict[lib]["depth_df"]
    #depth_df.replace({"Project":pretty_datanames}, inplace=True)
    #sns.barplot(data=depth_df, x="Project", y="Depth", hue="Project", errorbar="sd", ax=ax, order=hue_order.keys(), hue_order=hue_order, palette=npg[:len(project_names)])
    sns.boxplot(data=depth_df, x="Project", y="Depth", hue="Project", ax=ax, order=hue_order.keys(), hue_order=hue_order, palette=npg[:len(project_names)])
    ax.set(ylabel="Enrichment Depth",title=lib)

    if ax == axes[3]:
        ax.set_ylim(axes[2].get_ylim())

for i in range(len(axes)):
    axes[i].annotate(chr(ord('A')+i), xy=(-0.08, 1.04), xycoords="axes fraction", weight="bold", va='center',ha='center', fontsize=18)
    axes[i].set_xticks(axes[i].get_xticks(), axes[i].get_xticklabels(), rotation=30, ha='right')
    axes[i].set(xlabel=None)

    # annotate KI
    # axes[i].axvline(len(pretty_datanames)-4.5,0,axes[i].get_ylim()[1],ls="--",color="black",alpha=0.7)
    #axes[i].annotate("KI", xy=(0,0), xytext=(len(pretty_datanames)-3, 0.8*axes[i].get_ylim()[1] ), xycoords="data",zorder=99)

    # if i == 0:
    #     axes[i].annotate("KI", xy=(0,0), xytext=(0.78,0.8), xycoords="axes fraction",zorder=99)
    # elif i == 1:
    #     axes[i].annotate("KI", xy=(0,0), xytext=(0.78,0.735), xycoords="axes fraction",zorder=99)
    # elif i == 2:
    #     axes[i].annotate("KI", xy=(0,0), xytext=(0.78,0.84), xycoords="axes fraction",zorder=99)
    # else:
    #     axes[i].annotate("KI", xy=(0,0), xytext=(0.78,0.84), xycoords="axes fraction",zorder=99)

fig.tight_layout()
fig.savefig(f"{meta_savepath}/bars.data.meta.png")

# Work in progress

In [None]:
golib = "GO_STRING_mouse"
d = meta_summary_dict[golib]["depth_df"]
d["n_genes"] = d["Genes"].str.split(";").apply(lambda x: len(x))
d["logFDR"] = -np.log10(d["Combined FDR"])
fig, ax = plt.subplots(2,3, figsize=(15,10))
ax = ax.flatten()

sns.violinplot(data=d, x="Direction", y="Depth", ax=ax[0])
ax[0].set(xlabel="Direction", ylabel="Robustness")

sns.violinplot(data=d, x="ONTOLOGY", y="Depth", ax=ax[1])
ax[1].set(xlabel="Domain", ylabel="Robustness")

sns.violinplot(data=d, x="Enrichr", y="Depth", ax=ax[2])
ax[2].set(xlabel="Enrichr", ylabel="Robustness")

sns.violinplot(data=d, x="Depth", y="n_genes", ax=ax[3])
ax[3].set(xlabel="Robustness", ylabel="# Genes in Term")

sns.violinplot(data=d, x="Depth", y="logFDR", ax=ax[4])
ax[4].set(xlabel="Robustness", ylabel="-log10 FDR")

sns.violinplot(data=d, x="Enrichr", y="n_genes", ax=ax[5])
ax[5].set(xlabel="Enrichr", ylabel="# Genes in Term")

fig.suptitle(golib)
fig.tight_layout()
fig.savefig(f"{meta_savepath}/violin.meta.png")

## Inter-project term heatmaps

In [None]:
depth_cutoff = 6
lib = "GO_STRING_mouse"

projects = meta_dict.keys()
#projects = [p for p in projects if "P14" in p]

inter_mat = pd.DataFrame(index=projects, columns=projects, dtype=float)
union_mat = pd.DataFrame(index=projects, columns=projects, dtype=float)
jacc_mat = pd.DataFrame(index=projects, columns=projects, dtype=float)

for p1 in projects:
    for p2 in projects:
        
        d1 = meta_dict[p1][lib]["depth_df"]
        if len(d1) < 1:
            dc1 = set()
        else:
            dc1 = d1[d1["Depth"]>depth_cutoff]
            dc1 = dc1.index.astype(str) + "_" + dc1["Direction"]
            dc1 = set(dc1)


        d2 = meta_dict[p2][lib]["depth_df"]
        if len(d2) < 1:
            dc2 = set()
        else:
            dc2 = d2[d2["Depth"]>depth_cutoff]
            dc2 = dc2.index.astype(str) + "_" + dc2["Direction"]
            dc2 = set(dc2)

        inter = dc1.intersection(dc2)
        inter_mat.loc[p1,p2] = len(inter)

        union = dc1.union(dc2)
        union_mat.loc[p1,p2] = len(union)

        jacc = (len(inter) / len(union) ) if len(union) else 0
        jacc_mat.loc[p1,p2] = jacc

In [None]:
y = int(len(jacc_mat)*0.6)
x = int(1.2*y)

fig, ax = plt.subplots(1,1,figsize=(x,y))

#mask = jacc_mat.applymap(lambda x: '•' if x < 0.01 else f'{x:.2f}')
mask1 = inter_mat.map(lambda x: '•' if x == 0 else f'{int(x)}')
mask2 = union_mat.map(lambda x: '•' if x == 0 else f'{int(x)}')
mask3 = mask1 + "\n" + mask2
mask3 = mask3.map(lambda x: '•' if x.startswith("•") else x)

h = sns.heatmap(jacc_mat,  annot=mask3.values, ax=ax, fmt="")
fig.axes[1].set(title="Jaccard", xlabel='', ylabel='')

ax.set_xticklabels([label.get_text().replace('met.', '') for label in ax.get_xticklabels()])
ax.set_yticklabels([label.get_text().replace('met.', '') for label in ax.get_yticklabels()])

fig.suptitle(f"Intersection/Union of significant terms with same direction\n{lib} | Depth>{depth_cutoff}")

fig.tight_layout()
fig.savefig(f"{meta_savepath}/heat.meta.depth{depth_cutoff}.png")

## Inter-configuration heat maps

In [None]:
project = "met.Inh_Meis2.P14"
lib = "GO_STRING_mouse"

df = meta_dict[project][lib]["depth_df"]
df = df.assign(Configurations=df["Configurations"].str.split(";")).explode("Configurations")
df["Configurations"] = df["Configurations"].str.strip()
d1 = df[["Configurations","Direction"]]
configurations = sorted(list(set(df["Configurations"])))

inter_mat = pd.DataFrame(index=configurations, columns=configurations, dtype=float)
union_mat = pd.DataFrame(index=configurations, columns=configurations, dtype=float)
jacc_mat = pd.DataFrame(index=configurations, columns=configurations, dtype=float)

for c1 in configurations:
    for c2 in configurations:

        dc1 = d1[d1["Configurations"]==c1]
        dc1 = dc1.index.astype(str) + "_" + dc1["Direction"]
        dc1 = set(dc1)

        dc2 = d1[d1["Configurations"]==c2]
        dc2 = dc2.index.astype(str) + "_" + dc2["Direction"]
        dc2 = set(dc2)


        inter = dc1.intersection(dc2)
        inter_mat.loc[c1,c2] = len(inter)

        union = dc1.union(dc2)
        union_mat.loc[c1,c2] = len(union)

        jacc = (len(inter) / len(union) ) if len(union) else 0
        jacc_mat.loc[c1,c2] = jacc

In [None]:
import colorcet as cc
#cmap = cc.diverging_rainbow_bgymr_45_85_c67

y = int(len(jacc_mat)*0.6*1.5)
x = int(1.2*y)

fig, ax = plt.subplots(1,1,figsize=(x,y))

mask1 = inter_mat.map(lambda x: '•' if x == 0 else f'{int(x)}')
mask2 = union_mat.map(lambda x: '•' if x == 0 else f'{int(x)}')
mask3 = mask1 + "\n" + mask2
mask3 = mask3.map(lambda x: '•' if x.startswith("•") else x)

clean = True
if clean:
    mask3 = mask3.map(lambda x: '')

h = sns.heatmap(jacc_mat, ax=ax, fmt="", annot=mask3.values, vmin=0)#, cmap=cmap)
fig.axes[1].set(title="Jaccard", xlabel='', ylabel='')

ax.set_xticklabels([label.get_text().replace('met.', '') for label in ax.get_xticklabels()])
ax.set_yticklabels([label.get_text().replace('met.', '') for label in ax.get_yticklabels()])

ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha='right')

ax.set_xticks(ax.get_xticks(), [l.get_text().replace("neg_","") for l in ax.get_xticklabels()], rotation=45, ha='right')
ax.set_yticks(ax.get_yticks(), [l.get_text().replace("neg_","") for l in ax.get_yticklabels()])


#fig.suptitle(f"Intersection/Union of significant terms with same direction\n{lib}")
fig.tight_layout()

if clean:
    fig.savefig(f"{meta_savepath}/heat.meta.configs.clean.png")
else:
    fig.savefig(f"{meta_savepath}/heat.meta.configs.png")

## Bulk plots

Quickly re-plot multiple projects

In [None]:
from scripts.plots import make_lollipop_plots

for project in meta_dict:

    summary_dict_p = meta_dict[project]
    figpath_p = os.path.join("../../", f"results/{project}/figures")

    to_split = ["met.Astrocytes2.P14","met.Inh_Meis2.P14","met.Exc3_ML.P90"]
    split_by_subontology = any([s in project for s in to_split])
    if not split_by_subontology: continue

    make_lollipop_plots(summary_dict_p,
                            lib_names,
                            figpath_p,
                            project,
                            top_terms = 30,
                            qval = 0.05,
                            depth_cutoff = 6, 
                            max_depth = max_depth,
                            x_val = "NegSignedlogFDR",
                            hue_subontology=True,
                            split_by_subontology=split_by_subontology,
                            ext="png")

In [None]:
import scripts.plots
import importlib
importlib.reload(scripts.plots)