Skip to content

Commit

Permalink
Added leading edge plot
Browse files Browse the repository at this point in the history
  • Loading branch information
PauBadiaM committed Jun 16, 2023
1 parent fc115e3 commit d768a29
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 3 deletions.
2 changes: 1 addition & 1 deletion decoupler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .omnip import get_ksn_omnipath # noqa: F401
from .plotting import plot_volcano, plot_violins, plot_barplot, plot_metrics_scatter, plot_metrics_boxplot # noqa: F401
from .plotting import plot_metrics_scatter_cols, plot_psbulk_samples, plot_filter_by_expr, plot_filter_by_prop # noqa: F401
from .plotting import plot_volcano_df, plot_targets # noqa: F401
from .plotting import plot_volcano_df, plot_targets, plot_running_score # noqa: F401
from .benchmark import benchmark, format_benchmark_inputs, get_performances # noqa: F401
from .utils_benchmark import get_toy_benchmark_data, show_metrics # noqa: F401
from .metrics import metric_auroc, metric_auprc, metric_mcauroc, metric_mcauprc # noqa: F401
170 changes: 169 additions & 1 deletion decoupler/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd
from anndata import AnnData

from .pre import extract, rename_net
from .pre import extract, rename_net, filt_min_n
from .utils_anndata import get_filterbyexpr_inputs, get_min_sample_size, get_cpm_cutoff, get_cpm


