diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 5fa70833..954afe40 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -65,7 +65,7 @@ ) from spatialdata._core.query.relational_query import _locate_value from spatialdata._types import ArrayLike -from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement +from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement, get_table_keys from spatialdata.transformations.operations import get_transformation from spatialdata.transformations.transformations import Scale from xarray import DataArray, DataTree @@ -794,6 +794,14 @@ def _get_colors_for_categorical_obs( return palette[:len_cat] # type: ignore[return-value] +def _format_element_names(element_name: list[str] | str | None) -> str: + if element_name is None: + return "the requested element" + if isinstance(element_name, str): + return f"'{element_name}'" + return ", ".join(f"'{name}'" for name in element_name) + + def _format_element_name(element_name: list[str] | str | None) -> str: if isinstance(element_name, str): return element_name @@ -802,6 +810,86 @@ def _format_element_name(element_name: list[str] | str | None) -> str: return "" +def _preview_values(values: Sequence[Any], limit: int = 5) -> str: + values = list(values) + preview = ", ".join(map(str, values[:limit])) + if len(values) > limit: + preview += ", ..." + return preview + + +def _ensure_one_to_one_mapping( + sdata: SpatialData, + element: SpatialElement | None, + element_name: list[str] | str | None, + table_name: str | None, +) -> None: + if table_name is None or element_name is None: + return + + table = sdata.get(table_name, None) + if table is None: + return + + _validate_table_instance_uniqueness(table, element_name, table_name) + _validate_shape_index_uniqueness(element, element_name, table_name) + + +def _validate_shape_index_uniqueness( + element: SpatialElement | None, + element_name: list[str] | str | None, + table_name: str, +) -> None: + if not isinstance(element, GeoDataFrame): + return + + duplicates = element.index[element.index.duplicated(keep=False)] + if duplicates.empty: + return + + element_label = _format_element_names(element_name) + preview = _preview_values(pd.Index(duplicates).unique()) + raise ValueError( + f"{element_label} contains duplicate index values ({preview}) while table '{table_name}' " + "requires a one-to-one mapping between shapes and annotations. " + "Please ensure each spatial element has a unique index." + ) + + +def _validate_table_instance_uniqueness( + table: AnnData, + element_name: list[str] | str | None, + table_name: str, +) -> None: + try: + _, region_key, instance_key = get_table_keys(table) + except (AttributeError, KeyError, ValueError): + return + + if instance_key is None or instance_key not in table.obs.columns: + return + + obs = table.obs + if region_key is not None and region_key in obs.columns and element_name is not None: + element_names = [element_name] if isinstance(element_name, str) else list(element_name) + obs = obs[obs[region_key].isin(element_names)] + + if obs.empty: + return + + duplicates_mask = obs[instance_key].duplicated(keep=False) + if not duplicates_mask.any(): + return + + element_label = _format_element_names(element_name) + preview = _preview_values(obs.loc[duplicates_mask, instance_key].astype(str).unique()) + raise ValueError( + f"Table '{table_name}' contains duplicate '{instance_key}' values for {element_label}: {preview}. " + "Each observation must annotate a single spatial element. Please deduplicate the table or subset it " + "before plotting." + ) + + def _infer_color_data_kind( series: pd.Series, value_to_plot: str, @@ -885,6 +973,13 @@ def _set_color_source_vec( ) if len(origins) == 1 and value_to_plot is not None: + if table_name is not None: + _ensure_one_to_one_mapping( + sdata=sdata, + element=element, + element_name=element_name, + table_name=table_name, + ) color_source_vector = get_values( value_key=value_to_plot, sdata=sdata, diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 6afd8e5a..1acf3d3c 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -209,6 +209,55 @@ def test_plot_colorbar_can_be_normalised(self, sdata_blobs: SpatialData): norm = Normalize(vmin=0, vmax=5, clip=True) sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups=["c1"], norm=norm).pl.show() + def test_render_shapes_duplicate_shape_indices_error(self, sdata_blobs: SpatialData): + element = "blobs_polygons" + shapes = sdata_blobs.shapes[element].copy() + n_shapes = len(shapes) + rng = get_standard_RNG() + adata = AnnData(rng.normal(size=(n_shapes, 3))) + adata.obs["annotation"] = rng.choice(["a", "b"], size=n_shapes) + adata.obs["instance_id"] = [f"id_{i}" for i in range(n_shapes)] + adata.obs["region"] = pd.Categorical([element] * n_shapes) + table = TableModel.parse(adata=adata, region=element, region_key="region", instance_key="instance_id") + sdata_blobs["table"] = table + instance_key = table.uns["spatialdata_attrs"]["instance_key"] + shapes.index = table.obs[instance_key].tolist() + duplicated_index = shapes.index.to_list() + duplicated_index[1] = duplicated_index[0] + shapes.index = duplicated_index + sdata_blobs.shapes[element] = shapes + + with pytest.raises(ValueError, match="duplicate index values"): + sdata_blobs.pl.render_shapes( + element=element, + color="annotation", + table_name="table", + ).pl.show() + + def test_render_shapes_duplicate_table_rows_error(self, sdata_blobs: SpatialData): + element = "blobs_polygons" + shapes = sdata_blobs.shapes[element] + n_shapes = len(shapes) + rng = get_standard_RNG() + shape_ids = [f"shape_{i}" for i in range(n_shapes)] + shapes.index = shape_ids + sdata_blobs.shapes[element] = shapes + adata = AnnData(rng.normal(size=(n_shapes, 3))) + adata.obs["annotation"] = rng.choice(["a", "b"], size=n_shapes) + adata.obs["instance_id"] = shape_ids + adata.obs["region"] = pd.Categorical([element] * n_shapes) + table = TableModel.parse(adata=adata, region=element, region_key="region", instance_key="instance_id") + instance_key = table.uns["spatialdata_attrs"]["instance_key"] + table.obs.at[table.obs.index[1], instance_key] = table.obs.at[table.obs.index[0], instance_key] + sdata_blobs["table"] = table + + with pytest.raises(ValueError, match="duplicate 'instance"): + sdata_blobs.pl.render_shapes( + element=element, + color="annotation", + table_name="table", + ).pl.show() + def test_render_shapes_raises_when_color_key_missing(self, sdata_blobs_shapes_annotated: SpatialData): missing_col = "__non_existent_column__" with pytest.raises(KeyError, match=f"Unable to locate color key '{missing_col}'"):