In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc
import pandas as pd
from matplotlib_venn import venn2
os.chdir('/workspace/Benchmarking')
from metrics import *

## Compute DEGs and PPS

In [None]:
deg_true_test = compute_deg(true_RNA)
deg_sb = compute_deg(pred_RNA_sb)
deg_babel = compute_deg(pred_RNA_babel)
deg_polar = compute_deg(pred_RNA_polar)

df1 = deg_true_test[1]
df2 = deg_sb[1]
df3 = deg_babel[1]
df4 = deg_polar[1]

Plot results in Volcano plot

In [None]:
import matplotlib.pyplot as plt
import numpy as np


fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axs = axes.flatten()

df1['color'] = 'grey'
df1.loc[(df1['logfoldchanges'] > 1) & (df1['pvals_adj'] < 0.001), 'color'] = '#fcad45'
df1.loc[(df1['logfoldchanges'] < -1) & (df1['pvals_adj'] < 0.001), 'color'] = '#60b1dd'

axes[0, 0].scatter(df1['logfoldchanges'], df1['neg_log10_padj'], s=10, alpha=0.7, c=df1['color'], edgecolor='none')
axes[0, 0].axhline(-np.log10(0.001), color='green', linestyle='--')
axes[0, 0].axvline(1, color='green', linestyle='dotted')
axes[0, 0].axvline(-1, color='green', linestyle='dotted')
axes[0, 0].set_xlim(-10, 10)
axes[0, 0].set_title('True RNA data', fontsize=16)
axes[0, 0].set_xlabel('log\u2082(FC)', fontsize=14)
axes[0, 0].set_ylabel('-log\u2081\u2080(p-value)', fontsize=14)

df2['color'] = 'grey'
df2.loc[(df2['logfoldchanges'] > 1) & (df2['pvals_adj'] < 0.001), 'color'] = '#fcad45'
df2.loc[(df2['logfoldchanges'] < -1) & (df2['pvals_adj'] < 0.001), 'color'] = '#60b1dd'

axes[0, 1].scatter(df2['logfoldchanges'], df2['neg_log10_padj'], s=10, alpha=0.7, c=df2['color'], edgecolor='none')
axes[0, 1].axhline(-np.log10(0.001), color='green', linestyle='--')
axes[0, 1].axvline(1, color='green', linestyle='dotted')
axes[0, 1].axvline(-1, color='green', linestyle='dotted')
axes[0, 1].set_xlim(-10, 10)
axes[0, 1].set_title('Predicted scButterfly', fontsize=16)
axes[0, 1].set_xlabel('log\u2082(FC)', fontsize=14)
axes[0, 1].set_ylabel('-log\u2081\u2080(p-value)', fontsize=14)

df3['color'] = 'grey'
df3.loc[(df3['logfoldchanges'] > 1) & (df3['pvals_adj'] < 0.001), 'color'] = '#fcad45'
df3.loc[(df3['logfoldchanges'] < -1) & (df3['pvals_adj'] < 0.001), 'color'] = '#60b1dd'

axes[1, 0].scatter(df3['logfoldchanges'], df3['neg_log10_padj'], s=10, alpha=0.7, c=df3['color'], edgecolor='none')
axes[1, 0].axhline(-np.log10(0.001), color='green', linestyle='--')
axes[1, 0].axvline(1, color='green', linestyle='dotted')
axes[1, 0].axvline(-1, color='green', linestyle='dotted')
axes[1, 0].set_xlim(-10, 10)
axes[1, 0].set_title('Predicted BABEL', fontsize=16)
axes[1, 0].set_xlabel('log\u2082(FC)', fontsize=14)
axes[1, 0].set_ylabel('-log\u2081\u2080(p-value)', fontsize=14)

df4['color'] = 'grey'
df4.loc[(df4['logfoldchanges'] > 1) & (df4['pvals_adj'] < 0.001), 'color'] = '#fcad45'
df4.loc[(df4['logfoldchanges'] < -1) & (df4['pvals_adj'] < 0.001), 'color'] = '#60b1dd'

