diff --git a/src/spatialdata_plot/_logging.py b/src/spatialdata_plot/_logging.py index be1cf5f7..364cba27 100644 --- a/src/spatialdata_plot/_logging.py +++ b/src/spatialdata_plot/_logging.py @@ -1,6 +1,13 @@ # from https://github.com/scverse/spatialdata/blob/main/src/spatialdata/_logging.py import logging +import re +from collections.abc import Iterator +from contextlib import contextmanager +from typing import TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from _pytest.logging import LogCaptureFixture def _setup_logger() -> "logging.Logger": @@ -21,3 +28,44 @@ def _setup_logger() -> "logging.Logger": logger = _setup_logger() + + +@contextmanager +def logger_warns( + caplog: "LogCaptureFixture", + logger: logging.Logger, + match: str | None = None, + level: int = logging.WARNING, +) -> Iterator[None]: + """ + Context manager similar to pytest.warns, but for logging.Logger. + + Usage: + with logger_warns(caplog, logger, match="Found 1 NaN"): + call_code_that_logs() + """ + # Store initial record count to only check new records + initial_record_count = len(caplog.records) + + # Add caplog's handler directly to the logger to capture logs even if propagate=False + handler = caplog.handler + logger.addHandler(handler) + original_level = logger.level + logger.setLevel(level) + + # Use caplog.at_level to ensure proper capture setup + with caplog.at_level(level, logger=logger.name): + try: + yield + finally: + logger.removeHandler(handler) + logger.setLevel(original_level) + + # Only check records that were added during this context + records = [r for r in caplog.records[initial_record_count:] if r.levelno >= level] + + if match is not None: + pattern = re.compile(match) + if not any(pattern.search(r.getMessage()) for r in records): + msgs = [r.getMessage() for r in records] + raise AssertionError(f"Did not find log matching {match!r} in records: {msgs!r}") diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 396efba9..9d409ff5 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from collections import abc from copy import copy @@ -49,7 +48,6 @@ _get_extent_and_range_for_datashader_canvas, _get_linear_colormap, _hex_no_alpha, - _is_coercable_to_float, _map_color_seg, _maybe_set_colors, _mpl_ax_contains_elements, @@ -94,20 +92,7 @@ def _render_shapes( ) sdata_filt[table_name] = table = joined_table - if ( - col_for_color is not None - and table_name is not None - and col_for_color in sdata_filt[table_name].obs.columns - and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O" - and not _is_coercable_to_float(color_col) - ): - warnings.warn( - f"Converting copy of '{col_for_color}' column to categorical dtype for categorical plotting. " - f"Consider converting before plotting.", - UserWarning, - stacklevel=2, - ) - sdata_filt[table_name].obs[col_for_color] = sdata_filt[table_name].obs[col_for_color].astype("category") + shapes = sdata_filt[element] # get color vector (categorical or continuous) color_source_vector, color_vector, _ = _set_color_source_vec( @@ -121,6 +106,7 @@ def _render_shapes( cmap_params=render_params.cmap_params, table_name=table_name, table_layer=table_layer, + coordinate_system=coordinate_system, ) values_are_categorical = color_source_vector is not None @@ -144,12 +130,25 @@ def _render_shapes( # continuous case: leave NaNs as NaNs; utils maps them to na_color during draw if color_source_vector is None and not values_are_categorical: - color_vector = np.asarray(color_vector, dtype=float) - if np.isnan(color_vector).any(): - nan_count = int(np.isnan(color_vector).sum()) - msg = f"Found {nan_count} NaN values in color data. These observations will be colored with the 'na_color'." - warnings.warn(msg, UserWarning, stacklevel=2) - logger.warning(msg) + _series = color_vector if isinstance(color_vector, pd.Series) else pd.Series(color_vector) + + try: + color_vector = np.asarray(_series, dtype=float) + except (TypeError, ValueError): + nan_count = int(_series.isna().sum()) + if nan_count: + logger.warning( + f"Found {nan_count} NaN values in color data. " + "These observations will be colored with the 'na_color'." + ) + color_vector = _series.to_numpy() + else: + if np.isnan(color_vector).any(): + nan_count = int(np.isnan(color_vector).sum()) + logger.warning( + f"Found {nan_count} NaN values in color data. " + "These observations will be colored with the 'na_color'." + ) # Using dict.fromkeys here since set returns in arbitrary order # remove the color of NaN values, else it might be assigned to a category @@ -476,10 +475,33 @@ def _render_shapes( if not values_are_categorical: vmin = render_params.cmap_params.norm.vmin vmax = render_params.cmap_params.norm.vmax - if vmin is None: - vmin = float(np.nanmin(color_vector)) - if vmax is None: - vmax = float(np.nanmax(color_vector)) + if vmin is None or vmax is None: + # Extract numeric values only (filter out strings and other non-numeric types) + if isinstance(color_vector, np.ndarray): + if np.issubdtype(color_vector.dtype, np.number): + # Already numeric, can use directly + numeric_values = color_vector + else: + # Mixed types - extract only numeric values using pandas + numeric_values = pd.to_numeric(color_vector, errors="coerce") + numeric_values = numeric_values[np.isfinite(numeric_values)] + if len(numeric_values) > 0: + if vmin is None: + vmin = float(np.nanmin(numeric_values)) + if vmax is None: + vmax = float(np.nanmax(numeric_values)) + else: + # No numeric values found, use defaults + if vmin is None: + vmin = 0.0 + if vmax is None: + vmax = 1.0 + else: + # Not a numpy array, use defaults + if vmin is None: + vmin = 0.0 + if vmax is None: + vmax = 1.0 _cax.set_clim(vmin=vmin, vmax=vmax) if ( @@ -541,11 +563,9 @@ def _render_points( coords = ["x", "y"] if table_name is not None and col_for_color not in points.columns: - warnings.warn( + logger.warning( f"Annotating points with {col_for_color} which is stored in the table `{table_name}`. " - f"To improve performance, it is advisable to store point annotations directly in the .parquet file.", - UserWarning, - stacklevel=2, + f"To improve performance, it is advisable to store point annotations directly in the .parquet file." ) if col_for_color is None or ( @@ -553,19 +573,6 @@ def _render_points( and (col_for_color in sdata_filt[table_name].obs.columns or col_for_color in sdata_filt[table_name].var_names) ): points = points[coords].compute() - if ( - col_for_color - and col_for_color in sdata_filt[table_name].obs.columns - and (color_col := sdata_filt[table_name].obs[col_for_color]).dtype == "O" - and not _is_coercable_to_float(color_col) - ): - warnings.warn( - f"Converting copy of '{col_for_color}' column to categorical dtype for categorical " - f"plotting. Consider converting before plotting.", - UserWarning, - stacklevel=2, - ) - sdata_filt[table_name].obs[col_for_color] = sdata_filt[table_name].obs[col_for_color].astype("category") else: coords += [col_for_color] points = points[coords].compute() @@ -683,6 +690,7 @@ def _render_points( alpha=render_params.alpha, table_name=table_name, render_type="points", + coordinate_system=coordinate_system, ) if added_color_from_table and col_for_color is not None: @@ -1219,6 +1227,7 @@ def _render_labels( cmap_params=render_params.cmap_params, table_name=table_name, table_layer=table_layer, + coordinate_system=coordinate_system, ) # rasterize could have removed labels from label diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 95df77c0..49aa9fbf 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -2,7 +2,6 @@ import math import os -import warnings from collections import OrderedDict from collections.abc import Iterable, Mapping, Sequence from copy import copy @@ -46,7 +45,7 @@ from matplotlib_scalebar.scalebar import ScaleBar from numpy.ma.core import MaskedArray from numpy.random import default_rng -from pandas.api.types import CategoricalDtype +from pandas.api.types import CategoricalDtype, is_bool_dtype, is_numeric_dtype, is_string_dtype from pandas.core.arrays.categorical import Categorical from scanpy import settings from scanpy.plotting._tools.scatterplots import _add_categorical_legend @@ -424,11 +423,10 @@ def _as_rgba_array(x: Any) -> np.ndarray: used_norm = colors.Normalize(vmin=vmin, vmax=vmax, clip=False) fill_c[is_num] = cmap(used_norm(num[is_num])) - # non-numeric entries as explicit colors - # treat missing values as na_color, and only convert valid color-like entries - non_numeric_mask = (~is_num) & c_series.notna() - if non_numeric_mask.any(): - fill_c[non_numeric_mask] = ColorConverter().to_rgba_array(c_series[non_numeric_mask].tolist()) + # non-numeric, non-NaN entries as explicit colors + non_numeric_color_mask = (~is_num) & c_series.notna().to_numpy() + if non_numeric_color_mask.any(): + fill_c[non_numeric_color_mask] = ColorConverter().to_rgba_array(c_series[non_numeric_color_mask].tolist()) # Case C: single color or list of color-like specs (strings or tuples) else: @@ -796,6 +794,64 @@ def _get_colors_for_categorical_obs( return palette[:len_cat] # type: ignore[return-value] +def _format_element_name(element_name: list[str] | str | None) -> str: + if isinstance(element_name, str): + return element_name + if isinstance(element_name, list) and len(element_name) > 0: + return ", ".join(element_name) + return "" + + +def _infer_color_data_kind( + series: pd.Series, + value_to_plot: str, + element_name: list[str] | str | None, + table_name: str | None, + warn_on_object_to_categorical: bool = False, +) -> tuple[Literal["numeric", "categorical"], pd.Series | pd.Categorical]: + element_label = _format_element_name(element_name) + + if isinstance(series.dtype, pd.CategoricalDtype): + return "categorical", pd.Categorical(series) + + if is_bool_dtype(series.dtype): + return "numeric", series.astype(float) + + if is_numeric_dtype(series.dtype): + return "numeric", pd.to_numeric(series, errors="coerce") + + if is_string_dtype(series.dtype) or series.dtype == object: + non_na = series[~pd.isna(series)] + if len(non_na) == 0: + return "numeric", pd.to_numeric(series, errors="coerce") + + numeric_like = pd.to_numeric(non_na, errors="coerce") + has_numeric = numeric_like.notna().any() + has_non_numeric = numeric_like.isna().any() + + if has_numeric and has_non_numeric: + invalid_examples = non_na[numeric_like.isna()].astype(str).unique()[:3] + location = f" in table '{table_name}'" if table_name is not None else "" + raise TypeError( + f"Column '{value_to_plot}' for element '{element_label}'{location} contains both numeric and " + f"non-numeric values (e.g. {', '.join(invalid_examples)}). " + "Please ensure that the column stores consistent data." + ) + + if has_numeric: + return "numeric", pd.to_numeric(series, errors="coerce") + + if warn_on_object_to_categorical: + logger.warning( + f"Converting copy of '{value_to_plot}' column to categorical dtype for categorical plotting. " + "Consider converting before plotting." + ) + + return "categorical", pd.Categorical(series) + + return "numeric", pd.to_numeric(series, errors="coerce") + + def _set_color_source_vec( sdata: sd.SpatialData, element: SpatialElement | None, @@ -809,6 +865,7 @@ def _set_color_source_vec( table_name: str | None = None, table_layer: str | None = None, render_type: Literal["points"] | None = None, + coordinate_system: str | None = None, ) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]: if value_to_plot is None and element is not None: color = np.full(len(element), na_color.get_hex_with_alpha()) @@ -827,7 +884,7 @@ def _set_color_source_vec( f"Color key '{value_to_plot}' for element '{element_name}' been found in multiple locations: {origins}." ) - if len(origins) == 1: + if len(origins) == 1 and value_to_plot is not None: color_source_vector = get_values( value_key=value_to_plot, sdata=sdata, @@ -836,33 +893,35 @@ def _set_color_source_vec( table_layer=table_layer, )[value_to_plot] - # numerical vs. categorical case - if color_source_vector is not None and not isinstance(color_source_vector.dtype, pd.CategoricalDtype): - is_numeric_like = pd.api.types.is_numeric_dtype(color_source_vector.dtype) - is_object_series = isinstance(color_source_vector, pd.Series) and color_source_vector.dtype == "O" + color_series = ( + color_source_vector if isinstance(color_source_vector, pd.Series) else pd.Series(color_source_vector) + ) - # If it's an object-typed series but not coercible to float, treat as categorical - if is_object_series and not _is_coercable_to_float(color_source_vector): - color_source_vector = pd.Categorical(color_source_vector) - else: - is_numeric_like = True + kind, processed = _infer_color_data_kind( + series=color_series, + value_to_plot=value_to_plot, + element_name=element_name, + table_name=table_name, + warn_on_object_to_categorical=table_name is not None, + ) - # Continuous case: return early - if is_numeric_like: - if ( - not isinstance(element, GeoDataFrame) - and isinstance(palette, list) - and palette[0] is not None - or isinstance(element, GeoDataFrame) - and isinstance(palette, list) - ): - logger.warning( - "Ignoring categorical palette which is given for a continuous variable. " - "Consider using `cmap` to pass a ColorMap." - ) - return None, color_source_vector, False + if kind == "numeric": + numeric_vector = processed + if ( + not isinstance(element, GeoDataFrame) + and isinstance(palette, list) + and palette[0] is not None + or isinstance(element, GeoDataFrame) + and isinstance(palette, list) + ): + logger.warning( + "Ignoring categorical palette which is given for a continuous variable. " + "Consider using `cmap` to pass a ColorMap." + ) + return None, numeric_vector, False - color_source_vector = pd.Categorical(color_source_vector) # convert, e.g., `pd.Series` + assert isinstance(processed, pd.Categorical) + color_source_vector = processed # convert, e.g., `pd.Series` # Use the provided table_name parameter, fall back to only one present table_to_use: str | None @@ -941,11 +1000,14 @@ def _set_color_source_vec( return color_source_vector, color_vector, True - logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not been found, using default colors.") - # Fallback: color everything with na_color; use element length when table is unknown - n_obs = len(element) if element is not None else (sdata[table_name].n_obs if table_name in sdata.tables else 0) - color = np.full(n_obs, na_color.get_hex_with_alpha()) - return color, color, False + if table_name is None: + raise KeyError( + f"Unable to locate color key '{value_to_plot}' for element '{element_name}'. " + "Please ensure the key exists in a table annotating this element." + ) + raise KeyError( + f"Unable to locate color key '{value_to_plot}' in table '{table_name}' for element '{element_name}'." + ) def _map_color_seg( @@ -2369,35 +2431,39 @@ def _validate_col_for_column_table( table_name: str | None, labels: bool = False, ) -> tuple[str | None, str | None]: + if col_for_color is None: + return None, None + if not labels and col_for_color in sdata[element_name].columns: table_name = None elif table_name is not None: tables = get_element_annotators(sdata, element_name) - if table_name not in tables or ( - col_for_color not in sdata[table_name].obs.columns and col_for_color not in sdata[table_name].var_names - ): - warnings.warn( - f"Table '{table_name}' does not annotate element '{element_name}'.", - UserWarning, - stacklevel=2, + if table_name not in tables: + raise KeyError(f"Table '{table_name}' does not annotate element '{element_name}'.") + if col_for_color not in sdata[table_name].obs.columns and col_for_color not in sdata[table_name].var_names: + raise KeyError( + f"Column '{col_for_color}' not found in obs/var of table '{table_name}' for element '{element_name}'." ) - table_name = None - col_for_color = None else: tables = get_element_annotators(sdata, element_name) - for table_name in tables.copy(): - if col_for_color not in sdata[table_name].obs.columns and col_for_color not in sdata[table_name].var_names: - tables.remove(table_name) if len(tables) == 0: - col_for_color = None - elif len(tables) >= 1: - table_name = next(iter(tables)) - if len(tables) > 1: - warnings.warn( - f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.", - UserWarning, - stacklevel=2, - ) + raise KeyError( + f"Element '{element_name}' has no annotating tables. " + f"Cannot use column '{col_for_color}' for coloring. " + "Please ensure the element is annotated by at least one table." + ) + # Now check which tables contain the column + for annotates in tables.copy(): + if col_for_color not in sdata[annotates].obs.columns and col_for_color not in sdata[annotates].var_names: + tables.remove(annotates) + if len(tables) == 0: + raise KeyError( + f"Unable to locate color key '{col_for_color}' for element '{element_name}'. " + "Please ensure the key exists in a table annotating this element." + ) + table_name = next(iter(tables)) + if len(tables) > 1: + logger.warning(f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.") return col_for_color, table_name @@ -2520,11 +2586,6 @@ def _get_wanted_render_elements( raise ValueError(f"Unknown element type {element_type}") -def _is_coercable_to_float(series: pd.Series) -> bool: - numeric_series = pd.to_numeric(series, errors="coerce") - return not numeric_series.isnull().any() - - def _ax_show_and_transform( array: MaskedArray[tuple[int, ...], Any] | npt.NDArray[Any], trans_data: CompositeGenericTransform, @@ -2961,11 +3022,7 @@ def _multipolygon_to_square(multipolygon: shapely.MultiPolygon) -> tuple[shapely else: non_point_count += 1 if non_point_count > 0: - warnings.warn( - "visium_hex supports Points best. Non-Point geometries will use regular hex conversion.", - UserWarning, - stacklevel=2, - ) + logger.warning("visium_hex supports Points best. Non-Point geometries will use regular hex conversion.") if len(point_centers) >= 2: centers = np.array(point_centers, dtype=float) # pairwise min distance diff --git a/tests/_images/Points_can_annotate_points_with_table_and_groups.png b/tests/_images/Points_can_annotate_points_with_table_and_groups.png index 4941bfd2..b03781a4 100644 Binary files a/tests/_images/Points_can_annotate_points_with_table_and_groups.png and b/tests/_images/Points_can_annotate_points_with_table_and_groups.png differ diff --git a/tests/_images/Shapes_can_handle_nan_values_in_color_data.png b/tests/_images/Shapes_can_handle_nan_values_in_color_data.png new file mode 100644 index 00000000..46d9dc44 Binary files /dev/null and b/tests/_images/Shapes_can_handle_nan_values_in_color_data.png differ diff --git a/tests/_images/Shapes_colorbar_normalization_with_nan_values.png b/tests/_images/Shapes_colorbar_normalization_with_nan_values.png new file mode 100644 index 00000000..2e224d5b Binary files /dev/null and b/tests/_images/Shapes_colorbar_normalization_with_nan_values.png differ diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index baa2ce29..a585d4eb 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -360,7 +360,7 @@ def test_plot_respects_custom_colors_from_uns_with_groups_and_palette( ).pl.show() -def test_warns_when_table_does_not_annotate_element(sdata_blobs: SpatialData): +def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData): # Work on an independent copy since we mutate tables sdata_blobs_local = deepcopy(sdata_blobs) @@ -371,12 +371,13 @@ def test_warns_when_table_does_not_annotate_element(sdata_blobs: SpatialData): sdata_blobs_local["other_table"] = other_table # Rendering "blobs_labels" with a table that annotates "blobs_multiscale_labels" - # should raise a warning and fall back to using no table. - with pytest.warns(UserWarning, match="does not annotate element"): - ( - sdata_blobs_local.pl.render_labels( - "blobs_labels", - color="channel_0_sum", - table_name="other_table", - ).pl.show() - ) + # should now raise to alert the user about the mismatch. + with pytest.raises( + KeyError, + match="Table 'other_table' does not annotate element 'blobs_labels'", + ): + sdata_blobs_local.pl.render_labels( + "blobs_labels", + color="channel_0_sum", + table_name="other_table", + ).pl.show() diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index 704a7d6b..34b63e94 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -190,6 +190,24 @@ def test_plot_datashader_can_color_by_category(self, sdata_blobs: SpatialData): method="datashader", ).pl.show() + def test_render_points_missing_color_column_raises_key_error(self, sdata_blobs: SpatialData) -> None: + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_points"] * sdata_blobs["table"].n_obs) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_points" + with pytest.raises(KeyError, match="does_not_exist"): + sdata_blobs.pl.render_points(element="blobs_points", color="does_not_exist") + + def test_render_points_missing_region_for_table_raises_key_error(self, sdata_blobs: SpatialData) -> None: + blob = deepcopy(sdata_blobs) + blob["table"].obs["region"] = pd.Categorical(["blobs_points"] * blob["table"].n_obs) + blob["table"].uns["spatialdata_attrs"]["region"] = "blobs_points" + blob["table"].obs["table_value"] = np.arange(blob["table"].n_obs) + other_table = blob["table"].copy() + other_table.obs["region"] = pd.Categorical(["other"] * other_table.n_obs) + other_table.uns["spatialdata_attrs"]["region"] = "other" + blob["other_table"] = other_table + with pytest.raises(KeyError, match="does not annotate element"): + blob.pl.render_points(element="blobs_points", color="table_value", table_name="other_table") + def test_plot_datashader_colors_from_table_obs(self, sdata_blobs: SpatialData): n_obs = len(sdata_blobs["blobs_points"]) obs = pd.DataFrame( @@ -505,7 +523,7 @@ def test_plot_can_annotate_points_with_table_layer(self, sdata_blobs: SpatialDat sdata_blobs.pl.render_points("blobs_points", color="feature0", size=10, table_layer="normalized").pl.show() -def test_warns_when_table_does_not_annotate_element(sdata_blobs: SpatialData): +def test_raises_when_table_does_not_annotate_element(sdata_blobs: SpatialData): # Work on an independent copy since we mutate tables sdata_blobs_local = deepcopy(sdata_blobs) @@ -516,15 +534,16 @@ def test_warns_when_table_does_not_annotate_element(sdata_blobs: SpatialData): sdata_blobs_local["other_table"] = other_table # Rendering "blobs_points" with a table that annotates "blobs_labels" - # should raise a warning and fall back to using no table. - with pytest.warns(UserWarning, match="does not annotate element"): - ( - sdata_blobs_local.pl.render_points( - "blobs_points", - color="channel_0_sum", - table_name="other_table", - ).pl.show() - ) + # should now raise to alert the user about the mismatch. + with pytest.raises( + KeyError, + match="Table 'other_table' does not annotate element 'blobs_points'", + ): + sdata_blobs_local.pl.render_points( + "blobs_points", + color="channel_0_sum", + table_name="other_table", + ).pl.show() def test_datashader_colors_points_from_table_obs(sdata_blobs: SpatialData): diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index f4bd2308..48d8dd2d 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -17,6 +17,7 @@ from spatialdata.transformations._utils import _set_transformations import spatialdata_plot # noqa: F401 +from spatialdata_plot._logging import logger, logger_warns from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over, get_standard_RNG sc.pl.set_rcParams_defaults() @@ -102,7 +103,7 @@ def _make_multi(): def test_plot_can_render_multipolygons_with_multiple_holes(self): square = [(0.0, 0.0), (5.0, 0.0), (5.0, 5.0), (0.0, 5.0), (0.0, 0.0)] first_hole = [(1.0, 1.0), (2.0, 1.0), (2.0, 2.0), (1.0, 2.0), (1.0, 1.0)] - second_hole = [(3.0, 3.0), (4.0, 3.0), (4.0, 4.0), (3.0, 4.0), (3.0, 3.0)] + second_hole = [(3.0, 3.0), (4.0, 3.0), (4.0, 4.0), (3.0, 3.0), (3.0, 3.0)] multipoly = MultiPolygon([Polygon(square, holes=[first_hole, second_hole])]) cell_polygon_table = gpd.GeoDataFrame(geometry=gpd.GeoSeries([multipoly])) sd_polygons = ShapesModel.parse(cell_polygon_table) @@ -194,6 +195,37 @@ def test_plot_colorbar_can_be_normalised(self, sdata_blobs: SpatialData): norm = Normalize(vmin=0, vmax=5, clip=True) sdata_blobs.pl.render_shapes("blobs_polygons", color="cluster", groups=["c1"], norm=norm).pl.show() + def test_render_shapes_raises_when_color_key_missing(self, sdata_blobs_shapes_annotated: SpatialData): + missing_col = "__non_existent_column__" + with pytest.raises(KeyError, match=f"Unable to locate color key '{missing_col}'"): + sdata_blobs_shapes_annotated.pl.render_shapes( + element="blobs_polygons", + color=missing_col, + ).pl.show() + + def test_render_shapes_raises_for_invalid_table_name(self, sdata_blobs_shapes_annotated: SpatialData): + table = sdata_blobs_shapes_annotated["table"] + table.obs["region"] = pd.Categorical(["blobs_polygons"] * table.n_obs) + table.uns["spatialdata_attrs"]["region"] = "blobs_polygons" + table.obs["valid_col"] = np.arange(table.n_obs) + + with pytest.raises(KeyError, match="Table 'not_a_table' does not annotate element 'blobs_polygons'"): + sdata_blobs_shapes_annotated.pl.render_shapes( + element="blobs_polygons", color="valid_col", table_name="not_a_table" + ) + + def test_render_shapes_raises_for_missing_column_in_table(self, sdata_blobs_shapes_annotated: SpatialData): + table = sdata_blobs_shapes_annotated["table"] + table.obs["region"] = pd.Categorical(["blobs_polygons"] * table.n_obs) + table.uns["spatialdata_attrs"]["region"] = "blobs_polygons" + + with pytest.raises( + KeyError, match="Column 'not_a_column' not found in obs/var of table 'table' for element 'blobs_polygons'" + ): + sdata_blobs_shapes_annotated.pl.render_shapes( + element="blobs_polygons", color="not_a_column", table_name="table" + ) + def test_plot_can_plot_shapes_after_spatial_query(self, sdata_blobs: SpatialData): # subset to only shapes, should be unnecessary after rasterizeation of multiscale images is included blob = SpatialData.init_from_elements( @@ -627,7 +659,7 @@ def test_plot_respects_custom_colors_from_uns(self, sdata_blobs: SpatialData): categories[:3] = ["a", "b", "c"] categories = pd.Categorical(categories, categories=["a", "b", "c"]) sdata_blobs["table"].obs["category"] = categories - sdata_blobs["table"].uns["category_colors"] = ["red", "green", "blue"] # purple, green, yellow + sdata_blobs["table"].uns["category_colors"] = ["red", "green", "blue"] sdata_blobs.pl.render_shapes(shapes_name, color="category", table_name="table").pl.show() @@ -717,39 +749,61 @@ def test_plot_datashader_can_render_shapes_with_colored_double_outline(self, sda method="datashader", ).pl.show() - -def test_warns_when_table_does_not_annotate_element(sdata_blobs: SpatialData): - # Work on an independent copy since we mutate tables - sdata_blobs_local = deepcopy(sdata_blobs) - - # Create a table that annotates a DIFFERENT element than the one we will render - other_table = sdata_blobs_local["table"].copy() - other_table.obs["region"] = pd.Categorical(["blobs_points"] * other_table.n_obs) # Different region - other_table.uns["spatialdata_attrs"]["region"] = "blobs_points" - sdata_blobs_local["other_table"] = other_table - - # Rendering "blobs_circles" with a table that annotates "blobs_points" - # should raise a warning and fall back to using no table. - with pytest.warns(UserWarning, match="does not annotate element"): - ( + def test_raises_when_table_does_not_annotate_element(self, sdata_blobs: SpatialData): + # Work on an independent copy since we mutate tables + sdata_blobs_local = deepcopy(sdata_blobs) + + # Create a table that annotates a DIFFERENT element than the one we will render + other_table = sdata_blobs_local["table"].copy() + other_table.obs["region"] = pd.Categorical(["blobs_points"] * other_table.n_obs) # Different region + other_table.uns["spatialdata_attrs"]["region"] = "blobs_points" + sdata_blobs_local["other_table"] = other_table + + # Rendering "blobs_circles" with a table that annotates "blobs_points" + # should now raise to alert the user about the mismatch. + with pytest.raises( + KeyError, + match="Table 'other_table' does not annotate element 'blobs_circles'", + ): sdata_blobs_local.pl.render_shapes( "blobs_circles", color="channel_0_sum", table_name="other_table", ).pl.show() - ) + + def test_raises_when_element_has_no_annotating_tables(self, sdata_blobs: SpatialData): + """Test that rendering an element with no annotating tables raises a clear error.""" + # Work on an independent copy since we mutate tables + sdata_blobs_local = deepcopy(sdata_blobs) + + # Change the region to something else so it no longer annotates "blobs_circles" + table = sdata_blobs_local["table"].copy() + table.obs["region"] = pd.Categorical(["blobs_points"] * table.n_obs) + table.uns["spatialdata_attrs"]["region"] = "blobs_points" + sdata_blobs_local["table"] = table + + # Now "blobs_circles" should have no annotating tables + # Trying to render it with a color column should raise an error + with pytest.raises( + KeyError, + match="Element 'blobs_circles' has no annotating tables", + ): + sdata_blobs_local.pl.render_shapes( + "blobs_circles", + color="channel_0_sum", + ).pl.show() -def test_plot_can_handle_nan_values_in_color_data(sdata_blobs: SpatialData): - """Test that NaN values in color data are handled gracefully.""" +def test_plot_can_handle_nan_values_in_color_data(sdata_blobs: SpatialData, caplog): + """Test that NaN values in color data are handled gracefully and logged.""" sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles" # Add color column with NaN values sdata_blobs.shapes["blobs_circles"]["color_with_nan"] = [1.0, 2.0, np.nan, 4.0, 5.0] - # Test that rendering works with NaN values and issues warning - with pytest.warns(UserWarning, match="Found 1 NaN values in color data"): + # Expect a logger warning about NaN values + with logger_warns(caplog, logger, match="Found 1 NaN values in color data"): sdata_blobs.pl.render_shapes(element="blobs_circles", color="color_with_nan", na_color="red").pl.show() @@ -760,7 +814,7 @@ def test_plot_colorbar_normalization_with_nan_values(sdata_blobs: SpatialData): sdata_blobs.shapes["blobs_polygons"]["color_with_nan"] = [1.0, 2.0, np.nan, 4.0, 5.0] - # Test colorbar with NaN values - should use nanmin/nanmax + # Test colorbar with NaN values - should use nanmin/nanmax under the hood and not crash sdata_blobs.pl.render_shapes(element="blobs_polygons", color="color_with_nan", na_color="gray").pl.show() @@ -772,10 +826,12 @@ def test_plot_can_handle_non_numeric_radius_values(sdata_blobs: SpatialData): def test_plot_can_handle_mixed_numeric_and_color_data(sdata_blobs: SpatialData): - """Test handling of mixed numeric and color-like data.""" + """Test that mixed numeric and color-like data raises a clear error.""" sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs) sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles" sdata_blobs.shapes["blobs_circles"]["mixed_data"] = [1.0, 2.0, np.nan, "red", 5.0] - sdata_blobs.pl.render_shapes(element="blobs_circles", color="mixed_data", na_color="gray").pl.show() + # Mixed numeric / non-numeric values should raise a TypeError + with pytest.raises(TypeError, match="contains both numeric and non-numeric values"): + sdata_blobs.pl.render_shapes(element="blobs_circles", color="mixed_data", na_color="gray").pl.show()