<a href="https://colab.research.google.com/github/pachterlab/LSCHWCP_2023/blob/main/Notebooks/Supp_Fig_9/Supp_Fig_9ab/enrichment_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Enrichment analysis of predictive genes

In [None]:
!pip install -q gget anndata
import numpy as np
import pandas as pd
pd.set_option('display.max_columns', None)
# pd.set_option('display.max_rows', None)
import anndata
import gget
from tqdm import tqdm
TQDM_BAR_FORMAT = (
    "{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]"
)
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import textwrap
%config InlineBackend.figure_format='retina'

### Perform enrichment analyses:

In [None]:
!wget https://raw.githubusercontent.com/pachterlab/LSCHWCP_2023/main/Notebooks/Supp_Fig_9/Supp_Fig_9ab/gene_weights.csv
gene_df = pd.read_csv("gene_weights.csv")
gene_df.head()

In [None]:
virs = ['u10', 'u102540', 'u11150', 'u202260']

In [None]:
# Number of genes to get gene names from
n_genes = 200

Plot gene weight distributions:

In [None]:
fig, axs = plt.subplots(figsize =(20, 5), ncols=4, sharey=True)
fontsize = 14

col_idx = 0
for vir in virs:
    ax = axs[col_idx]
    ax.hist(gene_df[f"{vir} weight"].values, bins = 100, color="#003049")
    ax.set_title(vir, fontsize=fontsize+2)

    ax.grid(True, which="both", color='lightgray', ls="--", lw=1)
    ax.set_axisbelow(True)
    ax.xaxis.grid(False)

    ax.set_yscale("symlog")
    ax.tick_params(axis="both", labelsize=fontsize)

    if col_idx == 0:
        ax.set_ylabel("Frequency", fontsize=fontsize+2)
    ax.set_xlabel("Gene weight", fontsize=fontsize+2)

    # Shoe cut-off based on X number of top genes
    ax.axvline(gene_df[f"{vir} weight"].values[n_genes], ls="--", color="grey", lw=3)

    col_idx += 1

plt.subplots_adjust(wspace=0.1, hspace=0.2)

fig.savefig(
    "gene_weight_distributions.png", dpi=300, bbox_inches="tight", transparent=True
)

fig.show()

Get gene symbols from Ensembl and perform enrichment analysis:

In [None]:
# Enrichr database(s)
database = "Microbe_Perturbations_from_GEO_up"
database2 = "KEGG_2021_Human"

u10:

In [None]:
# Get gene symbols/names from Ensembl ID with gget info
u10_info_df = gget.info(gene_df["u10 Ensembl ID"][:n_genes].values, verbose=False)
u10_info_df.head()

In [None]:
# Perform enrichment analysis
u10_enrichr_df = gget.enrichr(u10_info_df["ensembl_gene_name"].dropna(axis=0).values, database=database, background=True, plot=True)

In [None]:
gget.enrichr(u10_info_df["ensembl_gene_name"].dropna(axis=0).values, database=database2, background=True)

u11150:

In [None]:
u11150_info_df = gget.info(gene_df["u11150 Ensembl ID"][:n_genes].values, verbose=False)
u11150_info_df.head()

In [None]:
u11150_enrichr_df = gget.enrichr(u11150_info_df["ensembl_gene_name"].dropna(axis=0).values, database=database, background=True, plot=True)

In [None]:
gget.enrichr(u11150_info_df["ensembl_gene_name"].dropna(axis=0).values, database=database2, background=True)

u202260:

In [None]:
u202260_info_df = gget.info(gene_df["u202260 Ensembl ID"][:n_genes].values, verbose=False)
u202260_info_df.head()

In [None]:
u202260_enrichr_df = gget.enrichr(u202260_info_df["ensembl_gene_name"].dropna(axis=0).values, database=database, background=True, plot=True)

In [None]:
gget.enrichr(u202260_info_df["ensembl_gene_name"].dropna(axis=0).values, database=database2, background=True)

u102540:

In [None]:
u102540_info_df = gget.info(gene_df["u102540 Ensembl ID"][:n_genes].values, verbose=False)
u102540_info_df.head()

In [None]:
u102540_enrichr_df = gget.enrichr(u102540_info_df["ensembl_gene_name"].dropna(axis=0).values, database=database, background=True, plot=True)

In [None]:
gget.enrichr(u102540_info_df["ensembl_gene_name"].dropna(axis=0).values, database=database2, background=True)

### Combine enrichment results into one plot

In [None]:
pd.set_option('display.max_rows', None)

In [None]:
# Number of paths to plot
n_paths = 15
dfs = [
    u10_enrichr_df[:n_paths],
    u102540_enrichr_df[:n_paths],
    u11150_enrichr_df[:n_paths],
    u202260_enrichr_df[:n_paths]
]
vir_names = ["u10", "u102540", "u11150", "u202260"]

