Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/spatialdata_plot/pl/_datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from matplotlib.colors import Normalize

from spatialdata_plot._logging import logger
from spatialdata_plot.pl.render_params import Color, FigParams, ShapesRenderParams
from spatialdata_plot.pl.render_params import Color, FigParams, ShapesRenderParams, _DsReduction
from spatialdata_plot.pl.utils import (
_ax_show_and_transform,
_convert_alpha_to_datashader_range,
Expand All @@ -32,8 +32,6 @@
# Type aliases and constants
# ---------------------------------------------------------------------------

_DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"]

# Sentinel category name used in datashader categorical paths to represent
# missing (NaN) values. Must not collide with realistic user category names.
_DS_NAN_CATEGORY = "ds_nan"
Expand Down
45 changes: 41 additions & 4 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Callable, Sequence
from copy import deepcopy
from pathlib import Path
from typing import Any, Literal, cast
from typing import Any, Literal, cast, get_args

import matplotlib
import matplotlib.pyplot as plt
Expand All @@ -29,7 +29,7 @@
from xarray import DataArray, DataTree

from spatialdata_plot._accessor import register_spatial_data_accessor
from spatialdata_plot._logging import _log_context
from spatialdata_plot._logging import _log_context, logger
from spatialdata_plot.pl.render import (
_draw_channel_legend,
_render_graph,
Expand All @@ -52,8 +52,10 @@
LegendParams,
PointsRenderParams,
ShapesRenderParams,
_DsReduction,
_FontSize,
_FontWeight,
_ImageDsReduction,
)
from spatialdata_plot.pl.utils import (
_RENDER_CMD_TO_CS_FLAG,
Expand Down Expand Up @@ -194,7 +196,7 @@ def render_shapes(
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None,
colorbar: bool | str | None = "auto",
colorbar_params: dict[str, object] | None = None,
datashader_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None,
datashader_reduction: _DsReduction | None = None,
transfunc: Callable[[float], float] | None = None,
) -> sd.SpatialData:
"""
Expand Down Expand Up @@ -384,7 +386,7 @@ def render_points(
gene_symbols: str | None = None,
colorbar: bool | str | None = "auto",
colorbar_params: dict[str, object] | None = None,
datashader_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None,
datashader_reduction: _DsReduction | None = None,
transfunc: Callable[[float], float] | None = None,
) -> sd.SpatialData:
"""
Expand Down Expand Up @@ -536,6 +538,8 @@ def render_images(
colorbar: bool | str | None = "auto",
colorbar_params: dict[str, object] | None = None,
channels_as_legend: bool = False,
method: Literal["matplotlib", "datashader"] | None = None,
datashader_reduction: _ImageDsReduction | None = None,
) -> sd.SpatialData:
"""
Render image elements in SpatialData.
Expand Down Expand Up @@ -616,6 +620,21 @@ def render_images(
Ignored for single-channel and RGB(A) images. When multiple
``render_images`` calls use this flag on the same axes, all
channel entries are combined into a single legend.
method : str | None, optional
Whether to use ``'matplotlib'`` (default) or ``'datashader'`` for
the downsampling step. When ``'datashader'`` is selected, the
rasterization-to-canvas step uses
:meth:`datashader.Canvas.raster` with ``datashader_reduction`` as the
downsample method (default ``'max'``), and ``imshow`` is rendered
with ``interpolation='nearest'`` so the chosen reduction is not
re-smoothed at display time. Useful for very sparse images
(mostly zeros) where mean aggregation collapses the signal —
``method='datashader'`` with ``datashader_reduction='max'`` preserves the
rare non-zero pixels (``plt.spy``-style).
datashader_reduction : {"max", "min", "mean", "mode", "first", "last", "var", "std"} | None, optional
Downsample reduction used by the datashader path. Defaults to
``'max'`` when ``method='datashader'``. Ignored otherwise (a
warning is emitted if set without ``method='datashader'``).

Notes
-----
Expand All @@ -634,6 +653,22 @@ def render_images(
"""
if grayscale and palette is not None:
raise ValueError("Cannot combine grayscale=True with palette.")

if method is not None and not isinstance(method, str):
raise TypeError("Parameter 'method' must be a string.")
if method is not None and method not in ("matplotlib", "datashader"):
raise ValueError("Parameter 'method' must be either 'matplotlib' or 'datashader'.")
_valid_image_reductions = get_args(_ImageDsReduction)
if datashader_reduction is not None and not isinstance(datashader_reduction, str):
raise TypeError("Parameter 'datashader_reduction' must be a string.")
if datashader_reduction is not None and datashader_reduction not in _valid_image_reductions:
raise ValueError(
f"Parameter 'datashader_reduction' must be one of {_valid_image_reductions}, "
f"got {datashader_reduction!r}."
)
if datashader_reduction is not None and method != "datashader":
logger.warning("Parameter 'datashader_reduction' has no effect unless method='datashader'; ignoring.")

params_dict = _validate_image_render_params(
self._sdata,
element=element,
Expand Down Expand Up @@ -699,6 +734,8 @@ def render_images(
transfunc=transfunc,
grayscale=grayscale,
channels_as_legend=channels_as_legend,
method=method,
ds_reduction=datashader_reduction,
)
n_steps += 1

Expand Down
33 changes: 30 additions & 3 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
_ds_aggregate,
_ds_shade_categorical,
_ds_shade_continuous,
_DsReduction,
_render_ds_image,
_render_ds_outlines,
)
Expand All @@ -55,6 +54,7 @@
LegendParams,
PointsRenderParams,
ShapesRenderParams,
_DsReduction,
)
from spatialdata_plot.pl.utils import (
_ax_show_and_transform,
Expand All @@ -73,6 +73,7 @@
_prepare_cmap_norm,
_prepare_transformation,
_rasterize_if_necessary,
_rasterize_if_necessary_datashader,
_set_color_source_vec,
_validate_polygons,
)
Expand Down Expand Up @@ -1279,7 +1280,24 @@ def _render_images(
scale=scale,
)
# rasterize spatial image if necessary to speed up performance
if rasterize:
use_datashader = render_params.method == "datashader"
if use_datashader:
downsample_method = render_params.ds_reduction or "max"
logger.info(
f"Using 'datashader' backend with '{downsample_method}' as downsample method. "
"Depending on the reduction, the value range of the plot might change. "
"Set method to 'matplotlib' to disable this behaviour."
)
img = _rasterize_if_necessary_datashader(
image=img,
dpi=fig_params.fig.dpi,
width=fig_params.fig.get_size_inches()[0],
height=fig_params.fig.get_size_inches()[1],
coordinate_system=coordinate_system,
extent=extent,
downsample_method=downsample_method,
)
elif rasterize:
img = _rasterize_if_necessary(
image=img,
dpi=fig_params.fig.dpi,
Expand Down Expand Up @@ -1389,6 +1407,10 @@ def _render_images(
"Consider using 'palette' instead."
)

# Force nearest-neighbor at display time when the datashader reduction picked
# a non-mean aggregation; otherwise imshow's default interpolation would smear it.
_interp = "nearest" if use_datashader else None

# Detect RGB(A) images by channel names — skip when user overrides with palette/cmap
is_rgb, has_alpha = _is_rgb_image(channels)
has_explicit_cmap = (
Expand Down Expand Up @@ -1430,7 +1452,7 @@ def _render_images(
render_params.alpha,
)

_ax_show_and_transform(stacked, trans_data, ax, **show_kwargs)
_ax_show_and_transform(stacked, trans_data, ax, interpolation=_interp, **show_kwargs)
if render_params.channels_as_legend:
logger.warning("channels_as_legend is not supported for true RGB images and will be ignored.")
return
Expand All @@ -1457,6 +1479,7 @@ def _render_images(
cmap=cmap,
zorder=render_params.zorder,
norm=render_params.cmap_params.norm,
interpolation=_interp,
)

wants_colorbar = _should_request_colorbar(
Expand Down Expand Up @@ -1549,6 +1572,7 @@ def _render_images(
ax,
render_params.alpha,
zorder=render_params.zorder,
interpolation=_interp,
)

# 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
Expand Down Expand Up @@ -1613,6 +1637,7 @@ def _render_images(
ax,
render_params.alpha,
zorder=render_params.zorder,
interpolation=_interp,
)

# 2C) palette set; also covers `palette + norm=list` since synthesized
Expand All @@ -1633,6 +1658,7 @@ def _render_images(
ax,
render_params.alpha,
zorder=render_params.zorder,
interpolation=_interp,
)

elif palette is None and got_multiple_cmaps:
Expand All @@ -1654,6 +1680,7 @@ def _render_images(
ax,
render_params.alpha,
zorder=render_params.zorder,
interpolation=_interp,
)

# Collect channel legend entries (single point for all multi-channel paths)
Expand Down
8 changes: 6 additions & 2 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

_FontWeight = Literal["light", "normal", "medium", "semibold", "bold", "heavy", "black"]
_FontSize = Literal["xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large"]
_DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"]
_ImageDsReduction = Literal["max", "min", "mean", "mode", "first", "last", "var", "std"]

# replace with
# from spatialdata._types import ColorLike
Expand Down Expand Up @@ -243,7 +245,7 @@ class ShapesRenderParams:
table_name: str | None = None
table_layer: str | None = None
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
ds_reduction: _DsReduction | None = None
colorbar: bool | str | None = "auto"
colorbar_params: dict[str, object] | None = None

Expand All @@ -265,7 +267,7 @@ class PointsRenderParams:
zorder: int = 0
table_name: str | None = None
table_layer: str | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
ds_reduction: _DsReduction | None = None
colorbar: bool | str | None = "auto"
colorbar_params: dict[str, object] | None = None

Expand All @@ -286,6 +288,8 @@ class ImageRenderParams:
transfunc: Callable[[np.ndarray], np.ndarray] | list[Callable[[np.ndarray], np.ndarray]] | None = None
grayscale: bool = False
channels_as_legend: bool = False
method: Literal["matplotlib", "datashader"] | None = None
ds_reduction: _ImageDsReduction | None = None


@dataclass
Expand Down
63 changes: 61 additions & 2 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
PointsRenderParams,
ScalebarParams,
ShapesRenderParams,
_DsReduction,
_FontSize,
_FontWeight,
)
Expand Down Expand Up @@ -2048,6 +2049,61 @@ def _rasterize_if_necessary(
return image


def _rasterize_if_necessary_datashader(
image: DataArray,
dpi: float,
width: float,
height: float,
coordinate_system: str,
extent: dict[str, tuple[float, float]],
downsample_method: str,
) -> DataArray:
"""Downsample to canvas resolution with a configurable datashader reduction.

Used by ``render_images(method='datashader')`` so sparse images (mostly
zeros, rare non-zero pixels) survive the downsample step instead of
being averaged away by the default mean aggregation.
"""
has_c_dim = len(image.shape) == 3
y_dims, x_dims = (image.shape[1], image.shape[2]) if has_c_dim else image.shape

target_y_dims = int(dpi * height)
target_x_dims = int(dpi * width)

if y_dims <= target_y_dims and x_dims <= target_x_dims:
return image

# spatialdata.rasterize is invoked solely to inherit the output coords and
# spatial transformation; its mean-aggregated values are overwritten below.
# TODO: this wastes a full per-channel resample pass. A future refactor can
# construct the target DataArray + transformation directly once spatialdata
# exposes a public geometry-only helper.
world_x = float(extent["x"][1]) - float(extent["x"][0])
world_y = float(extent["y"][1]) - float(extent["y"][0])
target_unit_to_pixels = min(target_y_dims / world_y, target_x_dims / world_x)
base = rasterize(
image,
("y", "x"),
[extent["y"][0], extent["x"][0]],
[extent["y"][1], extent["x"][1]],
coordinate_system,
target_unit_to_pixels=target_unit_to_pixels,
)

out_y, out_x = (base.shape[1], base.shape[2]) if has_c_dim else base.shape
# Materialize once: per-chunk reductions across channels would otherwise
# trigger repeated dask graph evaluations on the same source array.
src = image.compute() if hasattr(image.data, "compute") else image
cvs = ds.Canvas(
plot_width=out_x,
plot_height=out_y,
x_range=(float(extent["x"][0]), float(extent["x"][1])),
y_range=(float(extent["y"][0]), float(extent["y"][1])),
)
base.values = np.asarray(cvs.raster(src, downsample_method=downsample_method).values).astype(base.dtype, copy=False)
return base


def _multiscale_to_spatial_image(
multiscale_image: DataTree,
dpi: float,
Expand Down Expand Up @@ -3385,6 +3441,7 @@ def _ax_show_and_transform(
cmap: ListedColormap | LinearSegmentedColormap | None = None,
zorder: int = 0,
norm: Normalize | None = None,
interpolation: str | None = None,
) -> matplotlib.image.AxesImage:
# ``extent`` uses mpl's pixel-grid convention; world placement happens via
# ``set_transform(trans_data)`` afterwards.
Expand All @@ -3396,6 +3453,8 @@ def _ax_show_and_transform(
imshow_kwargs["alpha"] = alpha
else:
imshow_kwargs["cmap"] = cmap
if interpolation is not None:
imshow_kwargs["interpolation"] = interpolation
im = ax.imshow(array, **imshow_kwargs)
im.set_transform(trans_data)
return im
Expand Down Expand Up @@ -3508,7 +3567,7 @@ def _create_image_from_datashader_result(


def _datashader_aggregate_with_function(
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
reduction: _DsReduction | None,
cvs: Canvas,
spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame,
col_for_color: str | None,
Expand Down Expand Up @@ -3572,7 +3631,7 @@ def _datashader_aggregate_with_function(


def _datshader_get_how_kw_for_spread(
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
reduction: _DsReduction | None,
) -> str:
# Get the best input for the how argument of ds.tf.spread(), needed for numerical values
reduction = reduction or "sum"
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading