diff --git a/src/spatialdata_plot/pl/_datashader.py b/src/spatialdata_plot/pl/_datashader.py index b5cf0532..1d6415ed 100644 --- a/src/spatialdata_plot/pl/_datashader.py +++ b/src/spatialdata_plot/pl/_datashader.py @@ -306,7 +306,8 @@ def _render_ds_image( shaded: Any, factor: float, zorder: int, - extent: list[float] | None, + x_min: float = 0.0, + y_min: float = 0.0, nan_result: Any | None = None, ) -> Any: """Render a shaded datashader image onto matplotlib axes, with optional NaN overlay. @@ -316,10 +317,10 @@ def _render_ds_image( it again would apply transparency twice (see #367). """ if nan_result is not None: - rgba_nan, trans_nan = _create_image_from_datashader_result(nan_result, factor, ax) - _ax_show_and_transform(rgba_nan, trans_nan, ax, zorder=zorder, extent=extent) - rgba_image, trans_data = _create_image_from_datashader_result(shaded, factor, ax) - return _ax_show_and_transform(rgba_image, trans_data, ax, zorder=zorder, extent=extent) + rgba_nan, trans_nan = _create_image_from_datashader_result(nan_result, factor, ax, x_min, y_min) + _ax_show_and_transform(rgba_nan, trans_nan, ax, zorder=zorder) + rgba_image, trans_data = _create_image_from_datashader_result(shaded, factor, ax, x_min, y_min) + return _ax_show_and_transform(rgba_image, trans_data, ax, zorder=zorder) def _render_ds_outlines( @@ -329,7 +330,8 @@ def _render_ds_outlines( fig_params: FigParams, ax: matplotlib.axes.SubplotBase, factor: float, - extent: list[float], + x_min: float = 0.0, + y_min: float = 0.0, ) -> None: """Aggregate, shade, and render shape outlines (outer and inner) with datashader.""" ds_lw_factor = fig_params.fig.dpi / 72 @@ -357,8 +359,8 @@ def _render_ds_outlines( how="linear", ) shaded = _apply_user_alpha(shaded, alpha) - rgba, trans = _create_image_from_datashader_result(shaded, factor, ax) - _ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder, extent=extent) + rgba, trans = _create_image_from_datashader_result(shaded, factor, ax, x_min, y_min) + _ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder) def _build_ds_colorbar( diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 9e2465a7..0b43b54a 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -517,8 +517,13 @@ def _render_shapes( ) ) + if len(transformed_element) == 0: + # Nothing to rasterize (e.g., a bounding_box_query that matched no + # shapes). Skip the datashader pipeline. + return + plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas( - transformed_element, "global", ax, fig_params + transformed_element, "global", fig_params ) cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext) @@ -599,7 +604,8 @@ def _render_shapes( fig_params, ax, factor, - x_ext + y_ext, + x_min=x_ext[0], + y_min=y_ext[0], ) _cax = _render_ds_image( @@ -607,7 +613,8 @@ def _render_shapes( shaded, factor, render_params.zorder, - x_ext + y_ext, + x_min=x_ext[0], + y_min=y_ext[0], nan_result=nan_shaded, ) @@ -929,8 +936,14 @@ def _render_points( transformations={coordinate_system: Identity()}, ).compute() + if len(transformed_element) == 0: + # Nothing to rasterize (e.g., a bounding_box_query that matched no + # points). Skip the datashader pipeline; rendering proceeds with + # any other elements on the axes. + return + plot_width, plot_height, x_ext, y_ext, factor = _datashader_canvas_from_dataframe( - transformed_element, ax, fig_params + transformed_element, fig_params ) # use datashader for the visualization of points @@ -1026,7 +1039,8 @@ def _render_points( shaded, factor, render_params.zorder, - x_ext + y_ext, + x_min=x_ext[0], + y_min=y_ext[0], nan_result=nan_shaded, ) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 968fbe53..35844c6a 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -66,7 +66,8 @@ from spatialdata._types import ArrayLike from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement, get_table_keys from spatialdata.transformations.operations import get_transformation -from spatialdata.transformations.transformations import Scale +from spatialdata.transformations.transformations import Scale, Translation +from spatialdata.transformations.transformations import Sequence as TransformSequence from xarray import DataArray, DataTree from spatialdata_plot._logging import logger @@ -2029,7 +2030,12 @@ def _rasterize_if_necessary( if do_rasterization: logger.info("Rasterizing image for faster rendering.") - target_unit_to_pixels = min(target_y_dims / y_dims, target_x_dims / x_dims) + # ``rasterize`` interprets ``target_unit_to_pixels`` in world units, not + # intrinsic pixels. Dividing by world extent keeps the result correct + # for any transformation (translation, scale, etc.). + world_x = float(extent["x"][1]) - float(extent["x"][0]) + world_y = float(extent["y"][1]) - float(extent["y"][0]) + target_unit_to_pixels = min(target_y_dims / world_y, target_x_dims / world_x) image = rasterize( image, ("y", "x"), @@ -3378,42 +3384,20 @@ def _ax_show_and_transform( alpha: float | None = None, cmap: ListedColormap | LinearSegmentedColormap | None = None, zorder: int = 0, - extent: list[float] | None = None, norm: Normalize | None = None, ) -> matplotlib.image.AxesImage: - # default extent in mpl: - image_extent = [-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5] - if extent is not None: - # make sure extent is [x_min, x_max, y_min, y_max] - if extent[3] < extent[2]: - extent[2], extent[3] = extent[3], extent[2] - if extent[0] < 0: - x_factor = array.shape[1] / (extent[1] - extent[0]) - image_extent[0] = image_extent[0] + (extent[0] * x_factor) - image_extent[1] = image_extent[1] + (extent[0] * x_factor) - if extent[2] < 0: - y_factor = array.shape[0] / (extent[3] - extent[2]) - image_extent[2] = image_extent[2] + (extent[2] * y_factor) - image_extent[3] = image_extent[3] + (extent[2] * y_factor) - + # ``extent`` uses mpl's pixel-grid convention; world placement happens via + # ``set_transform(trans_data)`` afterwards. + image_extent = (-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5) + # ``alpha`` is applied only when no cmap is set, so RGBA arrays already + # carrying per-pixel alpha (e.g. datashader output) are not double-attenuated. + imshow_kwargs: dict[str, Any] = {"zorder": zorder, "extent": image_extent, "norm": norm} if not cmap and alpha is not None: - im = ax.imshow( - array, - alpha=alpha, - zorder=zorder, - extent=tuple(image_extent), - norm=norm, - ) - im.set_transform(trans_data) + imshow_kwargs["alpha"] = alpha else: - im = ax.imshow( - array, - cmap=cmap, - zorder=zorder, - extent=tuple(image_extent), - norm=norm, - ) - im.set_transform(trans_data) + imshow_kwargs["cmap"] = cmap + im = ax.imshow(array, **imshow_kwargs) + im.set_transform(trans_data) return im @@ -3442,30 +3426,12 @@ def set_zero_in_cmap_to_transparent(cmap: Colormap | str, steps: int | None = No def _compute_datashader_canvas_params( x_ext: list[Any], y_ext: list[Any], - ax: Axes, fig_params: FigParams, ) -> tuple[Any, Any, list[Any], list[Any], Any]: """Compute datashader canvas dimensions from spatial extents. Shared logic used by both the dask-based and pandas-based entry points. """ - previous_xlim = ax.get_xlim() - previous_ylim = ax.get_ylim() - # increase range if sth larger was rendered on the axis before - if _mpl_ax_contains_elements(ax): - x_ext = [min(x_ext[0], previous_xlim[0]), max(x_ext[1], previous_xlim[1])] - y_ext = ( - [ - min(y_ext[0], previous_ylim[1]), - max(y_ext[1], previous_ylim[0]), - ] - if ax.yaxis_inverted() - else [ - min(y_ext[0], previous_ylim[0]), - max(y_ext[1], previous_ylim[1]), - ] - ) - # Compute canvas size in pixels, capped at the figure's display resolution. # Using np.max ensures the canvas never exceeds display pixels on either axis, # preventing pixel-based operations (spread, line_width) from being downscaled @@ -3485,18 +3451,16 @@ def _compute_datashader_canvas_params( def _get_extent_and_range_for_datashader_canvas( spatial_element: SpatialElement, coordinate_system: str, - ax: Axes, fig_params: FigParams, ) -> tuple[Any, Any, list[Any], list[Any], Any]: extent = get_extent(spatial_element, coordinate_system=coordinate_system) - x_ext = [min(0, extent["x"][0]), extent["x"][1]] - y_ext = [min(0, extent["y"][0]), extent["y"][1]] - return _compute_datashader_canvas_params(x_ext, y_ext, ax, fig_params) + x_ext = [float(extent["x"][0]), float(extent["x"][1])] + y_ext = [float(extent["y"][0]), float(extent["y"][1])] + return _compute_datashader_canvas_params(x_ext, y_ext, fig_params) def _datashader_canvas_from_dataframe( df: pd.DataFrame, - ax: Axes, fig_params: FigParams, ) -> tuple[Any, Any, list[Any], list[Any], Any]: """Compute datashader canvas params directly from a pandas DataFrame. @@ -3504,23 +3468,35 @@ def _datashader_canvas_from_dataframe( Avoids the overhead of ``get_extent()`` (which requires a dask-backed SpatialElement) by reading min/max from the already-materialised data. """ - x_ext = [min(0, float(df["x"].min())), float(df["x"].max())] - y_ext = [min(0, float(df["y"].min())), float(df["y"].max())] - return _compute_datashader_canvas_params(x_ext, y_ext, ax, fig_params) + if len(df) == 0: + # Empty input (e.g., a bounding_box_query with no overlap) — caller + # should short-circuit; return zero-sized canvas params as a sentinel. + return 0, 0, [0.0, 0.0], [0.0, 0.0], 1.0 + x_ext = [float(df["x"].min()), float(df["x"].max())] + y_ext = [float(df["y"].min()), float(df["y"].max())] + return _compute_datashader_canvas_params(x_ext, y_ext, fig_params) def _create_image_from_datashader_result( ds_result: ds.transfer_functions.Image | np.ndarray[Any, np.dtype[np.uint8]], factor: float, ax: Axes, + x_min: float = 0.0, + y_min: float = 0.0, ) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.Transform]: # create SpatialImage from datashader output to get it back to original size rgba_image_data = ds_result.copy() if isinstance(ds_result, np.ndarray) else ds_result.to_numpy().base rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1)) + transformation: Scale | TransformSequence = Scale([1, factor, factor], ("c", "y", "x")) + if x_min != 0.0 or y_min != 0.0: + # Canvas pixel (0, 0) corresponds to world (x_min, y_min). Without this + # translation the rgba would render at the world origin instead of at + # the element's actual position. + transformation = TransformSequence([transformation, Translation([x_min, y_min], ("x", "y"))]) rgba_image = Image2DModel.parse( rgba_image_data, dims=("c", "y", "x"), - transformations={"global": Scale([1, factor, factor], ("c", "y", "x"))}, + transformations={"global": transformation}, ) _, trans_data = _prepare_transformation(rgba_image, "global", ax) diff --git a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_scale.png b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_scale.png index 617859f9..300915a3 100644 Binary files a/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_scale.png and b/tests/_images/NotebooksTransformations_can_render_transformations_raccoon_scale.png differ diff --git a/tests/_images/Shapes_datashader_can_render_shapes_with_colored_double_outline.png b/tests/_images/Shapes_datashader_can_render_shapes_with_colored_double_outline.png index d5468723..efbca31d 100644 Binary files a/tests/_images/Shapes_datashader_can_render_shapes_with_colored_double_outline.png and b/tests/_images/Shapes_datashader_can_render_shapes_with_colored_double_outline.png differ diff --git a/tests/_images/Shapes_datashader_can_render_shapes_with_double_outline.png b/tests/_images/Shapes_datashader_can_render_shapes_with_double_outline.png index 6bb56592..92f3096f 100644 Binary files a/tests/_images/Shapes_datashader_can_render_shapes_with_double_outline.png and b/tests/_images/Shapes_datashader_can_render_shapes_with_double_outline.png differ diff --git a/tests/_images/Shapes_datashader_can_render_with_colored_outline.png b/tests/_images/Shapes_datashader_can_render_with_colored_outline.png index 7ea92c52..819c7eb2 100644 Binary files a/tests/_images/Shapes_datashader_can_render_with_colored_outline.png and b/tests/_images/Shapes_datashader_can_render_with_colored_outline.png differ diff --git a/tests/_images/Shapes_datashader_can_render_with_diff_width_outline.png b/tests/_images/Shapes_datashader_can_render_with_diff_width_outline.png index e39c24f1..5673358e 100644 Binary files a/tests/_images/Shapes_datashader_can_render_with_diff_width_outline.png and b/tests/_images/Shapes_datashader_can_render_with_diff_width_outline.png differ diff --git a/tests/_images/Shapes_datashader_can_render_with_outline.png b/tests/_images/Shapes_datashader_can_render_with_outline.png index eaac5b07..061da1b7 100644 Binary files a/tests/_images/Shapes_datashader_can_render_with_outline.png and b/tests/_images/Shapes_datashader_can_render_with_outline.png differ diff --git a/tests/_images/Shapes_datashader_can_render_with_rgb_colored_outline.png b/tests/_images/Shapes_datashader_can_render_with_rgb_colored_outline.png index 3ea2f152..f083a8b5 100644 Binary files a/tests/_images/Shapes_datashader_can_render_with_rgb_colored_outline.png and b/tests/_images/Shapes_datashader_can_render_with_rgb_colored_outline.png differ diff --git a/tests/_images/Shapes_datashader_can_render_with_rgba_colored_outline.png b/tests/_images/Shapes_datashader_can_render_with_rgba_colored_outline.png index cef9f0eb..99b0a14f 100644 Binary files a/tests/_images/Shapes_datashader_can_render_with_rgba_colored_outline.png and b/tests/_images/Shapes_datashader_can_render_with_rgba_colored_outline.png differ diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index dcc4267c..78399916 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -1031,3 +1031,104 @@ def test_render_points_disjoint_instance_ids_clear_error(): sdata.pl.render_points("pts", color="cat", table_name="t").pl.show(ax=ax) finally: plt.close(fig) + + +def _make_offset_points_sdata(offset: tuple[float, float] = (10000.0, 18000.0), n: int = 100) -> SpatialData: + rng = np.random.default_rng(0) + df = pd.DataFrame( + { + "x": rng.uniform(0, 200, size=n), + "y": rng.uniform(0, 200, size=n), + } + ) + pts = PointsModel.parse(df, transformations={"global": Translation(list(offset), axes=("x", "y"))}) + return SpatialData(points={"pts": pts}) + + +def test_datashader_canvas_preserves_resolution_under_bbox_query(): + # regression test for #668: bounding_box_query on a translated SpatialData + # must not collapse the datashader canvas resolution. The factor (world + # units per canvas pixel) at offset must match the no-offset baseline. + from spatialdata import bounding_box_query + + from spatialdata_plot.pl.render_params import FigParams + from spatialdata_plot.pl.utils import _datashader_canvas_from_dataframe + + baseline_sdata = _make_offset_points_sdata(offset=(0.0, 0.0)) + offset_sdata = _make_offset_points_sdata(offset=(10000.0, 18000.0)) + + baseline_crop = bounding_box_query( + baseline_sdata["pts"], + axes=("x", "y"), + min_coordinate=[50, 50], + max_coordinate=[150, 150], + target_coordinate_system="global", + ) + offset_crop = bounding_box_query( + offset_sdata["pts"], + axes=("x", "y"), + min_coordinate=[10050, 18050], + max_coordinate=[10150, 18150], + target_coordinate_system="global", + ) + + fig = plt.figure(figsize=(6, 6), dpi=100) + fig_params = FigParams(fig=fig, ax=fig.gca(), num_panels=1) + try: + # Compare canvas params on the post-transform world-coord frame, which + # is what _render_points feeds to _datashader_canvas_from_dataframe. + base_df = baseline_crop.compute() + offset_df = offset_crop.compute() + offset_df_world = offset_df.copy() + offset_df_world["x"] = offset_df_world["x"] + 10000.0 + offset_df_world["y"] = offset_df_world["y"] + 18000.0 + + _, _, _, _, factor_baseline = _datashader_canvas_from_dataframe(base_df, fig_params) + _, _, _, _, factor_offset = _datashader_canvas_from_dataframe(offset_df_world, fig_params) + + # Without the fix factor_offset would be ~60x factor_baseline. + assert abs(factor_offset - factor_baseline) < 1e-6, ( + f"datashader factor leaked offset: baseline={factor_baseline}, offset={factor_offset}" + ) + finally: + plt.close(fig) + + +def test_render_points_datashader_under_bbox_query_does_not_crash(): + # regression test for #668: rendering a bbox_query result with datashader + # must produce a figure (previously: resolution collapse → effectively blank). + from spatialdata import bounding_box_query + + sdata = _make_offset_points_sdata() + cropped = bounding_box_query( + sdata["pts"], + axes=("x", "y"), + min_coordinate=[10050, 18050], + max_coordinate=[10150, 18150], + target_coordinate_system="global", + ) + cropped_sdata = SpatialData(points={"pts": cropped}) + + fig, ax = plt.subplots() + try: + cropped_sdata.pl.render_points("pts", method="datashader").pl.show(ax=ax) + finally: + plt.close(fig) + + +def test_datashader_canvas_from_empty_dataframe_does_not_crash(): + # regression test for #668: _datashader_canvas_from_dataframe used to + # crash with ``ValueError: cannot convert float NaN to integer`` when fed + # an empty DataFrame (NaN min()/max() → int cast). The helper now returns + # a zero-sized sentinel so callers can short-circuit cleanly. + from spatialdata_plot.pl.render_params import FigParams + from spatialdata_plot.pl.utils import _datashader_canvas_from_dataframe + + empty_df = pd.DataFrame({"x": pd.Series(dtype=float), "y": pd.Series(dtype=float)}) + fig = plt.figure(figsize=(6, 6), dpi=100) + fig_params = FigParams(fig=fig, ax=fig.gca(), num_panels=1) + try: + plot_width, plot_height, _, _, _ = _datashader_canvas_from_dataframe(empty_df, fig_params) + assert plot_width == 0 and plot_height == 0 + finally: + plt.close(fig) diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index a401fef8..f77f0ed8 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -1446,3 +1446,30 @@ def track(x): plt.close(fig) assert called, "transfunc was not called for continuous shapes data" + + +def test_render_shapes_datashader_under_bbox_query_does_not_crash(): + # regression test for #668: bounding_box_query on a translated SpatialData + # used to collapse the shapes datashader canvas resolution and produce a + # broken figure. + from spatialdata import bounding_box_query + + rng = np.random.default_rng(0) + geom = [Point(rng.uniform(0, 200), rng.uniform(0, 200)).buffer(5) for _ in range(20)] + gdf = gpd.GeoDataFrame({"geometry": geom}) + shapes = ShapesModel.parse(gdf, transformations={"global": Translation([10000.0, 18000.0], axes=("x", "y"))}) + + cropped = bounding_box_query( + shapes, + axes=("x", "y"), + min_coordinate=[10050, 18050], + max_coordinate=[10150, 18150], + target_coordinate_system="global", + ) + cropped_sdata = SpatialData(shapes={"shapes": cropped}) + + fig, ax = plt.subplots() + try: + cropped_sdata.pl.render_shapes("shapes", method="datashader").pl.show(ax=ax) + finally: + plt.close(fig) diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index 7f342ff7..0a277f05 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -458,3 +458,48 @@ def test_explicit_table_name_honored_when_element_has_same_column(): sdata.pl.render_shapes("s1", color="cat").pl.show(ax=ax) assert sorted(t.get_text() for t in ax.get_legend().get_texts()) == ["X", "Y"] plt.close(fig) + + +def test_rasterize_target_unit_to_pixels_uses_world_extent(monkeypatch): + # regression test for #668: _rasterize_if_necessary must compute + # target_unit_to_pixels in world units (target_px / world_unit), not in + # intrinsic pixels (target_px / source_px). For Scale=0.5 the two differ + # by a factor of 2. + from spatialdata.models import Image2DModel + from spatialdata.transformations import Scale, Sequence, Translation, set_transformation + + from spatialdata_plot.pl import utils as plut + + arr = np.zeros((3, 8000, 8000), dtype=np.uint8) # large enough to trigger rasterization + img = Image2DModel.parse(arr, dims=("c", "y", "x")) + set_transformation( + img, + Sequence([Scale([0.5, 0.5], axes=("x", "y")), Translation([100.0, 200.0], axes=("x", "y"))]), + to_coordinate_system="global", + ) + + captured: dict[str, float] = {} + + def fake_rasterize(*args, **kwargs): + captured["target_unit_to_pixels"] = kwargs["target_unit_to_pixels"] + return img # return is unused for the assertion + + monkeypatch.setattr(plut, "rasterize", fake_rasterize) + + # Source intrinsic is 8000 px; Scale=0.5 → world extent 4000 wu. + # Display target = dpi*width = 100*6 = 600 px. + # Expected (world-unit basis): 600 / 4000 = 0.15 + # Old (intrinsic basis): 600 / 8000 = 0.075 (wrong by 2x) + plut._rasterize_if_necessary( + image=img, + dpi=100, + width=6, + height=6, + coordinate_system="global", + extent={"x": (100.0, 4100.0), "y": (200.0, 4200.0)}, + ) + + assert "target_unit_to_pixels" in captured, "rasterize was not called" + assert captured["target_unit_to_pixels"] == pytest.approx(0.15, rel=1e-6), ( + f"Expected world-unit basis (0.15), got {captured['target_unit_to_pixels']}" + )