Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def render_shapes(
method: str | None = None,
table_name: str | None = None,
table_layer: str | None = None,
shape: Literal["circle", "hex", "square"] | None = None,
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None,
**kwargs: Any,
) -> sd.SpatialData:
"""
Expand Down Expand Up @@ -243,9 +243,11 @@ def render_shapes(
table_layer: str | None
Layer of the table to use for coloring if `color` is in :attr:`sdata.table.var_names`. If None, the data in
:attr:`sdata.table.X` is used for coloring.
shape: Literal["circle", "hex", "square"] | None
shape: Literal["circle", "hex", "visium_hex", "square"] | None
If None (default), the shapes are rendered as they are. Else, if either of "circle", "hex" or "square" is
specified, the shapes are converted to a circle/hexagon/square before rendering.
specified, the shapes are converted to a circle/hexagon/square before rendering. If "visium_hex" is
specified, the shapes are assumed to be Visium spots and the size of the hexagons is adjusted to be adjacent
to each other.

**kwargs : Any
Additional arguments for customization. This can include:
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class ShapesRenderParams:
zorder: int = 0
table_name: str | None = None
table_layer: str | None = None
shape: Literal["circle", "hex", "square"] | None = None
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None


Expand Down
74 changes: 68 additions & 6 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,7 +1802,9 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
if (norm := param_dict.get("norm")) is not None:
if element_type in {"images", "labels"} and not isinstance(norm, Normalize):
raise TypeError("Parameter 'norm' must be of type Normalize.")
if element_type in ["shapes", "points"] and not isinstance(norm, bool | Normalize):
if element_type in {"shapes", "points"} and not isinstance(
norm, bool | Normalize
):
raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.")

if (scale := param_dict.get("scale")) is not None:
Expand All @@ -1821,11 +1823,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
raise ValueError("Parameter 'size' must be a positive number.")

if element_type == "shapes" and (shape := param_dict.get("shape")) is not None:
valid_shapes = {"circle", "hex", "visium_hex", "square"}
if not isinstance(shape, str):
raise TypeError("Parameter 'shape' must be a String from ['circle', 'hex', 'square'] if not None.")
if shape not in ["circle", "hex", "square"]:
raise TypeError(
f"Parameter 'shape' must be a String from {valid_shapes} if not None."
)
if shape not in valid_shapes:
raise ValueError(
f"'{shape}' is not supported for 'shape', please choose from[None, 'circle', 'hex', 'square']."
f"'{shape}' is not supported for 'shape', please choose from {valid_shapes}."
)

table_name = param_dict.get("table_name")
Expand Down Expand Up @@ -2040,7 +2045,7 @@ def _validate_shape_render_params(
scale: float | int,
table_name: str | None,
table_layer: str | None,
shape: Literal["circle", "hex", "square"] | None,
shape: Literal["circle", "hex", "visium_hex", "square"] | None,
method: str | None,
ds_reduction: str | None,
) -> dict[str, dict[str, Any]]:
Expand Down Expand Up @@ -2647,9 +2652,10 @@ def _convert_shapes(

# define individual conversion methods
def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
# Create hexagon with point at top (30° offset from standard orientation)
vertices = [
(center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle)))
for angle in range(0, 360, 60)
for angle in range(30, 390, 60) # Start at 30° and go every 60°
]
return shapely.Polygon(vertices), None

Expand Down Expand Up @@ -2718,6 +2724,62 @@ def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely
"Polygon": _polygon_to_hexagon,
"Multipolygon": _multipolygon_to_hexagon,
}
elif target_shape == "visium_hex":
# For visium_hex, we only support Points and warn for other geometry types
point_centers = []
non_point_count = 0

for i in range(shapes.shape[0]):
if shapes["geometry"][i].type == "Point":
point_centers.append((shapes["geometry"][i].x, shapes["geometry"][i].y))
else:
non_point_count += 1

if non_point_count > 0:
warnings.warn(
f"visium_hex conversion only supports Point geometries. Found {non_point_count} non-Point geometries "
f"that will be converted using regular hex conversion. Consider using shape='hex' for mixed geometry types.",
UserWarning,
stacklevel=2,
)

if len(point_centers) < 2:
# If we have fewer than 2 points, fall back to regular hex conversion
conversion_methods = {
"Point": _circle_to_hexagon,
"Polygon": _polygon_to_hexagon,
"Multipolygon": _multipolygon_to_hexagon,
}
else:
# Calculate typical spacing between point centers
centers_array = np.array(point_centers)
distances = []
for i in range(len(point_centers)):
for j in range(i + 1, len(point_centers)):
dist = np.linalg.norm(centers_array[i] - centers_array[j])
distances.append(dist)

# Use min dist of closest neighbors as the side length for radius calc
side_length = np.min(distances)
hex_radius = (side_length * 2.0 / math.sqrt(3)) / 2.0

# Create conversion methods
def _circle_to_visium_hex(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
return _circle_to_hexagon(center, hex_radius)

def _polygon_to_visium_hex(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
# Fall back to regular hex conversion for non-points
return _polygon_to_hexagon(polygon)

def _multipolygon_to_visium_hex(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
# Fall back to regular hex conversion for non-points
return _multipolygon_to_hexagon(multipolygon)

conversion_methods = {
"Point": _circle_to_visium_hex,
"Polygon": _polygon_to_visium_hex,
"Multipolygon": _multipolygon_to_visium_hex,
}
else:
conversion_methods = {
"Point": _circle_to_square,
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from abc import ABC, ABCMeta
from collections.abc import Callable
from functools import wraps
Expand Down Expand Up @@ -525,3 +526,29 @@ def _get_sdata_with_multiple_images(share_coordinate_system: str = "all"):
return sdata

return _get_sdata_with_multiple_images


@pytest.fixture
def sdata_hexagonal_grid_spots():
"""Create a hexagonal grid of points for testing visium_hex functionality."""
from shapely.geometry import Point
from spatialdata.models import ShapesModel

spacing = 10.0
n_rows, n_cols = 4, 4

points = []
for i, j in itertools.product(range(n_rows), range(n_cols)):
# Offset every second row by half the spacing for proper hexagonal packing
x = j * spacing + (i % 2) * spacing / 2
y = i * spacing * 0.866 # sqrt(3)/2 for proper hexagonal spacing
points.append(Point(x, y))

# Create GeoDataFrame with radius column
gdf = GeoDataFrame(geometry=points)
gdf["radius"] = 2.0 # Small radius for original circles

# Use ShapesModel.parse() to create a properly validated GeoDataFrame
shapes_gdf = ShapesModel.parse(gdf)

return SpatialData(shapes={"spots": shapes_gdf})
38 changes: 28 additions & 10 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_plot_can_render_circles_with_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1).pl.show()

def test_plot_can_render_circles_with_colored_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", outline_color="red").pl.show()
sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1, outline_color="red").pl.show()

def test_plot_can_render_polygons(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons").pl.show()
Expand All @@ -49,13 +49,17 @@ def test_plot_can_render_polygons_with_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_alpha=1).pl.show()

def test_plot_can_render_polygons_with_str_colored_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_color="red").pl.show()
sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_alpha=1, outline_color="red").pl.show()

def test_plot_can_render_polygons_with_rgb_colored_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_color=(0.0, 0.0, 1.0, 1.0)).pl.show()
sdata_blobs.pl.render_shapes(
element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 0.0, 1.0, 1.0)
).pl.show()

def test_plot_can_render_polygons_with_rgba_colored_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_polygons", outline_color=(0.0, 1.0, 0.0, 1.0)).pl.show()
sdata_blobs.pl.render_shapes(
element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 1.0, 0.0, 1.0)
).pl.show()

def test_plot_can_render_empty_geometry(self, sdata_blobs: SpatialData):
sdata_blobs.shapes["blobs_circles"].at[0, "geometry"] = gpd.points_from_xy([None], [None])[0]
Expand All @@ -65,7 +69,7 @@ def test_plot_can_render_circles_with_default_outline_width(self, sdata_blobs: S
sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1).pl.show()

def test_plot_can_render_circles_with_specified_outline_width(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", outline_width=3.0).pl.show()
sdata_blobs.pl.render_shapes(element="blobs_circles", outline_alpha=1, outline_width=3.0).pl.show()

def test_plot_can_render_multipolygons(self):
def _make_multi():
Expand Down Expand Up @@ -402,19 +406,23 @@ def test_plot_datashader_can_render_with_diff_alpha_outline(self, sdata_blobs: S
sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_alpha=0.5).pl.show()

def test_plot_datashader_can_render_with_diff_width_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_width=5.0).pl.show()
sdata_blobs.pl.render_shapes(
method="datashader", element="blobs_polygons", outline_alpha=1.0, outline_width=5.0
).pl.show()

def test_plot_datashader_can_render_with_colored_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(method="datashader", element="blobs_polygons", outline_color="red").pl.show()
sdata_blobs.pl.render_shapes(
method="datashader", element="blobs_polygons", outline_alpha=1, outline_color="red"
).pl.show()

def test_plot_datashader_can_render_with_rgb_colored_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(
method="datashader", element="blobs_polygons", outline_color=(0.0, 0.0, 1.0)
method="datashader", element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 0.0, 1.0)
).pl.show()

def test_plot_datashader_can_render_with_rgba_colored_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(
method="datashader", element="blobs_polygons", outline_color=(0.0, 1.0, 0.0, 1.0)
method="datashader", element="blobs_polygons", outline_alpha=1, outline_color=(0.0, 1.0, 0.0, 1.0)
).pl.show()

def test_plot_can_set_clims_clip(self, sdata_blobs: SpatialData):
Expand Down Expand Up @@ -593,6 +601,12 @@ def test_plot_can_render_multipolygons_to_square(self, sdata_blobs: SpatialData)
def test_plot_can_render_multipolygons_to_circle(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="circle").pl.show()

def test_plot_visium_hex_hexagonal_grid(self, sdata_hexagonal_grid_spots: SpatialData):
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")

sdata_hexagonal_grid_spots.pl.render_shapes(element="spots", shape="circle").pl.show(ax=axs[0])
sdata_hexagonal_grid_spots.pl.render_shapes(element="spots", shape="visium_hex").pl.show(ax=axs[1])

def test_plot_datashader_can_render_circles_to_hex(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", shape="hex", method="datashader").pl.show()

Expand All @@ -616,6 +630,7 @@ def test_plot_datashader_can_render_multipolygons_to_square(self, sdata_blobs: S

def test_plot_datashader_can_render_multipolygons_to_circle(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_multipolygons", shape="circle", method="datashader").pl.show()

def test_plot_can_render_shapes_with_double_outline(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes("blobs_circles", outline_width=(10.0, 5.0)).pl.show()

Expand All @@ -631,7 +646,10 @@ def test_plot_can_render_double_outline_with_diff_alpha(self, sdata_blobs: Spati

def test_plot_outline_alpha_takes_precedence(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(
element="blobs_circles", outline_color=("#ff660033", "#33aa0066"), outline_width=(20, 10), outline_alpha=1.0
element="blobs_circles",
outline_color=("#ff660033", "#33aa0066"),
outline_width=(20, 10),
outline_alpha=(1.0, 1.0),
).pl.show()

def test_plot_datashader_can_render_shapes_with_double_outline(self, sdata_blobs: SpatialData):
Expand Down