Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/spatialdata_plot/pl/_datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ def _render_ds_image(
shaded: Any,
factor: float,
zorder: int,
extent: list[float] | None,
x_min: float = 0.0,
y_min: float = 0.0,
nan_result: Any | None = None,
) -> Any:
"""Render a shaded datashader image onto matplotlib axes, with optional NaN overlay.
Expand All @@ -316,10 +317,10 @@ def _render_ds_image(
it again would apply transparency twice (see #367).
"""
if nan_result is not None:
rgba_nan, trans_nan = _create_image_from_datashader_result(nan_result, factor, ax)
_ax_show_and_transform(rgba_nan, trans_nan, ax, zorder=zorder, extent=extent)
rgba_image, trans_data = _create_image_from_datashader_result(shaded, factor, ax)
return _ax_show_and_transform(rgba_image, trans_data, ax, zorder=zorder, extent=extent)
rgba_nan, trans_nan = _create_image_from_datashader_result(nan_result, factor, ax, x_min, y_min)
_ax_show_and_transform(rgba_nan, trans_nan, ax, zorder=zorder)
rgba_image, trans_data = _create_image_from_datashader_result(shaded, factor, ax, x_min, y_min)
return _ax_show_and_transform(rgba_image, trans_data, ax, zorder=zorder)


def _render_ds_outlines(
Expand All @@ -329,7 +330,8 @@ def _render_ds_outlines(
fig_params: FigParams,
ax: matplotlib.axes.SubplotBase,
factor: float,
extent: list[float],
x_min: float = 0.0,
y_min: float = 0.0,
) -> None:
"""Aggregate, shade, and render shape outlines (outer and inner) with datashader."""
ds_lw_factor = fig_params.fig.dpi / 72
Expand Down Expand Up @@ -357,8 +359,8 @@ def _render_ds_outlines(
how="linear",
)
shaded = _apply_user_alpha(shaded, alpha)
rgba, trans = _create_image_from_datashader_result(shaded, factor, ax)
_ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder, extent=extent)
rgba, trans = _create_image_from_datashader_result(shaded, factor, ax, x_min, y_min)
_ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder)


def _build_ds_colorbar(
Expand Down
24 changes: 19 additions & 5 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,13 @@ def _render_shapes(
)
)

if len(transformed_element) == 0:
# Nothing to rasterize (e.g., a bounding_box_query that matched no
# shapes). Skip the datashader pipeline.
return

plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas(
transformed_element, "global", ax, fig_params
transformed_element, "global", fig_params
)

cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext)
Expand Down Expand Up @@ -599,15 +604,17 @@ def _render_shapes(
fig_params,
ax,
factor,
x_ext + y_ext,
x_min=x_ext[0],
y_min=y_ext[0],
)

_cax = _render_ds_image(
ax,
shaded,
factor,
render_params.zorder,
x_ext + y_ext,
x_min=x_ext[0],
y_min=y_ext[0],
nan_result=nan_shaded,
)

Expand Down Expand Up @@ -929,8 +936,14 @@ def _render_points(
transformations={coordinate_system: Identity()},
).compute()

if len(transformed_element) == 0:
# Nothing to rasterize (e.g., a bounding_box_query that matched no
# points). Skip the datashader pipeline; rendering proceeds with
# any other elements on the axes.
return

plot_width, plot_height, x_ext, y_ext, factor = _datashader_canvas_from_dataframe(
transformed_element, ax, fig_params
transformed_element, fig_params
)

# use datashader for the visualization of points
Expand Down Expand Up @@ -1026,7 +1039,8 @@ def _render_points(
shaded,
factor,
render_params.zorder,
x_ext + y_ext,
x_min=x_ext[0],
y_min=y_ext[0],
nan_result=nan_shaded,
)

Expand Down
98 changes: 37 additions & 61 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@
from spatialdata._types import ArrayLike
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement, get_table_keys
from spatialdata.transformations.operations import get_transformation
from spatialdata.transformations.transformations import Scale
from spatialdata.transformations.transformations import Scale, Translation
from spatialdata.transformations.transformations import Sequence as TransformSequence
from xarray import DataArray, DataTree