axes[1, 1].scatter(df4['logfoldchanges'], df4['neg_log10_padj'], s=10, alpha=0.7, c=df4['color'], edgecolor='none')
axes[1, 1].axhline(-np.log10(0.001), color='green', linestyle='--')
axes[1, 1].axvline(1, color='green', linestyle='dotted')
axes[1, 1].axvline(-1, color='green', linestyle='dotted')
axes[1, 1].set_xlim(-10, 10)
axes[1, 1].set_title('Predicted Polarbear', fontsize=16)
axes[1, 1].set_xlabel('log\u2082(FC)', fontsize=14)
axes[1, 1].set_ylabel('-log\u2081\u2080(p-value)', fontsize=14)


for axis in axs:
    axis.spines['top'].set_visible(False)
    axis.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig('/workspace/Benchmarking/data_all/figures/deg_lymphoma.svg', format='svg', bbox_inches='tight')
plt.show()

Infer enriched pathways and compute PPS 

In [None]:
pathways_polar = compute_pathway_overlap(deg_true_test[0], deg_polar[0])
pathways_babel = compute_pathway_overlap(deg_true_test[0], deg_babel[0])
pathways_sb = compute_pathway_overlap(deg_true_test[0], deg_sb[0])

print(f'Polarbear: {pathways_polar}', 
      f'BABEL: {pathways_babel}', 
      f'scButterfly: {pathways_sb}')

Plot Barplot of pathways and Venn diagramm of gene overlap

In [None]:
shared_terms_babel = set(deg_true_test[0]['Term']).intersection(deg_babel[0]['Term'])
shared_terms_sb = set(deg_true_test[0]['Term']).intersection(deg_sb[0]['Term'])
shared_terms_polar = set(deg_true_test[0]['Term']).intersection(deg_polar[0]['Term'])
all_terms = shared_terms_babel.union(shared_terms_sb).union(shared_terms_polar)

dict_data = {}
for i in all_terms:
    dict_data[i] = {
        'scButterfly': pathways_sb[1][i]['Jaccard'] if i in pathways_sb[1].keys() else 0.0,
        'BABEL': pathways_babel[1][i]['Jaccard'] if i in pathways_babel[1].keys() else 0.0,
    }

df = pd.DataFrame.from_dict(dict_data, orient="index").reset_index()
df.columns = ["Pathway", "scButterfly", "Babel"]

df = df.fillna(0)

df["Max"] = df[["scButterfly", "Babel"]].abs().max(axis=1)
df = df.sort_values("Max", ascending=True).drop(columns="Max")

fig, ax = plt.subplots(figsize=(8, 6))

ax.barh(df["Pathway"], -df["scButterfly"], color="#b8e0b0", label="scButterfly (PPS = 0.001)")
ax.barh(df["Pathway"], df["Babel"], color="#e19c56", label="Babel (PPS = 0.024)")

ax.axvline(0, color="black", linewidth=0.5)
ax.set_xlabel("Jaccard Index of Gene Overlap")
ax.legend()

ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

plt.tight_layout()
plt.show()

In [None]:
genes_true = set(deg_true_test[0][deg_true_test[0]['Term'] == 'Estrogen Response Early']['Genes'].values[0].split(';'))
genes_pred = set(deg_babel[0][deg_babel[0]['Term'] == 'Estrogen Response Early']['Genes'].values[0].split(';'))

venn2([genes_true, genes_pred], set_colors=('#60b1dd', '#ef973f'), alpha=0.7)
plt.savefig('/workspace/Benchmarking/data_all/figures/venn_estrogen.svg', format='svg', bbox_inches='tight')
plt.show()

genes_true = set(deg_true_test[0][deg_true_test[0]['Term'] == 'E2F Targets']['Genes'].values[0].split(';'))
genes_pred = set(deg_sb[0][deg_sb[0]['Term'] == 'E2F Targets']['Genes'].values[0].split(';'))
venn2([genes_true, genes_pred], set_colors=('#60b1dd', '#ef973f'), alpha=0.7)
plt.savefig('/workspace/Benchmarking/data_all/figures/venn_e2f.svg', format='svg', bbox_inches='tight')
plt.show()