diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index a948077f..883f51ce 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -170,7 +170,7 @@ def render_shapes( method: str | None = None, table_name: str | None = None, table_layer: str | None = None, - shape: Literal["circle", "hex", "square"] | None = None, + shape: Literal["circle", "hex", "visium_hex", "square"] | None = None, **kwargs: Any, ) -> sd.SpatialData: """ @@ -243,9 +243,11 @@ def render_shapes( table_layer: str | None Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in :attr:`sdata.table.X` is used for coloring. - shape: Literal["circle", "hex", "square"] | None + shape: Literal["circle", "hex", "visium_hex", "square"] | None If None (default), the shapes are rendered as they are. Else, if either of "circle", "hex" or "square" is - specified, the shapes are converted to a circle/hexagon/square before rendering. + specified, the shapes are converted to a circle/hexagon/square before rendering. If "visium_hex" is + specified, the shapes are assumed to be Visium spots and the size of the hexagons is adjusted to be adjacent + to each other. **kwargs : Any Additional arguments for customization. This can include: diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index 15812c0c..a8382753 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -211,7 +211,7 @@ class ShapesRenderParams: zorder: int = 0 table_name: str | None = None table_layer: str | None = None - shape: Literal["circle", "hex", "square"] | None = None + shape: Literal["circle", "hex", "visium_hex", "square"] | None = None ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index ff1f8daa..82b84973 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -1802,7 +1802,9 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if (norm := param_dict.get("norm")) is not None: if element_type in {"images", "labels"} and not isinstance(norm, Normalize): raise TypeError("Parameter 'norm' must be of type Normalize.") - if element_type in ["shapes", "points"] and not isinstance(norm, bool | Normalize): + if element_type in {"shapes", "points"} and not isinstance( + norm, bool | Normalize + ): raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.") if (scale := param_dict.get("scale")) is not None: @@ -1821,11 +1823,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st raise ValueError("Parameter 'size' must be a positive number.") if element_type == "shapes" and (shape := param_dict.get("shape")) is not None: + valid_shapes = {"circle", "hex", "visium_hex", "square"} if not isinstance(shape, str): - raise TypeError("Parameter 'shape' must be a String from ['circle', 'hex', 'square'] if not None.") - if shape not in ["circle", "hex", "square"]: + raise TypeError( + f"Parameter 'shape' must be a String from {valid_shapes} if not None." + ) + if shape not in valid_shapes: raise ValueError( - f"'{shape}' is not supported for 'shape', please choose from[None, 'circle', 'hex', 'square']." + f"'{shape}' is not supported for 'shape', please choose from {valid_shapes}." ) table_name = param_dict.get("table_name") @@ -2040,7 +2045,7 @@ def _validate_shape_render_params( scale: float | int, table_name: str | None, table_layer: str | None, - shape: Literal["circle", "hex", "square"] | None, + shape: Literal["circle", "hex", "visium_hex", "square"] | None, method: str | None, ds_reduction: str | None, ) -> dict[str, dict[str, Any]]: @@ -2647,9 +2652,10 @@ def _convert_shapes( # define individual conversion methods def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: + # Create hexagon with point at top (30° offset from standard orientation) vertices = [ (center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle))) - for angle in range(0, 360, 60) + for angle in range(30, 390, 60) # Start at 30° and go every 60° ] return shapely.Polygon(vertices), None @@ -2718,6 +2724,62 @@ def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely "Polygon": _polygon_to_hexagon, "Multipolygon": _multipolygon_to_hexagon, } + elif target_shape == "visium_hex": + # For visium_hex, we only support Points and warn for other geometry types + point_centers = [] + non_point_count = 0 + + for i in range(shapes.shape[0]): + if shapes["geometry"][i].type == "Point": + point_centers.append((shapes["geometry"][i].x, shapes["geometry"][i].y)) + else: + non_point_count += 1 + + if non_point_count > 0: + warnings.warn( + f"visium_hex conversion only supports Point geometries. Found {non_point_count} non-Point geometries " + f"that will be converted using regular hex conversion. Consider using shape='hex' for mixed geometry types.", + UserWarning, + stacklevel=2, + ) + + if len(point_centers) < 2: + # If we have fewer than 2 points, fall back to regular hex conversion + conversion_methods = { + "Point": _circle_to_hexagon, + "Polygon": _polygon_to_hexagon, + "Multipolygon": _multipolygon_to_hexagon, + } + else: + # Calculate typical spacing between point centers + centers_array = np.array(point_centers) + distances = [] + for i in range(len(point_centers)): + for j in range(i + 1, len(point_centers)): + dist = np.linalg.norm(centers_array[i] - centers_array[j]) + distances.append(dist) + + # Use min dist of closest neighbors as the side length for radius calc + side_length = np.min(distances) + hex_radius = (side_length * 2.0 / math.sqrt(3)) / 2.0 + + # Create conversion methods + def _circle_to_visium_hex(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: + return _circle_to_hexagon(center, hex_radius) + + def _polygon_to_visium_hex(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: + # Fall back to regular hex conversion for non-points + return _polygon_to_hexagon(polygon) + + def _multipolygon_to_visium_hex(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: + # Fall back to regular hex conversion for non-points + return _multipolygon_to_hexagon(multipolygon) + + conversion_methods = { + "Point": _circle_to_visium_hex, + "Polygon": _polygon_to_visium_hex, + "Multipolygon": _multipolygon_to_visium_hex, + } else: conversion_methods = { "Point": _circle_to_square, diff --git a/tests/_images/Shapes_visium_hex_hexagonal_grid.png b/tests/_images/Shapes_visium_hex_hexagonal_grid.png new file mode 100644 index 00000000..b6c59369 Binary files /dev/null and b/tests/_images/Shapes_visium_hex_hexagonal_grid.png differ diff --git a/tests/conftest.py b/tests/conftest.py index 896eaeb7..61c66573 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import itertools from abc import ABC, ABCMeta from collections.abc import Callable from functools import wraps @@ -525,3 +526,29 @@ def _get_sdata_with_multiple_images(share_coordinate_system: str = "all"): return sdata return _get_sdata_with_multiple_images + + +@pytest.fixture +def sdata_hexagonal_grid_spots(): + """Create a hexagonal grid of points for testing visium_hex functionality.""" + from shapely.geometry import Point + from spatialdata.models import ShapesModel + + spacing = 10.0 + n_rows, n_cols = 4, 4 + + points = [] + for i, j in itertools.product(range(n_rows), range(n_cols)): + # Offset every second row by half the spacing for proper hexagonal packing + x = j * spacing + (i % 2) * spacing / 2 + y = i * spacing * 0.866 # sqrt(3)/2 for proper hexagonal spacing + points.append(Point(x, y)) + + # Create GeoDataFrame with radius column + gdf = GeoDataFrame(geometry=points) + gdf["radius"] = 2.0 # Small radius for original circles + + # Use ShapesModel.parse() to create a properly validated GeoDataFrame + shapes_gdf = ShapesModel.parse(gdf) + + return SpatialData(shapes={"spots": shapes_gdf}) \ No newline at end of file diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index fb676386..1a6dc3bb 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -40,7 +40,7 @@ def test_plot_can_render_circles_with_outline(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1).pl.show() def test_plot_can_render_circles_with_colored_outline(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_shapes(element="blobs_circles", outline_color="red").pl.show() + sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1, outline_color="red").pl.show() def test_plot_can_render_polygons(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes(element="blobs_polygons").pl.show() @@ -49,13 +49,17 @@ def test_plot_can_render_polygons_with_outline(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_alpha=1).pl.show() def test_plot_can_render_polygons_with_str_colored_outline(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_color="red").pl.show() + sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_alpha=1, outline_color="red").pl.show() def test_plot_can_render_polygons_with_rgb_colored_outline(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_color=(0.0, 0.0, 1.0, 1.0)).pl.show() + sdata_blobs.pl.render_shapes( + element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 0.0, 1.0, 1.0) + ).pl.show() def test_plot_can_render_polygons_with_rgba_colored_outline(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_color=(0.0, 1.0, 0.0, 1.0)).pl.show() + sdata_blobs.pl.render_shapes( + element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 1.0, 0.0, 1.0) + ).pl.show() def test_plot_can_render_empty_geometry(self, sdata_blobs: SpatialData): sdata_blobs.shapes["blobs_circles"].at[0, "geometry"] = gpd.points_from_xy([None], [None])[0] @@ -65,7 +69,7 @@ def test_plot_can_render_circles_with_default_outline_width(self, sdata_blobs: S sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1).pl.show() def test_plot_can_render_circles_with_specified_outline_width(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_shapes(element="blobs_circles", outline_width=3.0).pl.show() + sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1, outline_width=3.0).pl.show() def test_plot_can_render_multipolygons(self): def _make_multi(): @@ -402,19 +406,23 @@ def test_plot_datashader_can_render_with_diff_alpha_outline(self, sdata_blobs: S sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_alpha=0.5).pl.show() def test_plot_datashader_can_render_with_diff_width_outline(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_width=5.0).pl.show() + sdata_blobs.pl.render_shapes( + method="datashader", element="blobs_polygons", outline_alpha=1.0, outline_width=5.0 + ).pl.show() def test_plot_datashader_can_render_with_colored_outline(self, sdata_blobs: SpatialData): - sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_color="red").pl.show() + sdata_blobs.pl.render_shapes( + method="datashader", element="blobs_polygons", outline_alpha=1, outline_color="red" + ).pl.show() def test_plot_datashader_can_render_with_rgb_colored_outline(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes( - method="datashader", element="blobs_polygons", outline_color=(0.0, 0.0, 1.0) + method="datashader", element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 0.0, 1.0) ).pl.show() def test_plot_datashader_can_render_with_rgba_colored_outline(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes( - method="datashader", element="blobs_polygons", outline_color=(0.0, 1.0, 0.0, 1.0) + method="datashader", element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 1.0, 0.0, 1.0) ).pl.show() def test_plot_can_set_clims_clip(self, sdata_blobs: SpatialData): @@ -593,6 +601,12 @@ def test_plot_can_render_multipolygons_to_square(self, sdata_blobs: SpatialData) def test_plot_can_render_multipolygons_to_circle(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="circle").pl.show() + def test_plot_visium_hex_hexagonal_grid(self, sdata_hexagonal_grid_spots: SpatialData): + _, axs = plt.subplots(nrows=1, ncols=2, layout="tight") + + sdata_hexagonal_grid_spots.pl.render_shapes(element="spots", shape="circle").pl.show(ax=axs[0]) + sdata_hexagonal_grid_spots.pl.render_shapes(element="spots", shape="visium_hex").pl.show(ax=axs[1]) + def test_plot_datashader_can_render_circles_to_hex(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes(element="blobs_circles", shape="hex", method="datashader").pl.show() @@ -616,6 +630,7 @@ def test_plot_datashader_can_render_multipolygons_to_square(self, sdata_blobs: S def test_plot_datashader_can_render_multipolygons_to_circle(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="circle", method="datashader").pl.show() + def test_plot_can_render_shapes_with_double_outline(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes("blobs_circles", outline_width=(10.0, 5.0)).pl.show() @@ -631,7 +646,10 @@ def test_plot_can_render_double_outline_with_diff_alpha(self, sdata_blobs: Spati def test_plot_outline_alpha_takes_precedence(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes( - element="blobs_circles", outline_color=("#ff660033", "#33aa0066"), outline_width=(20, 10), outline_alpha=1.0 + element="blobs_circles", + outline_color=("#ff660033", "#33aa0066"), + outline_width=(20, 10), + outline_alpha=(1.0, 1.0), ).pl.show() def test_plot_datashader_can_render_shapes_with_double_outline(self, sdata_blobs: SpatialData):