from spatialdata_plot._logging import logger
Expand Down Expand Up @@ -2029,7 +2030,12 @@ def _rasterize_if_necessary(

if do_rasterization:
logger.info("Rasterizing image for faster rendering.")
target_unit_to_pixels = min(target_y_dims / y_dims, target_x_dims / x_dims)
# ``rasterize`` interprets ``target_unit_to_pixels`` in world units, not
# intrinsic pixels. Dividing by world extent keeps the result correct
# for any transformation (translation, scale, etc.).
world_x = float(extent["x"][1]) - float(extent["x"][0])
world_y = float(extent["y"][1]) - float(extent["y"][0])
target_unit_to_pixels = min(target_y_dims / world_y, target_x_dims / world_x)
image = rasterize(
image,
("y", "x"),
Expand Down Expand Up @@ -3378,42 +3384,20 @@ def _ax_show_and_transform(
alpha: float | None = None,
cmap: ListedColormap | LinearSegmentedColormap | None = None,
zorder: int = 0,
extent: list[float] | None = None,
norm: Normalize | None = None,
) -> matplotlib.image.AxesImage:
# default extent in mpl:
image_extent = [-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5]
if extent is not None:
# make sure extent is [x_min, x_max, y_min, y_max]
if extent[3] < extent[2]:
extent[2], extent[3] = extent[3], extent[2]
if extent[0] < 0:
x_factor = array.shape[1] / (extent[1] - extent[0])
image_extent[0] = image_extent[0] + (extent[0] * x_factor)
image_extent[1] = image_extent[1] + (extent[0] * x_factor)
if extent[2] < 0:
y_factor = array.shape[0] / (extent[3] - extent[2])
image_extent[2] = image_extent[2] + (extent[2] * y_factor)
image_extent[3] = image_extent[3] + (extent[2] * y_factor)

# ``extent`` uses mpl's pixel-grid convention; world placement happens via
# ``set_transform(trans_data)`` afterwards.
image_extent = (-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5)
# ``alpha`` is applied only when no cmap is set, so RGBA arrays already
# carrying per-pixel alpha (e.g. datashader output) are not double-attenuated.
imshow_kwargs: dict[str, Any] = {"zorder": zorder, "extent": image_extent, "norm": norm}
if not cmap and alpha is not None:
im = ax.imshow(
array,
alpha=alpha,
zorder=zorder,
extent=tuple(image_extent),
norm=norm,
)
im.set_transform(trans_data)
imshow_kwargs["alpha"] = alpha
else:
im = ax.imshow(
array,
cmap=cmap,
zorder=zorder,
extent=tuple(image_extent),
norm=norm,
)
im.set_transform(trans_data)
imshow_kwargs["cmap"] = cmap
im = ax.imshow(array, **imshow_kwargs)
im.set_transform(trans_data)
return im


Expand Down Expand Up @@ -3442,30 +3426,12 @@ def set_zero_in_cmap_to_transparent(cmap: Colormap | str, steps: int | None = No
def _compute_datashader_canvas_params(
x_ext: list[Any],
y_ext: list[Any],
ax: Axes,
fig_params: FigParams,
) -> tuple[Any, Any, list[Any], list[Any], Any]:
"""Compute datashader canvas dimensions from spatial extents.

Shared logic used by both the dask-based and pandas-based entry points.
"""
previous_xlim = ax.get_xlim()
previous_ylim = ax.get_ylim()
# increase range if sth larger was rendered on the axis before
if _mpl_ax_contains_elements(ax):
x_ext = [min(x_ext[0], previous_xlim[0]), max(x_ext[1], previous_xlim[1])]
y_ext = (
[
min(y_ext[0], previous_ylim[1]),
max(y_ext[1], previous_ylim[0]),
]
if ax.yaxis_inverted()
else [
min(y_ext[0], previous_ylim[0]),
max(y_ext[1], previous_ylim[1]),
]
)

# Compute canvas size in pixels, capped at the figure's display resolution.
# Using np.max ensures the canvas never exceeds display pixels on either axis,
# preventing pixel-based operations (spread, line_width) from being downscaled
Expand All @@ -3485,42 +3451,52 @@ def _compute_datashader_canvas_params(
def _get_extent_and_range_for_datashader_canvas(
spatial_element: SpatialElement,
coordinate_system: str,
ax: Axes,
fig_params: FigParams,
) -> tuple[Any, Any, list[Any], list[Any], Any]:
extent = get_extent(spatial_element, coordinate_system=coordinate_system)
x_ext = [min(0, extent["x"][0]), extent["x"][1]]
y_ext = [min(0, extent["y"][0]), extent["y"][1]]
return _compute_datashader_canvas_params(x_ext, y_ext, ax, fig_params)
x_ext = [float(extent["x"][0]), float(extent["x"][1])]
y_ext = [float(extent["y"][0]), float(extent["y"][1])]
return _compute_datashader_canvas_params(x_ext, y_ext, fig_params)


def _datashader_canvas_from_dataframe(
df: pd.DataFrame,
ax: Axes,
fig_params: FigParams,
) -> tuple[Any, Any, list[Any], list[Any], Any]:
"""Compute datashader canvas params directly from a pandas DataFrame.

Avoids the overhead of ``get_extent()`` (which requires a dask-backed
SpatialElement) by reading min/max from the already-materialised data.
"""
x_ext = [min(0, float(df["x"].min())), float(df["x"].max())]
y_ext = [min(0, float(df["y"].min())), float(df["y"].max())]
return _compute_datashader_canvas_params(x_ext, y_ext, ax, fig_params)
if len(df) == 0:
# Empty input (e.g., a bounding_box_query with no overlap) — caller
# should short-circuit; return zero-sized canvas params as a sentinel.
return 0, 0, [0.0, 0.0], [0.0, 0.0], 1.0
x_ext = [float(df["x"].min()), float(df["x"].max())]
y_ext = [float(df["y"].min()), float(df["y"].max())]
return _compute_datashader_canvas_params(x_ext, y_ext, fig_params)


def _create_image_from_datashader_result(
ds_result: ds.transfer_functions.Image | np.ndarray[Any, np.dtype[np.uint8]],
factor: float,
ax: Axes,
x_min: float = 0.0,
y_min: float = 0.0,
) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.Transform]:
# create SpatialImage from datashader output to get it back to original size
rgba_image_data = ds_result.copy() if isinstance(ds_result, np.ndarray) else ds_result.to_numpy().base
rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1))
transformation: Scale | TransformSequence = Scale([1, factor, factor], ("c", "y", "x"))
if x_min != 0.0 or y_min != 0.0:
# Canvas pixel (0, 0) corresponds to world (x_min, y_min). Without this
# translation the rgba would render at the world origin instead of at
# the element's actual position.
transformation = TransformSequence([transformation, Translation([x_min, y_min], ("x", "y"))])
rgba_image = Image2DModel.parse(
rgba_image_data,
dims=("c", "y", "x"),
transformations={"global": Scale([1, factor, factor], ("c", "y", "x"))},
transformations={"global": transformation},
)

_, trans_data = _prepare_transformation(rgba_image, "global", ax)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Shapes_datashader_can_render_with_outline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
101 changes: 101 additions & 0 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,3 +1031,104 @@ def test_render_points_disjoint_instance_ids_clear_error():
sdata.pl.render_points("pts", color="cat", table_name="t").pl.show(ax=ax)
finally:
plt.close(fig)


def _make_offset_points_sdata(offset: tuple[float, float] = (10000.0, 18000.0), n: int = 100) -> SpatialData:
rng = np.random.default_rng(0)
df = pd.DataFrame(
{
"x": rng.uniform(0, 200, size=n),
"y": rng.uniform(0, 200, size=n),
}
)
pts = PointsModel.parse(df, transformations={"global": Translation(list(offset), axes=("x", "y"))})
return SpatialData(points={"pts": pts})


def test_datashader_canvas_preserves_resolution_under_bbox_query():
# regression test for #668: bounding_box_query on a translated SpatialData
# must not collapse the datashader canvas resolution. The factor (world
# units per canvas pixel) at offset must match the no-offset baseline.
from spatialdata import bounding_box_query

from spatialdata_plot.pl.render_params import FigParams
from spatialdata_plot.pl.utils import _datashader_canvas_from_dataframe

baseline_sdata = _make_offset_points_sdata(offset=(0.0, 0.0))
offset_sdata = _make_offset_points_sdata(offset=(10000.0, 18000.0))

baseline_crop = bounding_box_query(
baseline_sdata["pts"],
axes=("x", "y"),
min_coordinate=[50, 50],
max_coordinate=[150, 150],
target_coordinate_system="global",
)
offset_crop = bounding_box_query(
offset_sdata["pts"],
axes=("x", "y"),
min_coordinate=[10050, 18050],
max_coordinate=[10150, 18150],
target_coordinate_system="global",
)

fig = plt.figure(figsize=(6, 6), dpi=100)
fig_params = FigParams(fig=fig, ax=fig.gca(), num_panels=1)
try:
# Compare canvas params on the post-transform world-coord frame, which
# is what _render_points feeds to _datashader_canvas_from_dataframe.
base_df = baseline_crop.compute()
offset_df = offset_crop.compute()
offset_df_world = offset_df.copy()
offset_df_world["x"] = offset_df_world["x"] + 10000.0
offset_df_world["y"] = offset_df_world["y"] + 18000.0

_, _, _, _, factor_baseline = _datashader_canvas_from_dataframe(base_df, fig_params)
_, _, _, _, factor_offset = _datashader_canvas_from_dataframe(offset_df_world, fig_params)

# Without the fix factor_offset would be ~60x factor_baseline.
assert abs(factor_offset - factor_baseline) < 1e-6, (
f"datashader factor leaked offset: baseline={factor_baseline}, offset={factor_offset}"
)
finally:
plt.close(fig)


def test_render_points_datashader_under_bbox_query_does_not_crash():
# regression test for #668: rendering a bbox_query result with datashader
# must produce a figure (previously: resolution collapse → effectively blank).
from spatialdata import bounding_box_query

sdata = _make_offset_points_sdata()
cropped = bounding_box_query(
sdata["pts"],
axes=("x", "y"),
min_coordinate=[10050, 18050],
max_coordinate=[10150, 18150],
target_coordinate_system="global",
)
cropped_sdata = SpatialData(points={"pts": cropped})

fig, ax = plt.subplots()
try:
cropped_sdata.pl.render_points("pts", method="datashader").pl.show(ax=ax)
finally:
plt.close(fig)


def test_datashader_canvas_from_empty_dataframe_does_not_crash():
# regression test for #668: _datashader_canvas_from_dataframe used to
# crash with ``ValueError: cannot convert float NaN to integer`` when fed
# an empty DataFrame (NaN min()/max() → int cast). The helper now returns
# a zero-sized sentinel so callers can short-circuit cleanly.
from spatialdata_plot.pl.render_params import FigParams
from spatialdata_plot.pl.utils import _datashader_canvas_from_dataframe

empty_df = pd.DataFrame({"x": pd.Series(dtype=float), "y": pd.Series(dtype=float)})
fig = plt.figure(figsize=(6, 6), dpi=100)
fig_params = FigParams(fig=fig, ax=fig.gca(), num_panels=1)
try:
plot_width, plot_height, _, _, _ = _datashader_canvas_from_dataframe(empty_df, fig_params)
assert plot_width == 0 and plot_height == 0
finally:
plt.close(fig)
Loading
Loading