Skip to content

Commit

Permalink
decouple streamplot
Browse files Browse the repository at this point in the history
  • Loading branch information
VolkerBergen committed Dec 7, 2018
1 parent 5329105 commit ede811b
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 8 deletions.
3 changes: 2 additions & 1 deletion scvelo/plotting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .scatter import scatter
from .velocity import velocity
from .velocity_embedding import velocity_embedding
from .velocity_embedding_grid import velocity_embedding_grid, velocity_embedding_stream
from .velocity_embedding_grid import velocity_embedding_grid
from .velocity_embedding_stream import velocity_embedding_stream
from .velocity_graph import velocity_graph
from .utils import hist
from scanpy.api.pl import paga, paga_compare, rank_genes_groups
13 changes: 7 additions & 6 deletions scvelo/plotting/velocity_embedding_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def compute_velocity_on_grid(X_emb, V_emb, density=None, smooth=None, n_neighbor

@doc_params(scatter=doc_scatter)
def velocity_embedding_grid(adata, basis=None, vkey='velocity', density=None, smooth=None, min_mass=None, arrow_size=None,
arrow_length=None, scale=None, autoscale=True, n_neighbors=None, X=None, V=None,
X_grid=None, V_grid=None, principal_curve=False, color=None, use_raw=None, layer=None,
color_map=None, colorbar=False, palette=None, size=None, alpha=.2, perc=None,
arrow_length=None, arrow_color=None, scale=None, autoscale=True, n_neighbors=None,
X=None, V=None, X_grid=None, V_grid=None, principal_curve=False, color=None, use_raw=None,
layer=None, color_map=None, colorbar=False, palette=None, size=None, alpha=.2, perc=None,
sort_order=True, groups=None, components=None, projection='2d', legend_loc='none',
legend_fontsize=None, legend_fontweight=None, right_margin=None, left_margin=None,
xlabel=None, ylabel=None, title=None, fontsize=None, figsize=None, dpi=None, frameon=None,
Expand Down Expand Up @@ -149,8 +149,9 @@ def velocity_embedding_grid(adata, basis=None, vkey='velocity', density=None, sm

hl, hw, hal = default_arrow(arrow_size)
scale = 1 / arrow_length if arrow_length is not None else scale if scale is not None else 1
quiver_kwargs = {"scale": scale, "angles": 'xy', "scale_units": 'xy', "width": .001, "color": 'black',
"edgecolors": 'k', "headlength": hl/2, "headwidth": hw/2, "headaxislength": hal/2, "linewidth": .2}
quiver_kwargs = {"scale": scale, "angles": 'xy', "scale_units": 'xy', "width": .001,
"color": 'grey' if arrow_color is None else arrow_color, "edgecolors": 'k',
"headlength": hl/2, "headwidth": hw/2, "headaxislength": hal/2, "linewidth": .2}
quiver_kwargs.update(kwargs)
pl.quiver(X_grid[:, 0], X_grid[:, 1], V_grid[:, 0], V_grid[:, 1], **quiver_kwargs, zorder=3)

Expand All @@ -172,7 +173,7 @@ def velocity_embedding_grid(adata, basis=None, vkey='velocity', density=None, sm
@doc_params(scatter=doc_scatter)
def velocity_embedding_stream(adata, basis=None, vkey='velocity', density=None, smooth=None, linewidth=None,
n_neighbors=None, X=None, V=None, X_grid=None, V_grid=None, color=None, use_raw=None,
layer=None, color_map=None, colorbar=False, palette=None, size=None, alpha=.2, perc=None,
layer=None, color_map=None, colorbar=False, palette=None, size=None, alpha=.1, perc=None,
sort_order=True, groups=None, components=None, projection='2d', legend_loc='none',
legend_fontsize=None, legend_fontweight=None, right_margin=None, left_margin=None,
xlabel=None, ylabel=None, title=None, fontsize=None, figsize=None, dpi=None, frameon=None,
Expand Down
108 changes: 108 additions & 0 deletions scvelo/plotting/velocity_embedding_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from ..tools.velocity_embedding import velocity_embedding
from .utils import default_basis, default_size, get_components, savefig, make_unique_list
from .velocity_embedding_grid import compute_velocity_on_grid
from .scatter import scatter
from .docs import doc_scatter, doc_params

from matplotlib import rcParams
import matplotlib.pyplot as pl
import numpy as np


@doc_params(scatter=doc_scatter)
def velocity_embedding_stream(adata, basis=None, vkey='velocity', density=None, smooth=None, linewidth=None,
n_neighbors=None, X=None, V=None, X_grid=None, V_grid=None, color=None, use_raw=None,
layer=None, color_map=None, colorbar=False, palette=None, size=None, alpha=.1, perc=None,
sort_order=True, groups=None, components=None, legend_loc='on data',
legend_fontsize=None, legend_fontweight=None, right_margin=None, left_margin=None,
xlabel=None, ylabel=None, title=None, fontsize=None, figsize=None, dpi=None, frameon=None,
show=True, save=None, ax=None, ncols=None, **kwargs):
"""\
Stream plot of velocities on the embedding.
Arguments
---------
adata: :class:`~anndata.AnnData`
Annotated data matrix.
x: `str`, `np.ndarray` or `None` (default: `None`)
x coordinate
y: `str`, `np.ndarray` or `None` (default: `None`)
y coordinate
vkey: `str` or `None` (default: `None`)
Key for annotations of observations/cells or variables/genes.
density: `float` (default: 1)
Amount of velocities to show - 0 none to 1 all
smooth: `float` (default: 0.5)
Multiplication factor for scale in Gaussian kernel around grid point.
linewidth: `float` (default: 1)
Line width for streamplot.
n_neighbors: `int` (default: None)
Number of neighbors to consider around grid point.
X: `np.ndarray` (default: None)
Embedding grid point coordinates
V: `np.ndarray` (default: None)
Embedding grid velocity coordinates
{scatter}
Returns
-------
`matplotlib.Axis` if `show==False`
"""
basis = default_basis(adata) if basis is None else basis
colors, layers, vkeys = make_unique_list(color, allow_array=True), make_unique_list(layer), make_unique_list(vkey)
for key in vkeys:
if key + '_' + basis not in adata.obsm_keys() and V is None:
velocity_embedding(adata, basis=basis, vkey=key)
color, layer, vkey = colors[0], layers[0], vkeys[0]

if X_grid is None or V_grid is None:
X_emb = adata.obsm['X_' + basis][:, get_components(components, basis)] if X is None else X[:, :2]
V_emb = adata.obsm[vkey + '_' + basis][:, get_components(components, basis)] if V is None else V[:, :2]
X_grid, V_grid = compute_velocity_on_grid(X_emb=X_emb, V_emb=V_emb, density=1, smooth=smooth,
n_neighbors=n_neighbors, autoscale=False, adjust_for_stream=True)
lengths = np.sqrt((V_grid ** 2).sum(0))
linewidth = 1 if linewidth is None else linewidth
linewidth *= 2 * lengths / lengths[~np.isnan(lengths)].max()

scatter_kwargs = {"basis": basis, "perc": perc, "use_raw": use_raw, "sort_order": sort_order, "alpha": alpha,
"components": components, "legend_loc": legend_loc, "groups": groups,
"legend_fontsize": legend_fontsize, "legend_fontweight": legend_fontweight, "palette": palette,
"color_map": color_map, "frameon": frameon, "title": title, "xlabel": xlabel, "ylabel": ylabel,
"right_margin": right_margin, "left_margin": left_margin, "colorbar": colorbar, "dpi": dpi,
"fontsize": fontsize, "show": False, "save": None}

multikey = colors if len(colors) > 1 else layers if len(layers) > 1 else vkeys if len(vkeys) > 1 else None
if multikey is not None:
ncols = len(multikey) if ncols is None else min(len(multikey), ncols)
nrows = int(np.ceil(len(multikey) / ncols))
figsize = rcParams['figure.figsize'] if figsize is None else figsize
for i, gs in enumerate(
pl.GridSpec(nrows, ncols, pl.figure(None, (figsize[0] * ncols, figsize[1] * nrows), dpi=dpi))):
if i < len(multikey):
velocity_embedding_stream(adata, density=density, size=size, smooth=smooth, n_neighbors=n_neighbors,
linewidth=linewidth, ax=pl.subplot(gs),
color=colors[i] if len(colors) > 1 else color,
layer=layers[i] if len(layers) > 1 else layer,
vkey=vkeys[i] if len(vkeys) > 1 else vkey,
X_grid=None if len(vkeys) > 1 else X_grid,
V_grid=None if len(vkeys) > 1 else V_grid, **scatter_kwargs, **kwargs)
if isinstance(save, str): savefig('' if basis is None else basis, dpi=dpi, save=save, show=show)
if show:
pl.show()
else:
return ax

else:
ax = pl.figure(None, figsize, dpi=dpi).gca() if ax is None else ax

density = 1 if density is None else density
stream_kwargs = {"linewidth": linewidth, "density": 2 * density}
stream_kwargs.update(kwargs)
pl.streamplot(X_grid[0], X_grid[1], V_grid[0], V_grid[1], color='grey', zorder=3, **stream_kwargs)

size = 4 * default_size(adata) if size is None else size
ax = scatter(adata, layer=layer, color=color, size=size, ax=ax, zorder=0, **scatter_kwargs)

if isinstance(save, str): savefig('' if basis is None else basis, dpi=dpi, save=save, show=show)
if show: pl.show()
else: return ax
1 change: 0 additions & 1 deletion scvelo/plotting/velocity_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .. import settings
from ..tools.transition_matrix import transition_matrix
from ..preprocessing.neighbors import get_connectivities
from .utils import savefig, default_basis
from .scatter import scatter
from .docs import doc_scatter, doc_params
Expand Down

0 comments on commit ede811b

Please sign in to comment.