fig, axs = plt.subplots(figsize=(40, 15), ncols=4)
fontsize = 21
barcolor = "#003049"
p_val_color = "tab:red" # orange: #f77f00

for i, (df, vir) in enumerate(zip(dfs, vir_names)):
    ax1 = axs[i]

    overlapping_genes = df["overlapping_genes"].values
    path_names = df["path_name"].values
    adj_p_values = df["adj_p_val"].values

    # Get gene counts
    gene_counts = []
    for gene_list in overlapping_genes:
        gene_counts.append(len(gene_list))

    # # Sort by number of genes in pathway
    # overlapping_genes = overlapping_genes[np.argsort(gene_counts)][::-1]
    # path_names = path_names[np.argsort(gene_counts)][::-1]
    # adj_p_values = adj_p_values[np.argsort(gene_counts)][::-1]
    # gene_counts = np.sort(gene_counts)[::-1]

    # Wrap pathway labels
    labels = []
    for label in path_names:
        labels.append(
            textwrap.fill(
                label,
                width=20,
                break_long_words=False,
                max_lines=2,
                placeholder="...",
            )
        )

    # Plot barplot
    # ax1.barh(labels, gene_counts, color=cmap(c_values), align="center")
    ax1.barh(np.arange(len(gene_counts)), gene_counts, color=barcolor, align="center")

#     # Add gene names to bar plot
#     for idx, (gc, og) in enumerate(zip(gene_counts, overlapping_genes)):
#         clean_genes = textwrap.fill(
#                 ", ".join(og),
#                 width=gc/2*10,
#                 break_long_words=False,
#                 max_lines=2,
#                 placeholder="...",
#             )

#         ax1.text(0.05, idx, clean_genes, fontsize=fontsize-2, color="lightgrey", va="center", ha="left")

    ax1.set_yticks(np.arange(len(gene_counts)), labels, linespacing=0.85, fontsize=fontsize)
    ax1.invert_yaxis()

    # Set x-limits
    # ax1.set_xlim(0, ax1.get_xlim()[1]+0.01)
    ax1.set_xlim(left=0, right=13.5)

    # Add adj. P value secondary x-axis
    ax2 = ax1.twiny()
    ax2.scatter(-np.log10(adj_p_values), np.arange(len(gene_counts)), color=p_val_color, s=150, edgecolor="white", lw=0.5)
    # Change label and color of p-value axis
    ax2.set_xlabel(
        "$-log_{10}$(adjusted P value)", fontsize=fontsize+2, color=p_val_color
    )
    ax2.spines["top"].set_color(p_val_color)
    ax2.spines["top"].set_linewidth(2)
    ax2.tick_params(axis="x", colors=p_val_color, labelsize=fontsize)

    # # Set x2-limits
    ax2.set_xlim(left=0, right=5)

    # Add alpha=0.05 p-value cutoff
    ax2.axvline(-np.log10(0.05), color=p_val_color, ls="--", lw=3.5)
    t = ax2.text(
        -np.log10(0.05) - 0.4,
        -0.3,
        "p = 0.05",
        ha="left",
        va="top",
        rotation="vertical",
        fontweight="bold",
        color=p_val_color,
        fontsize=fontsize,
    )
    # t.set_bbox(dict(facecolor='grey', alpha=0.5))

    # Set label and color of count axis
    ax1.set_xlabel(
        # f"Number of overlapping genes (query size: {len(genes_clean)})",
        f"Number of overlapping genes",
        color=barcolor,
        fontsize=fontsize+5,
    )
    # ax2.spines["bottom"].set_color(barcolor)
    # ax2.spines["bottom"].set_linewidth(2)
    ax1.tick_params(axis="x", labelsize=fontsize+3, colors=barcolor)
    # Set bottom x axis to keep only integers since counts cannot be floats
    ax1.xaxis.set_major_locator(MaxNLocator(integer=True))
    # Change fontsize of y-tick labels
    ax1.tick_params(axis="y", labelsize=fontsize)

    # Set title
    ax1.set_title(
        f"{vir}", fontsize=fontsize + 10, pad=10
    )

    # Set axis margins
    ax1.margins(y=0, x=0)

    ax1.grid(True, which="both", color="lightgrey", ls="--", lw=1)
    ax1.set_axisbelow(True)
    ax1.yaxis.grid(False)

    # Remove grids
    # ax1.grid(False)
    # ax2.grid(False)

plt.subplots_adjust(wspace=0.67)

# Save figure
fig.savefig(
    "predictive_genes_enrichment.png", dpi=300, bbox_inches="tight", transparent=True
)

fig.show()