From f4adf722eec2e094848413f06cd81ec4cdfe4bd8 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Sat, 15 Nov 2025 15:42:16 +0100 Subject: [PATCH 1/6] added test when user has invalid input --- src/spatialdata_plot/pl/utils.py | 98 +++++++++++++++++++++++++++++++- tests/pl/test_render_shapes.py | 44 ++++++++++++++ 2 files changed, 141 insertions(+), 1 deletion(-) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index cf21d212..c87949d4 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -66,7 +66,7 @@ ) from spatialdata._core.query.relational_query import _locate_value from spatialdata._types import ArrayLike -from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement +from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement, get_table_keys from spatialdata.transformations.operations import get_transformation from spatialdata.transformations.transformations import Scale from xarray import DataArray, DataTree @@ -794,6 +794,95 @@ def _get_colors_for_categorical_obs( return palette[:len_cat] # type: ignore[return-value] +def _format_element_names(element_name: list[str] | str | None) -> str: + if element_name is None: + return "the requested element" + if isinstance(element_name, str): + return f"'{element_name}'" + return ", ".join(f"'{name}'" for name in element_name) + + +def _preview_values(values: Sequence[Any], limit: int = 5) -> str: + values = list(values) + preview = ", ".join(map(str, values[:limit])) + if len(values) > limit: + preview += ", ..." + return preview + + +def _ensure_one_to_one_mapping( + sdata: SpatialData, + element: SpatialElement | None, + element_name: list[str] | str | None, + table_name: str | None, +) -> None: + if table_name is None or element_name is None: + return + + _validate_shape_index_uniqueness(element, element_name, table_name) + + table = sdata.get(table_name, None) + if table is None: + return + + _validate_table_instance_uniqueness(table, element_name, table_name) + + +def _validate_shape_index_uniqueness( + element: SpatialElement | None, + element_name: list[str] | str | None, + table_name: str, +) -> None: + if not isinstance(element, GeoDataFrame): + return + + duplicates = element.index[element.index.duplicated(keep=False)] + if duplicates.empty: + return + + element_label = _format_element_names(element_name) + preview = _preview_values(pd.Index(duplicates).unique()) + raise ValueError( + f"{element_label} contains duplicate index values ({preview}) while table '{table_name}' " + "requires a one-to-one mapping between shapes and annotations. " + "Please ensure each spatial element has a unique index." + ) + + +def _validate_table_instance_uniqueness( + table: AnnData, + element_name: list[str] | str | None, + table_name: str, +) -> None: + try: + _, region_key, instance_key = get_table_keys(table) + except (AttributeError, KeyError, ValueError): + return + + if instance_key is None or instance_key not in table.obs.columns: + return + + obs = table.obs + if region_key is not None and region_key in obs.columns and element_name is not None: + element_names = [element_name] if isinstance(element_name, str) else list(element_name) + obs = obs[obs[region_key].isin(element_names)] + + if obs.empty: + return + + duplicates_mask = obs[instance_key].duplicated(keep=False) + if not duplicates_mask.any(): + return + + element_label = _format_element_names(element_name) + preview = _preview_values(obs.loc[duplicates_mask, instance_key].astype(str).unique()) + raise ValueError( + f"Table '{table_name}' contains duplicate '{instance_key}' values for {element_label}: {preview}. " + "Each observation must annotate a single spatial element. Please deduplicate the table or subset it " + "before plotting." + ) + + def _set_color_source_vec( sdata: sd.SpatialData, element: SpatialElement | None, @@ -826,6 +915,13 @@ def _set_color_source_vec( ) if len(origins) == 1: + if table_name is not None and value_to_plot is not None: + _ensure_one_to_one_mapping( + sdata=sdata, + element=element, + element_name=element_name, + table_name=table_name, + ) color_source_vector = get_values( value_key=value_to_plot, sdata=sdata, diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 8fbc5d4c..8885a5e4 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -154,6 +154,50 @@ def test_plot_colorbar_can_be_normalised(self, sdata_blobs: SpatialData): norm = Normalize(vmin=0, vmax=5, clip=True) sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups=["c1"], norm=norm).pl.show() + def test_render_shapes_duplicate_shape_indices_error(self, sdata_blobs: SpatialData): + element = "blobs_polygons" + shapes = sdata_blobs.shapes[element] + n_shapes = len(shapes) + rng = get_standard_RNG() + adata = AnnData(rng.normal(size=(n_shapes, 3))) + adata.obs["annotation"] = rng.choice(["a", "b"], size=n_shapes) + adata.obs["instance_id"] = [f"id_{i}" for i in range(n_shapes)] + table = TableModel.parse(adata=adata, region=element, instance_key="instance_id") + sdata_blobs["table"] = table + instance_key = table.uns["spatialdata_attrs"]["instance_key"] + shapes.index = table.obs[instance_key].tolist() + duplicated_index = shapes.index.to_list() + duplicated_index[1] = duplicated_index[0] + shapes.index = duplicated_index + + with pytest.raises(ValueError, match="duplicate index values"): + sdata_blobs.pl.render_shapes( + element=element, + color="annotation", + table_name="table", + ).pl.show() + + def test_render_shapes_duplicate_table_rows_error(self, sdata_blobs: SpatialData): + element = "blobs_polygons" + shapes = sdata_blobs.shapes[element] + n_shapes = len(shapes) + rng = get_standard_RNG() + adata = AnnData(rng.normal(size=(n_shapes, 3))) + adata.obs["annotation"] = rng.choice(["a", "b"], size=n_shapes) + adata.obs["instance_id"] = [f"id_{i}" for i in range(n_shapes)] + table = TableModel.parse(adata=adata, region=element, instance_key="instance_id") + instance_key = table.uns["spatialdata_attrs"]["instance_key"] + table.obs.at[table.obs.index[1], instance_key] = table.obs.at[table.obs.index[0], instance_key] + sdata_blobs["table"] = table + shapes.index = table.obs[instance_key].tolist() + + with pytest.raises(ValueError, match="duplicate 'instance"): + sdata_blobs.pl.render_shapes( + element=element, + color="annotation", + table_name="table", + ).pl.show() + def test_plot_can_plot_shapes_after_spatial_query(self, sdata_blobs: SpatialData): # subset to only shapes, should be unnecessary after rasterizeation of multiscale images is included blob = SpatialData.init_from_elements( From 72f122abe42ae92d67b10553b698e2503548ee82 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Sat, 15 Nov 2025 15:48:05 +0100 Subject: [PATCH 2/6] modified tests --- tests/pl/test_render_shapes.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 8885a5e4..b12139ac 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -162,7 +162,8 @@ def test_render_shapes_duplicate_shape_indices_error(self, sdata_blobs: SpatialD adata = AnnData(rng.normal(size=(n_shapes, 3))) adata.obs["annotation"] = rng.choice(["a", "b"], size=n_shapes) adata.obs["instance_id"] = [f"id_{i}" for i in range(n_shapes)] - table = TableModel.parse(adata=adata, region=element, instance_key="instance_id") + adata.obs["region"] = pd.Categorical([element] * n_shapes) + table = TableModel.parse(adata=adata, region=element, region_key="region", instance_key="instance_id") sdata_blobs["table"] = table instance_key = table.uns["spatialdata_attrs"]["instance_key"] shapes.index = table.obs[instance_key].tolist() @@ -185,7 +186,8 @@ def test_render_shapes_duplicate_table_rows_error(self, sdata_blobs: SpatialData adata = AnnData(rng.normal(size=(n_shapes, 3))) adata.obs["annotation"] = rng.choice(["a", "b"], size=n_shapes) adata.obs["instance_id"] = [f"id_{i}" for i in range(n_shapes)] - table = TableModel.parse(adata=adata, region=element, instance_key="instance_id") + adata.obs["region"] = pd.Categorical([element] * n_shapes) + table = TableModel.parse(adata=adata, region=element, region_key="region", instance_key="instance_id") instance_key = table.uns["spatialdata_attrs"]["instance_key"] table.obs.at[table.obs.index[1], instance_key] = table.obs.at[table.obs.index[0], instance_key] sdata_blobs["table"] = table From 355d3d85657328e246a50a9e31e38d92b86c5d4b Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Sat, 15 Nov 2025 15:54:36 +0100 Subject: [PATCH 3/6] adjusted test --- tests/pl/test_render_shapes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index b12139ac..7dd34c6e 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -189,9 +189,10 @@ def test_render_shapes_duplicate_table_rows_error(self, sdata_blobs: SpatialData adata.obs["region"] = pd.Categorical([element] * n_shapes) table = TableModel.parse(adata=adata, region=element, region_key="region", instance_key="instance_id") instance_key = table.uns["spatialdata_attrs"]["instance_key"] + original_instance_ids = table.obs[instance_key].tolist() table.obs.at[table.obs.index[1], instance_key] = table.obs.at[table.obs.index[0], instance_key] sdata_blobs["table"] = table - shapes.index = table.obs[instance_key].tolist() + shapes.index = original_instance_ids with pytest.raises(ValueError, match="duplicate 'instance"): sdata_blobs.pl.render_shapes( From c211b5c57dca49d6df163c4fda4190d8abe7ca91 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Sat, 15 Nov 2025 16:00:55 +0100 Subject: [PATCH 4/6] adjusted test --- tests/pl/test_render_shapes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 7dd34c6e..a35fbd80 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -183,16 +183,16 @@ def test_render_shapes_duplicate_table_rows_error(self, sdata_blobs: SpatialData shapes = sdata_blobs.shapes[element] n_shapes = len(shapes) rng = get_standard_RNG() + shape_ids = [f"shape_{i}" for i in range(n_shapes)] + shapes.index = shape_ids adata = AnnData(rng.normal(size=(n_shapes, 3))) adata.obs["annotation"] = rng.choice(["a", "b"], size=n_shapes) - adata.obs["instance_id"] = [f"id_{i}" for i in range(n_shapes)] + adata.obs["instance_id"] = shape_ids adata.obs["region"] = pd.Categorical([element] * n_shapes) table = TableModel.parse(adata=adata, region=element, region_key="region", instance_key="instance_id") instance_key = table.uns["spatialdata_attrs"]["instance_key"] - original_instance_ids = table.obs[instance_key].tolist() table.obs.at[table.obs.index[1], instance_key] = table.obs.at[table.obs.index[0], instance_key] sdata_blobs["table"] = table - shapes.index = original_instance_ids with pytest.raises(ValueError, match="duplicate 'instance"): sdata_blobs.pl.render_shapes( From 4d953c56b8adeebd69a215039c43e2c3f70e8637 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Sat, 15 Nov 2025 16:12:40 +0100 Subject: [PATCH 5/6] adjusted test --- src/spatialdata_plot/pl/utils.py | 3 +-- tests/pl/test_render_shapes.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index c87949d4..d8e34378 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -819,13 +819,12 @@ def _ensure_one_to_one_mapping( if table_name is None or element_name is None: return - _validate_shape_index_uniqueness(element, element_name, table_name) - table = sdata.get(table_name, None) if table is None: return _validate_table_instance_uniqueness(table, element_name, table_name) + _validate_shape_index_uniqueness(element, element_name, table_name) def _validate_shape_index_uniqueness( diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index a35fbd80..7a2e4891 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -156,7 +156,7 @@ def test_plot_colorbar_can_be_normalised(self, sdata_blobs: SpatialData): def test_render_shapes_duplicate_shape_indices_error(self, sdata_blobs: SpatialData): element = "blobs_polygons" - shapes = sdata_blobs.shapes[element] + shapes = sdata_blobs.shapes[element].copy() n_shapes = len(shapes) rng = get_standard_RNG() adata = AnnData(rng.normal(size=(n_shapes, 3))) @@ -185,6 +185,7 @@ def test_render_shapes_duplicate_table_rows_error(self, sdata_blobs: SpatialData rng = get_standard_RNG() shape_ids = [f"shape_{i}" for i in range(n_shapes)] shapes.index = shape_ids + sdata_blobs.shapes[element] = shapes adata = AnnData(rng.normal(size=(n_shapes, 3))) adata.obs["annotation"] = rng.choice(["a", "b"], size=n_shapes) adata.obs["instance_id"] = shape_ids From f08f8f5ce82151e79f98ff85d4179ea491f91500 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Sat, 15 Nov 2025 16:19:21 +0100 Subject: [PATCH 6/6] adjusted test --- tests/pl/test_render_shapes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 7a2e4891..cc65673f 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -170,6 +170,7 @@ def test_render_shapes_duplicate_shape_indices_error(self, sdata_blobs: SpatialD duplicated_index = shapes.index.to_list() duplicated_index[1] = duplicated_index[0] shapes.index = duplicated_index + sdata_blobs.shapes[element] = shapes with pytest.raises(ValueError, match="duplicate index values"): sdata_blobs.pl.render_shapes(