Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/spatialdata_plot/pl/_datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def _ds_shade_continuous(
na_color_hex: str,
spread_px: int | None = None,
ds_reduction: _DsReduction | None = None,
how: str = "linear",
) -> tuple[Any, Any | None, tuple[Any, Any] | None]:
"""Shade a continuous datashader aggregate, optionally applying spread and NaN coloring.

Expand Down Expand Up @@ -255,6 +256,7 @@ def _ds_shade_continuous(
min_alpha=_convert_alpha_to_datashader_range(alpha),
span=color_span,
clip=norm.clip,
how=how,
)
shaded = _apply_user_alpha(shaded, alpha)

Expand All @@ -278,6 +280,8 @@ def _ds_shade_categorical(
color_vector: Any,
alpha: float,
spread_px: int | None = None,
how: str = "linear",
density: bool = False,
) -> Any:
"""Shade a categorical or no-color datashader aggregate."""
ds_cmap = None
Expand All @@ -286,12 +290,20 @@ def _ds_shade_categorical(
if isinstance(ds_cmap, str) and ds_cmap[0] == "#":
ds_cmap = _hex_no_alpha(ds_cmap)

# The default min_alpha (~254) is a near-full-opacity floor — right for scatter
# plots, but it collapses the count-driven alpha range and makes categorical
# density read as a flat hue cloud. Drop the floor under density so per-pixel
# alpha can actually encode count. A small non-zero floor (~15%) keeps the
# sparse edges visible under density_how="linear" instead of vanishing.
min_alpha = 40.0 if density else _convert_alpha_to_datashader_range(alpha)

agg_to_shade = ds.tf.spread(agg, px=spread_px) if spread_px is not None else agg
shaded = _datashader_map_aggregate_to_color(
agg_to_shade,
cmap=ds_cmap,
color_key=color_key,
min_alpha=_convert_alpha_to_datashader_range(alpha),
min_alpha=min_alpha,
how=how,
)
return _apply_user_alpha(shaded, alpha)

Expand Down
36 changes: 36 additions & 0 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ def render_points(
colorbar: bool | str | None = "auto",
colorbar_params: dict[str, object] | None = None,
datashader_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None,
density: bool = False,
density_how: Literal["linear", "log", "cbrt", "eq_hist"] = "linear",
transfunc: Callable[[float], float] | None = None,
) -> sd.SpatialData:
"""
Expand Down Expand Up @@ -455,13 +457,38 @@ def render_points(
in another column of ``var``. Mimics scanpy's ``gene_symbols`` parameter.
datashader_reduction : Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, optional
Reduction method for datashader when coloring by continuous values. When ``None``, defaults to ``"sum"``.
density : bool, default False
Render the points as a 2-D count density via datashader instead of plotting individual markers.
When ``True``, ``method`` is forced to ``"datashader"`` (passing ``method="matplotlib"`` raises).
Density supports ``color=None`` (plain density) or a categorical ``color`` column (per-category
density via :func:`datashader.by`). A continuous ``color`` column or a literal color value
(e.g. ``"red"``) raises an error. Under ``density=True`` the following parameters are ignored
(with a warning if explicitly set): ``size``, ``transfunc``, ``norm.vmin/vmax``, and
``datashader_reduction``.
density_how : Literal["linear", "log", "cbrt", "eq_hist"], default "linear"
How datashader maps aggregated counts to color intensity. ``"linear"`` (default) keeps the
colorbar axis as a count; ``"log"`` and ``"cbrt"`` compress dynamic range; ``"eq_hist"``
equalizes the histogram (rank-based, surfaces the most structure but the colorbar axis is
no longer a count). Ignored when ``density=False``.
transfunc : Callable[[float], float] | None, optional
Optional transformation applied to the continuous color vector before normalization and colormap mapping.

Returns
-------
sd.SpatialData
A copy of the SpatialData object with the rendering parameters stored in its plotting tree.

Examples
--------
Plain density of all transcripts:

>>> sdata.pl.render_points("transcripts", density=True).pl.show()

Per-gene density with a categorical palette:

>>> sdata.pl.render_points(
... "transcripts", color="gene", groups=["Gad1", "Slc17a7"], palette="tab20", density=True
... ).pl.show()
"""
params_dict = _validate_points_render_params(
self._sdata,
Expand All @@ -480,6 +507,10 @@ def render_points(
colorbar=colorbar,
colorbar_params=colorbar_params,
gene_symbols=gene_symbols,
density=density,
density_how=density_how,
transfunc=transfunc,
method=method,
)

if method is not None:
Expand All @@ -488,6 +519,9 @@ def render_points(
if method not in ["matplotlib", "datashader"]:
raise ValueError("Parameter 'method' must be either 'matplotlib' or 'datashader'.")

if density and method is None:
method = "datashader"

sdata = self._copy()
sdata = _verify_plotting_tree(sdata)
n_steps = len(sdata.plotting_tree.keys())
Expand Down Expand Up @@ -515,6 +549,8 @@ def render_points(
ds_reduction=param_values["ds_reduction"],
colorbar=param_values["colorbar"],
colorbar_params=param_values["colorbar_params"],
density=density,
density_how=density_how,
)
n_steps += 1

Expand Down
60 changes: 57 additions & 3 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,44 @@ def _warn_groups_ignored_continuous(
)


def _is_categorical_like_dtype(dtype: Any) -> bool:
return (
isinstance(dtype, pd.CategoricalDtype)
or pd.api.types.is_object_dtype(dtype)
or pd.api.types.is_string_dtype(dtype)
)


def _reject_continuous_color_under_density(
sdata_filt: sd.SpatialData,
element: str,
col_for_color: str | None,
color_source_vector: Any,
color_vector: Any,
) -> None:
"""Raise before any materialization if density+continuous-color was requested.

``color_source_vector`` is only populated by ``_set_color_source_vec`` for the categorical
branch, so a non-None value is sufficient to accept the call. Otherwise we read the dtype
from the dask source (points element column) or the pre-computed color vector — neither
forces a ``.compute()``.
"""
if col_for_color is None or color_source_vector is not None:
return
points_columns = sdata_filt.points[element].columns
if col_for_color in points_columns:
dtype = sdata_filt.points[element][col_for_color].dtype
else:
dtype = getattr(color_vector, "dtype", None)
if dtype is None or _is_categorical_like_dtype(dtype):
return
raise ValueError(
f"density=True is only supported with no color or a categorical color column; "
f"got continuous column {col_for_color!r}. To color a density plot by a continuous "
f"variable, set density=False and use method='datashader' with datashader_reduction=."
)


def _warn_missing_groups(
groups: str | list[str],
color_source_vector: pd.Categorical,
Expand Down Expand Up @@ -950,7 +988,10 @@ def _render_points(

method = render_params.method

if method is None:
if render_params.density:
method = "datashader"
_reject_continuous_color_under_density(sdata_filt, element, col_for_color, color_source_vector, color_vector)
elif method is None:
method = "datashader" if n_points > 10000 else "matplotlib"

_default_reduction: _DsReduction = "sum"
Expand All @@ -960,7 +1001,11 @@ def _render_points(

# NOTE: s in matplotlib is in units of points**2
# use dpi/100 as a factor for cases where dpi!=100
px = int(np.round(np.sqrt(render_params.size) * (fig_params.fig.dpi / 100)))
# Under density, spreading would smear the count signal across pixels and
# distort apparent density at sparse edges, so disable it unconditionally.
px: int | None = (
None if render_params.density else int(np.round(np.sqrt(render_params.size) * (fig_params.fig.dpi / 100)))
)

# Apply transformations and materialize to pandas immediately so
# datashader aggregates without dask scheduler overhead. See #379.
Expand Down Expand Up @@ -1045,14 +1090,22 @@ def _render_points(
):
color_vector = np.asarray([_hex_no_alpha(c) for c in color_vector])

shade_how = render_params.density_how if render_params.density else "linear"
# Plain density (no color column) must use the user-facing cmap as a sequential
# gradient over counts; the categorical path collapses to a single color and only
# modulates alpha, which renders as a flat hue regardless of density.
plain_density = render_params.density and col_for_color is None

nan_shaded = None
if color_by_categorical or col_for_color is None:
if not plain_density and (color_by_categorical or col_for_color is None):
shaded = _ds_shade_categorical(
agg,
color_key,
color_vector,
render_params.alpha,
spread_px=px,
how=shade_how,
density=render_params.density,
)
else:
shaded, nan_shaded, reduction_bounds = _ds_shade_continuous(
Expand All @@ -1066,6 +1119,7 @@ def _render_points(
na_color_hex,
spread_px=px,
ds_reduction=render_params.ds_reduction,
how=shade_how,
)

_render_ds_image(
Expand Down
2 changes: 2 additions & 0 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ class PointsRenderParams:
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
colorbar: bool | str | None = "auto"
colorbar_params: dict[str, object] | None = None
density: bool = False
density_how: Literal["linear", "log", "cbrt", "eq_hist"] = "linear"


@dataclass
Expand Down
62 changes: 60 additions & 2 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import math
import os
import warnings
from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Callable, Iterable, Mapping, Sequence
from copy import copy
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -2820,7 +2821,17 @@ def _validate_points_render_params(
colorbar: bool | str | None,
colorbar_params: dict[str, object] | None,
gene_symbols: str | None = None,
density: bool = False,
density_how: Literal["linear", "log", "cbrt", "eq_hist"] = "linear",
transfunc: Callable[[float], float] | None = None,
method: str | None = None,
) -> dict[str, dict[str, Any]]:
if not isinstance(density, bool):
raise TypeError("Parameter 'density' must be a bool.")
allowed_how = ("linear", "log", "cbrt", "eq_hist")
if density_how not in allowed_how:
raise ValueError(f"Parameter 'density_how' must be one of {allowed_how}; got {density_how!r}.")

param_dict: dict[str, Any] = {
"sdata": sdata,
"element": element,
Expand All @@ -2840,6 +2851,47 @@ def _validate_points_render_params(
}
param_dict = _type_check_params(param_dict, "points")

if density:
if method == "matplotlib":
raise ValueError(
"density=True requires the datashader backend; got method='matplotlib'. "
"Either drop method= or set method='datashader'."
)
# Literal color (resolved into param_dict["color"] as a Color instance, with
# col_for_color set to None) is ambiguous with density: it could mean a
# single-hue cmap or a one-entry palette. Force the user to choose.
if param_dict["color"] is not None and param_dict["col_for_color"] is None:
raise ValueError(
"density=True with a literal color is ambiguous. Pass cmap= to recolor the "
"density, or palette= to assign a categorical color, but not color=<literal>."
)
# Warn-and-ignore: these parameters do not interact meaningfully with a
# count-based density and are silently dropped to keep the API consistent.
if size != 1.0:
warnings.warn(
"size is ignored when density=True; spreading would distort the count signal.",
UserWarning,
stacklevel=3,
)
if transfunc is not None:
warnings.warn(
"transfunc is ignored when density=True (no continuous color vector to transform).",
UserWarning,
stacklevel=3,
)
if isinstance(norm, Normalize) and (norm.vmin is not None or norm.vmax is not None):
warnings.warn(
"norm.vmin/vmax are ignored when density=True; use density_how= to control intensity mapping.",
UserWarning,
stacklevel=3,
)
if ds_reduction is not None:
warnings.warn(
"datashader_reduction is ignored when density=True; counts are forced.",
UserWarning,
stacklevel=3,
)

element_params: dict[str, dict[str, Any]] = {}
for el in param_dict["element"]:
# ensure that the element exists in the SpatialData object
Expand Down Expand Up @@ -3715,11 +3767,17 @@ def _datashader_map_aggregate_to_color(
min_alpha: float = 40,
span: None | list[float] = None,
clip: bool = True,
how: str = "linear",
) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]:
"""ds.tf.shade() part, ensuring correct clipping behavior.

If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results.
This ensures the correct clipping behavior, because else datashader would always automatically clip.

``how`` controls the count-to-color mapping passed to :func:`datashader.transfer_functions.shade`
(``"linear"`` by default; ``"log"``/``"cbrt"``/``"eq_hist"`` compress dynamic range). The split-shade
branch used for ``norm.clip=False`` always uses ``"linear"`` since per-segment shading would otherwise
interact poorly with rank-based mappings.
"""
if not clip and isinstance(cmap, Colormap) and span is not None:
# in case we use datashader together with a Normalize object where clip=False
Expand Down Expand Up @@ -3768,7 +3826,7 @@ def _datashader_map_aggregate_to_color(
color_key=color_key,
min_alpha=min_alpha,
span=span,
how="linear",
how=how,
)
return _apply_cmap_alpha_to_datashader_result(result, agg, cmap, span)

Expand Down
Binary file added tests/_images/Points_density_categorical.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/Points_density_how_eq_hist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/Points_density_plain.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,34 @@ def sdata_blobs() -> SpatialData:
return blobs()


@pytest.fixture()
def sdata_dense_points() -> SpatialData:
"""Dense (~20k) multi-cluster points dataset for density-rendering visual tests.

The blobs fixture is too sparse (~200 points across 500x500) for density to render
meaningfully without spreading; this fixture provides a Gaussian-cluster cloud with
a categorical ``gene`` column so the per-category density branch is exercised too.
"""
rng = get_standard_RNG()
n_per_cluster = 20000
centers = [(120, 120), (380, 150), (250, 380)]
genes = ["gene_a", "gene_b", "gene_c"]
xs, ys, gs = [], [], []
for (cx, cy), gene in zip(centers, genes, strict=True):
xs.append(rng.normal(loc=cx, scale=18, size=n_per_cluster))
ys.append(rng.normal(loc=cy, scale=18, size=n_per_cluster))
gs.extend([gene] * n_per_cluster)
df = pd.DataFrame(
{
"x": np.concatenate(xs).clip(0, 500),
"y": np.concatenate(ys).clip(0, 500),
"gene": pd.Categorical(gs, categories=genes),
}
)
points = PointsModel.parse(df)
return SpatialData(points={"dense_points": points})


@pytest.fixture()
def sdata_blobs_str() -> SpatialData:
return blobs(n_channels=5, c_coords=["c1", "c2", "c3", "c4", "c5"])
Expand Down
Loading
Loading