Expand Down Expand Up @@ -1115,3 +1115,171 @@ def plot_filter_by_prop(adata, min_prop=0.2, min_smpls=2, cmap='viridis', figsiz
return fig
else:
raise ValueError(msg)


def compute_es_per_rank(m, snet, rnks, set_msk):

# Get decending penalty
n_c = rnks.size
n_set = snet.shape[0]
dec = 1.0 / (n_c - n_set)

# Compute norm
sum_set = np.sum(np.abs(m[set_msk]))

# Compute ES
cum_sum = 0.0
es = np.zeros(n_c)
for i in rnks:
if set_msk[i]:
cum_sum += np.abs(m[i]) / sum_set
es[i] = cum_sum
else:
cum_sum -= dec
es[i] = cum_sum

return es


def plot_running_score(df, names, values, net, set_name, source='source', target='target', cmap='RdBu_r', figsize=(5, 5), dpi=100):
"""
Plot the running score of GSEA.
Parameters
----------
df : DataFrame
Dataframe containing feature names and feature values.
names : str
Column name in df with feature names
values : str
Column name in df with feature values
net : DataFrame
Network in long format.
set_name : str
Name of the feature set to plot.
source : str
Column name in net with source nodes.
target : str
Column name in net with target nodes.
cmap : str
Colormap to use.
figsize : tuple
Figure size.
dpi : int
DPI resolution of figure.
Returns
-------
fig : Figure
Returns Figure object.
le_c: ndarray
List of leading edge features. If ES is positive, these are the top ranked features. If ES is negative, these are
the bottom ranked features.
"""

# Load plotting packages
plt = check_if_matplotlib()
mpl = check_if_matplotlib(return_mpl=True)

# Define color norm
class MidpointNormalize(mpl.colors.Normalize):

def __init__(self, vmin=None, vmax=None, vcenter=None, clip=False):
self.vcenter = vcenter
super().__init__(vmin, vmax, clip)

def __call__(self, value, clip=None):
x, y = [self.vmin, self.vcenter, self.vmax], [0, 0.5, 1]
return np.ma.masked_array(np.interp(value, x, y))

# Extract feature level stats and names from df
c = df[names].values.astype('U')
m = df[values].values

# Remove empty values
msk = np.isfinite(m) * (m != 0.)
c = c[msk]
m = m[msk]

# Transform net
net = rename_net(net, source=source, target=target, weight=None)
net = filt_min_n(c, net, min_n=0)
snet = net[net['source'] == set_name]

# Sort features
idx = np.argsort(-m)
m = m[idx]
c = c[idx]

# Get ranks
rnks = np.arange(c.size)

# Get msk
set_msk = np.isin(c, snet['target'])

# Compute es
es = compute_es_per_rank(m, snet, rnks, set_msk)

# Get leading edge features
abs_max_idx = np.argmax(np.abs(es))
es_max = es[abs_max_idx]
sign = np.sign(es_max)
set_rnks = rnks[set_msk]
if sign > 0:
le_c = c[set_rnks[set_rnks <= abs_max_idx]]
else:
le_c = c[set_rnks[set_rnks >= abs_max_idx]]

# Plot
fig, axes = plt.subplots(4, 1, gridspec_kw={'height_ratios': [4, 0.5, 0.5, 2]}, figsize=(3,3), sharex=True, dpi=150)
axes = axes.ravel()

# Plot random walk
ax = axes[0]
ax.margins(0.)
ax.plot(rnks, es, color='#88c544', linewidth=2)
ax.axvline(rnks[abs_max_idx], linestyle='--', color='#88c544')
ax.axhline(0, linestyle='--', color='#88c544')
ax.set_ylabel('Enrichment Score')
ax.set_title(set_name)

# Plot gset mask
ax = axes[1]
ax.margins(0.)
ax.set_yticklabels([])
ax.set_yticks([])
ax.vlines(rnks[set_msk], 0, 1, linewidth=0.5, color='#88c544')

# Plot color bar
ax = axes[2]
ax.margins(0.)
ax.set_yticklabels([])
ax.set_yticks([])
vmin = np.percentile(np.min(m), 2)
vmax = np.percentile(np.max(m), 98)
midnorm = MidpointNormalize(vmin=vmin, vcenter=0, vmax=vmax)
ax.pcolormesh(
m[np.newaxis, :],
rasterized=True,
norm=midnorm,
cmap=cmap,
)
ax.set_xlim(0, rnks.size-1) # Remove extreme to the right

# Plot ranks
ax = axes[3]
ax.margins(0.)
ax.fill_between(rnks, y1=m, y2=0, color="#C9D3DB")
non_zero_rnks = rnks[m > 0]
if non_zero_rnks.size == 0:
zero_rnk = rnks[-1]
else:
zero_rnk = non_zero_rnks[-1] + 1
ax.axvline(zero_rnk, linestyle='--', color="#C9D3DB")
ax.set_xlabel('Rank')
ax.set_ylabel('Ranked metric')

# Remove spaces
fig.subplots_adjust(wspace=0, hspace=0)

return fig, le_c
31 changes: 30 additions & 1 deletion decoupler/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from numpy.testing import assert_almost_equal
from anndata import AnnData
from ..plotting import check_if_matplotlib, check_if_seaborn, save_plot, set_limits, plot_volcano, plot_violins
from ..plotting import plot_barplot, build_msks, write_labels, plot_metrics_scatter, plot_metrics_scatter_cols
from ..plotting import plot_metrics_boxplot, plot_psbulk_samples, plot_filter_by_expr, plot_filter_by_prop, plot_volcano_df
from ..plotting import plot_targets
from ..plotting import plot_targets, compute_es_per_rank, plot_running_score


def test_check_if_matplotlib():
Expand Down Expand Up @@ -292,3 +293,31 @@ def test_plot_filter_by_prop():
del adata.layers['psbulk_props']
with pytest.raises(ValueError):
plot_filter_by_prop(adata, min_prop=0.2, min_smpls=2)


def test_compute_es_per_rank():
m = np.array([9, 6, 5, 2, 1, 0])
snet = np.zeros((3, 1))
rnks = np.arange(m.size)
set_msk = np.array([True, True, True, False, False, False])

es = compute_es_per_rank(m, snet, rnks, set_msk)
exp_es = np.array([0.45, 0.75, 1. , 6.66666667e-01, 3.33333333e-01, 0.])
assert_almost_equal(es, exp_es)


def test_plot_running_score():
df = pd.DataFrame([
['G1', 7.],
['G2', 1.],
['G3', 1.],
['G4', 1.]
], columns=['genes', 'values'])
net = pd.DataFrame([['T1', 'G1', 1], ['T1', 'G2', 2], ['T2', 'G3', -3], ['T2', 'G4', 4]],
columns=['source', 'target', 'weight'])
fig, le = plot_running_score(df, names='genes', values='values', net=net, set_name='T1', source='source', target='target', figsize=(5, 5), dpi=100)
assert np.all(np.isin(le, np.array(['G1', 'G2'])))

df['values'] = -df['values']
fig, le = plot_running_score(df, names='genes', values='values', net=net, set_name='T1', source='source', target='target', figsize=(5, 5), dpi=100)
assert np.all(np.isin(le, np.array(['G1', 'G2'])))

0 comments on commit d768a29

Please sign in to comment.