# Enrichment analysis

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from enrichment import get_ora


sns.set_theme()

First, we load the DESeq2 results as follows:

In [None]:
df_lfc = pd.read_csv("../Output/DESeq2/dds_results.csv", dtype={"strain": str})
df_lfc

Then, we use the `get_ora()` function found in `enrichment.py` to get the 
GO terms for significantly up- or down-regulated genes:

In [None]:
df_go_terms = (
    df_lfc
    .query("padj < 0.05 and abs(log2fc) > 1")
    .groupby(["strain", "comparison"])
    .gene_id
    .apply(get_ora)
    .reset_index()
    .drop(columns="level_2")
)

We can plot the results of enrichment analysis for specific strain and comparison like below:

In [None]:
strain = "083.2"
comparison = "CvsP"

fig, ax = plt.subplots(figsize=(10, 7))
fig.subplots_adjust(left=0.7)
ax.set(title=(
    "GO terms for significantly up- or down-regulated \n"
    f"genes in strain {strain} in {comparison} comparison"
))
sns.barplot(
    df_go_terms
    .query("strain == @strain and comparison == @comparison and fdr < 0.1")
    .sort_values("fdr"),
    y="term",
    x="number_in_list", 
    hue="fdr",
    palette="mako"
)
sns.move_legend(ax, loc="center left", bbox_to_anchor=(1.0, 0.5))

In [None]:
comp_map = {
    "CvsP": 1,
    "MvsP": 2,
    "CvsM": 3,
    "CvsP, MvsP": 4,
    "CvsM, CvsP": 5,
    "CvsM, MvsP": 6,
    "CvsM, CvsP, MvsP": 7
}

df_go_heatmap = (
    df_go_terms
    .pipe(lambda x: x.merge(
        x
        .query("fdr < 0.01")
        .pivot_table(index=["comparison", "term"], values="strain", aggfunc=list)
        .reset_index()
        .assign(num_strains=lambda y: y.strain.apply(len))
        .explode("strain"),
        how="left",
        on=["term", "comparison", "strain"]
    ))
    .query("fdr < 0.05 and num_strains > 1")
    .pivot_table(index="term", columns="strain", values="comparison", aggfunc=", ".join)
    .replace(comp_map)
)

In [None]:
fig, ax = plt.subplots(figsize=(10, 14))
fig.subplots_adjust(left=0.7)
cbar_ax = fig.add_axes([0.05, 0.95, 0.5, 0.01])

sns.heatmap(
    df_go_heatmap,
    cmap=sns.color_palette("deep", 3), 
    linecolor="black",
    linewidths=0.3,
    vmin=0.5,
    vmax=3.5,
    cbar=True,
    cbar_ax=cbar_ax,
    cbar_kws=dict(orientation="horizontal"),
    ax=ax
)

colorbar = ax.collections[0].colorbar
colorbar.set_ticks([1, 2, 3])
colorbar.set_ticklabels(["CZA vs PAR", "MEM vs PAR", "CZA vs MEM"])