Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/spatialdata_plot/pl/_datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@
# ---------------------------------------------------------------------------


def _apply_user_alpha(result: ds.tf.Image | np.ndarray, alpha: float) -> ds.tf.Image | np.ndarray:
"""Scale the alpha channel of a datashader shade result by ``alpha``.

``ds.tf.shade(min_alpha=...)`` is a floor, not a scale, so user alpha
must be applied post-hoc. See #617.
"""
if alpha >= 1.0 or result is None:
return result
arr = result if isinstance(result, np.ndarray) else result.to_numpy().base
if arr is None or arr.ndim != 3 or arr.shape[-1] != 4:
return result
arr[..., 3] = (arr[..., 3].astype(np.float32) * alpha).astype(np.uint8)
return result


def _coerce_categorical_source(series: pd.Series | dd.Series) -> pd.Categorical:
"""Return a ``pd.Categorical`` from a pandas or dask Series."""
if isinstance(series, dd.Series):
Expand Down Expand Up @@ -241,6 +256,7 @@ def _ds_shade_continuous(
span=color_span,
clip=norm.clip,
)
shaded = _apply_user_alpha(shaded, alpha)

nan_shaded = None
if nan_agg is not None:
Expand All @@ -251,6 +267,7 @@ def _ds_shade_continuous(
# only shapes (no spread) pass min_alpha for NaN shading
shade_kwargs["min_alpha"] = _convert_alpha_to_datashader_range(alpha)
nan_shaded = ds.tf.shade(nan_agg, **shade_kwargs)
nan_shaded = _apply_user_alpha(nan_shaded, alpha)

return shaded, nan_shaded, reduction_bounds

Expand All @@ -270,12 +287,13 @@ def _ds_shade_categorical(
ds_cmap = _hex_no_alpha(ds_cmap)

agg_to_shade = ds.tf.spread(agg, px=spread_px) if spread_px is not None else agg
return _datashader_map_aggregate_to_color(
shaded = _datashader_map_aggregate_to_color(
agg_to_shade,
cmap=ds_cmap,
color_key=color_key,
min_alpha=_convert_alpha_to_datashader_range(alpha),
)
return _apply_user_alpha(shaded, alpha)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -338,6 +356,7 @@ def _render_ds_outlines(
min_alpha=_convert_alpha_to_datashader_range(alpha),
how="linear",
)
shaded = _apply_user_alpha(shaded, alpha)
rgba, trans = _create_image_from_datashader_result(shaded, factor, ax)
_ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder, extent=extent)

Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def _set_outline(
"""
# A) User doesn't want to see outlines
if (
(outline_alpha and outline_alpha == 0.0)
outline_alpha == 0.0
or (isinstance(outline_alpha, tuple) and np.all(np.array(outline_alpha) == 0.0))
or not (outline_alpha or outline_width or outline_color)
):
Expand Down
Binary file modified tests/_images/Points_datashader_continuous_color.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
62 changes: 62 additions & 0 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,68 @@ def test_datashader_alpha_not_applied_twice(sdata_blobs: SpatialData):
plt.close(fig)


@pytest.mark.parametrize(
("fill_alpha", "expected_max"),
[
(0.0, 0),
(0.3, 76),
(0.5, 127),
(1.0, 255),
],
)
def test_datashader_respects_fill_alpha(sdata_blobs: SpatialData, fill_alpha: float, expected_max: int):
"""fill_alpha must scale the rendered alpha channel linearly on the datashader path (#617)."""
fig, ax = plt.subplots()
sdata_blobs.pl.render_shapes(
element="blobs_polygons",
method="datashader",
fill_alpha=fill_alpha,
).pl.show(ax=ax)
fig.canvas.draw()

axes_images = [c for c in ax.get_children() if isinstance(c, matplotlib.image.AxesImage)]
assert axes_images
rgba = axes_images[0].get_array()
assert rgba.ndim == 3 and rgba.shape[-1] == 4
assert int(rgba[..., 3].max()) == expected_max
plt.close(fig)


@pytest.mark.parametrize(
("outline_alpha", "expected_max"),
[
(0.0, None),
(0.3, 76),
(0.5, 127),
(1.0, 255),
],
)
def test_datashader_respects_outline_alpha(sdata_blobs: SpatialData, outline_alpha: float, expected_max: int | None):
"""outline_alpha must scale the outline image's alpha; alpha=0 must skip rendering entirely (#617)."""
fig, ax = plt.subplots()
sdata_blobs.pl.render_shapes(
element="blobs_polygons",
method="datashader",
fill_alpha=1.0,
outline_alpha=outline_alpha,
outline_color="red",
).pl.show(ax=ax)
fig.canvas.draw()

axes_images = [c for c in ax.get_children() if isinstance(c, matplotlib.image.AxesImage)]
outline_imgs = [
img
for img in axes_images
if (arr := img.get_array()).ndim == 3 and arr.shape[-1] == 4 and arr[..., 0].max() > arr[..., 1].max()
]
if expected_max is None:
assert not outline_imgs
else:
assert outline_imgs
assert int(outline_imgs[0].get_array()[..., 3].max()) == expected_max
plt.close(fig)


def test_render_shapes_color_with_conflicting_index_name():
"""render_shapes(color=...) must not crash when obs.index.name matches an existing column.

Expand Down
18 changes: 18 additions & 0 deletions tests/pl/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from spatialdata import SpatialData

import spatialdata_plot
from spatialdata_plot.pl.render_params import Color
from spatialdata_plot.pl.utils import (
_apply_cmap_alpha_to_datashader_result,
_datashader_map_aggregate_to_color,
_get_subplots,
_set_outline,
set_zero_in_cmap_to_transparent,
)
from tests.conftest import DPI, PlotTester, PlotTesterMeta
Expand Down Expand Up @@ -164,6 +166,22 @@ def test_is_color_like(color_result: tuple[ColorLike, bool]):
assert spatialdata_plot.pl.utils._is_color_like(color) == result


@pytest.mark.parametrize(
("outline_alpha", "outline_color", "expected"),
[
(0.0, Color("#ff0000"), (0.0, 0.0)),
(0, Color("#ff0000"), (0.0, 0.0)),
((0.0, 0.0), Color("#ff0000"), (0.0, 0.0)),
(0.5, Color("#ff0000"), (0.5, 0.0)),
(1.0, Color("#ff0000"), (1.0, 0.0)),
],
)
def test_set_outline_respects_zero_alpha(outline_alpha, outline_color, expected):
"""outline_alpha=0 must yield (0.0, 0.0) even when outline_color is set (#617 follow-up)."""
alpha, _ = _set_outline(outline_alpha=outline_alpha, outline_width=None, outline_color=outline_color)
assert alpha == expected


class TestCmapAlphaDatashader:
"""Regression tests for #376: set_zero_in_cmap_to_transparent with datashader."""

Expand Down
Loading