diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 085d9b31..236eaf37 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -60,6 +60,7 @@ from spatialdata_plot.pl.utils import ( _RENDER_CMD_TO_CS_FLAG, _draw_scalebar, + _expand_color_panels, _get_cs_contents, _get_elements_to_be_rendered, _get_valid_cs, @@ -169,7 +170,12 @@ def _copy( shapes=self._sdata.shapes if shapes is None else shapes, tables=self._sdata.tables if tables is None else tables, ) - sdata.plotting_tree = self._sdata.plotting_tree if hasattr(self._sdata, "plotting_tree") else OrderedDict() + # Shallow-copy the plotting tree so appending a render step to one chain does not mutate a + # sibling chain branched off the same object (the RenderParams values stay shared; `show()` + # deep-copies them before use). + sdata.plotting_tree = ( + OrderedDict(self._sdata.plotting_tree) if hasattr(self._sdata, "plotting_tree") else OrderedDict() + ) sdata._source_sdata = getattr(self._sdata, "_source_sdata", self._sdata) return sdata @@ -309,7 +315,7 @@ def annotate( def render_shapes( self, element: str | None = None, - color: ColorLike | None = None, + color: ColorLike | list[str] | None = None, *, fill_alpha: float | int | None = None, groups: list[str] | str | None = None, @@ -345,7 +351,7 @@ def render_shapes( element : str | None, optional The name of the shapes element to render. If `None`, all shapes elements in the `SpatialData` object will be used. - color : ColorLike | None, optional + color : ColorLike | list[str] | None, optional Can either be color-like (name of a color as string, e.g. "red", hex representation, e.g. "#000000" or "#000000ff", or an RGB(A) array as a tuple or list containing 3-4 floats within [0, 1]. If an alpha value is indicated, the value of `fill_alpha` takes precedence if given) or a string representing a key in @@ -353,6 +359,11 @@ def render_shapes( `element` is `None`, if possible the color will be broadcasted to all elements. For this, the table in which the color key is found must annotate the respective element (region must be set to the specific element). If the color column is found in multiple locations, please provide the table_name to be used for the elements. + A **list of column/key names** (e.g. ``["gene1", "gene2"]``) produces one panel per key, like + ``scanpy``'s ``color=[...]``. ``palette``/``cmap``/``norm``/``groups`` are applied to every panel, + each panel auto-scales independently, and ``show(ncols=...)`` controls the grid width. Multi-panel + color requires a single coordinate system and only one ``render_*`` call in the chain may pass a list + (other calls use a scalar color and are drawn into every panel as a shared background). fill_alpha : float | int | None, optional Alpha value for the fill of shapes. By default, it is set to 1.0 or, if a color is given that implies an alpha, that value is used for `fill_alpha`. If an alpha channel is present in a cmap passed by the user, @@ -441,68 +452,76 @@ def render_shapes( sd.SpatialData A copy of the SpatialData object with the rendering parameters stored in its plotting tree. """ - params_dict = _validate_shape_render_params( + panel_param_dicts = _expand_color_panels( self._sdata, - element=element, - fill_alpha=fill_alpha, - groups=groups, - palette=palette, - color=color, - na_color=na_color, - outline_alpha=outline_alpha, - outline_color=outline_color, - outline_width=outline_width, - cmap=cmap, - norm=norm, - scale=scale, - table_name=table_name, - table_layer=table_layer, - shape=shape, - method=method, - ds_reduction=datashader_reduction, - colorbar=colorbar, - colorbar_params=colorbar_params, - gene_symbols=gene_symbols, + color, + "render_shapes", + lambda color_value: _validate_shape_render_params( + self._sdata, + element=element, + fill_alpha=fill_alpha, + groups=groups, + palette=palette, + color=color_value, + na_color=na_color, + outline_alpha=outline_alpha, + outline_color=outline_color, + outline_width=outline_width, + cmap=cmap, + norm=norm, + scale=scale, + table_name=table_name, + table_layer=table_layer, + shape=shape, + method=method, + ds_reduction=datashader_reduction, + colorbar=colorbar, + colorbar_params=colorbar_params, + gene_symbols=gene_symbols, + ), ) sdata = self._copy() sdata = _verify_plotting_tree(sdata) + n_steps = len(sdata.plotting_tree.keys()) - for element, param_values in params_dict.items(): - final_outline_alpha, outline_params = _set_outline( - params_dict[element]["outline_alpha"], - params_dict[element]["outline_width"], - params_dict[element]["outline_color"], - ) - cmap_params = _prepare_cmap_norm( - cmap=cmap, - norm=norm, - na_color=params_dict[element]["na_color"], # type: ignore[arg-type] - ) - sdata.plotting_tree[f"{n_steps + 1}_render_shapes"] = ShapesRenderParams( - element=element, - color=param_values["color"], - col_for_color=param_values["col_for_color"], - col_for_outline_color=param_values["col_for_outline_color"], - outline_table_name=param_values["outline_table_name"], - groups=param_values["groups"], - scale=param_values["scale"], - outline_params=outline_params, - cmap_params=cmap_params, - palette=param_values["palette"], - outline_alpha=final_outline_alpha, - fill_alpha=param_values["fill_alpha"], - transfunc=transfunc, - table_name=param_values["table_name"], - table_layer=param_values["table_layer"], - shape=param_values["shape"], - zorder=n_steps, - method=param_values["method"], - ds_reduction=param_values["ds_reduction"], - colorbar=param_values["colorbar"], - colorbar_params=param_values["colorbar_params"], - ) - n_steps += 1 + for panel_key, params_dict in panel_param_dicts: + for element, param_values in params_dict.items(): + final_outline_alpha, outline_params = _set_outline( + params_dict[element]["outline_alpha"], + params_dict[element]["outline_width"], + params_dict[element]["outline_color"], + ) + cmap_params = _prepare_cmap_norm( + cmap=cmap, + norm=norm, + na_color=params_dict[element]["na_color"], # type: ignore[arg-type] + ) + sdata.plotting_tree[f"{n_steps + 1}_render_shapes"] = ShapesRenderParams( + element=element, + color=param_values["color"], + col_for_color=param_values["col_for_color"], + col_for_outline_color=param_values["col_for_outline_color"], + outline_table_name=param_values["outline_table_name"], + groups=param_values["groups"], + scale=param_values["scale"], + outline_params=outline_params, + cmap_params=cmap_params, + palette=param_values["palette"], + outline_alpha=final_outline_alpha, + fill_alpha=param_values["fill_alpha"], + transfunc=transfunc, + table_name=param_values["table_name"], + table_layer=param_values["table_layer"], + shape=param_values["shape"], + zorder=n_steps, + method=param_values["method"], + ds_reduction=param_values["ds_reduction"], + colorbar=param_values["colorbar"], + colorbar_params=param_values["colorbar_params"], + panel_key=panel_key, + ) + n_steps += 1 return sdata @@ -920,7 +939,7 @@ def render_images( def render_labels( self, element: str | None = None, - color: ColorLike | None = None, + color: ColorLike | list[str] | None = None, *, groups: list[str] | str | None = None, contour_px: int | None = 3, @@ -953,13 +972,18 @@ def render_labels( element : str | None The name of the labels element to render. If `None`, all label elements in the `SpatialData` object will be used and all parameters will be broadcasted if possible. - color : ColorLike | None + color : ColorLike | list[str] | None Can either be color-like (name of a color as string, e.g. "red", hex representation, e.g. "#000000" or "#000000ff", or an RGB(A) array as a tuple or list containing 3-4 floats within [0, 1]. If an alpha value is indicated, the value of `fill_alpha` takes precedence if given) or a string representing a key in :attr:`sdata.table.obs` or in the index of :attr:`sdata.table.var`. The latter can be used to color by categorical or continuous variables. If the color column is found in multiple locations, please provide the table_name to be used for the element if you would like a specific table to be used. + A **list of column/key names** (e.g. ``["gene1", "gene2"]``) produces one panel per key, like + ``scanpy``'s ``color=[...]``. ``palette``/``cmap``/``norm``/``groups`` are applied to every panel, + each panel auto-scales independently, and ``show(ncols=...)`` controls the grid width. Multi-panel + color requires a single coordinate system and only one ``render_*`` call in the chain may pass a list + (other calls use a scalar color and are drawn into every panel as a shared background). groups : list[str] | str | None When using `color` and the key represents discrete labels, `groups` can be used to show only a subset of them. By default, non-matching labels are hidden. To show non-matching labels, set ``na_color`` explicitly. @@ -1024,59 +1048,66 @@ def render_labels( sd.SpatialData A copy of the SpatialData object with the rendering parameters stored in its plotting tree. """ - params_dict = _validate_label_render_params( + panel_param_dicts = _expand_color_panels( self._sdata, - element=element, - cmap=cmap, - color=color, - contour_px=contour_px, - fill_alpha=fill_alpha, - groups=groups, - na_color=na_color, - norm=norm, - outline_alpha=outline_alpha, - outline_color=outline_color, - palette=palette, - scale=scale, - colorbar=colorbar, - colorbar_params=colorbar_params, - table_name=table_name, - table_layer=table_layer, - gene_symbols=gene_symbols, + color, + "render_labels", + lambda color_value: _validate_label_render_params( + self._sdata, + element=element, + cmap=cmap, + color=color_value, + contour_px=contour_px, + fill_alpha=fill_alpha, + groups=groups, + na_color=na_color, + norm=norm, + outline_alpha=outline_alpha, + outline_color=outline_color, + palette=palette, + scale=scale, + colorbar=colorbar, + colorbar_params=colorbar_params, + table_name=table_name, + table_layer=table_layer, + gene_symbols=gene_symbols, + ), ) sdata = self._copy() sdata = _verify_plotting_tree(sdata) n_steps = len(sdata.plotting_tree.keys()) - for element, param_values in params_dict.items(): - cmap_params = _prepare_cmap_norm( - cmap=cmap, - norm=norm, - na_color=param_values["na_color"], # type: ignore[arg-type] - ) - sdata.plotting_tree[f"{n_steps + 1}_render_labels"] = LabelsRenderParams( - element=element, - color=param_values["color"], - col_for_color=param_values["col_for_color"], - col_for_outline_color=param_values["col_for_outline_color"], - outline_table_name=param_values["outline_table_name"], - groups=param_values["groups"], - contour_px=param_values["contour_px"], - cmap_params=cmap_params, - palette=param_values["palette"], - outline_alpha=param_values["outline_alpha"], - outline_color=param_values["outline_color"], - fill_alpha=param_values["fill_alpha"], - scale=param_values["scale"], - table_name=param_values["table_name"], - table_layer=param_values["table_layer"], - transfunc=transfunc, - zorder=n_steps, - colorbar=param_values["colorbar"], - colorbar_params=param_values["colorbar_params"], - ) - n_steps += 1 + for panel_key, params_dict in panel_param_dicts: + for element, param_values in params_dict.items(): + cmap_params = _prepare_cmap_norm( + cmap=cmap, + norm=norm, + na_color=param_values["na_color"], # type: ignore[arg-type] + ) + sdata.plotting_tree[f"{n_steps + 1}_render_labels"] = LabelsRenderParams( + element=element, + color=param_values["color"], + col_for_color=param_values["col_for_color"], + col_for_outline_color=param_values["col_for_outline_color"], + outline_table_name=param_values["outline_table_name"], + groups=param_values["groups"], + contour_px=param_values["contour_px"], + cmap_params=cmap_params, + palette=param_values["palette"], + outline_alpha=param_values["outline_alpha"], + outline_color=param_values["outline_color"], + fill_alpha=param_values["fill_alpha"], + scale=param_values["scale"], + table_name=param_values["table_name"], + table_layer=param_values["table_layer"], + transfunc=transfunc, + zorder=n_steps, + colorbar=param_values["colorbar"], + colorbar_params=param_values["colorbar_params"], + panel_key=panel_key, + ) + n_steps += 1 return sdata def render_graph( @@ -1284,7 +1315,8 @@ def show( hspace : float, default 0.25 Vertical spacing between panels (passed to :class:`matplotlib.gridspec.GridSpec`). ncols : int, default 4 - Number of columns in the multi-panel grid. + Number of columns in the multi-panel grid. Panels are created one per coordinate system, or, + when a ``render_*`` call was given a list of color keys, one per key (scanpy-style ``color=[...]``). frameon : bool | None Whether to draw the axes frame. If ``None``, the frame is hidden automatically for multi-panel plots. figsize : tuple[float, float] | None @@ -1297,7 +1329,8 @@ def show( Pass axes created from your figure via ``ax`` instead. title : list[str] | str | None Title(s) for the plot. A single string is applied to all panels; a list must match the number - of coordinate systems. If ``None``, each panel is titled with its coordinate system name. + of panels. If ``None``, each panel is titled with its coordinate system name, or, in multi-panel + color mode, with its color key. pad_extent : int | float, default 0 Padding added around the computed spatial extent on all sides. ax : list[Axes] | Axes | None @@ -1457,9 +1490,9 @@ def show( # When CS was auto-detected and ax is provided, keep only CS that have # element types for ALL render commands (workaround for upstream #176). - if ax is not None: + if ax is not None and cs_was_auto: n_ax = 1 if isinstance(ax, Axes) else len(ax) - if cs_was_auto and len(coordinate_systems) > n_ax: + if len(coordinate_systems) > n_ax: required_flags = [_RENDER_CMD_TO_CS_FLAG[cmd] for cmd in cmds if cmd in _RENDER_CMD_TO_CS_FLAG] strict_cs = [ cs_name @@ -1469,10 +1502,32 @@ def show( if strict_cs: coordinate_systems = strict_cs - if len(coordinate_systems) != n_ax: + # Determine the panel layout. Panels are normally one per coordinate system, but when a + # render_* call passed a list of color keys we instead lay out one panel per key within a + # single coordinate system (scanpy-style `color=[...]`). Render entries tagged with a + # `panel_key` belong to that key's panel; untagged entries are shared across all panels. + panel_keys: list[str] = [] + for _cmd, _params in render_cmds: + pkey = getattr(_params, "panel_key", None) + if pkey is not None and pkey not in panel_keys: + panel_keys.append(pkey) + if panel_keys: + if len(coordinate_systems) != 1: + raise ValueError( + "A list of color keys (multi-panel plotting) requires exactly one coordinate system, " + f"but {len(coordinate_systems)} were selected: {coordinate_systems}. " + "Pass `coordinate_systems=` to choose a single one." + ) + panels: list[tuple[str, str | None]] = [(coordinate_systems[0], key) for key in panel_keys] + else: + panels = [(cs, None) for cs in coordinate_systems] + num_panels = len(panels) + + if ax is not None: + n_ax = 1 if isinstance(ax, Axes) else len(ax) + if num_panels != n_ax: msg = ( - f"Mismatch between number of matplotlib axes objects ({n_ax}) " - f"and number of coordinate systems ({len(coordinate_systems)})." + f"Mismatch between number of matplotlib axes objects ({n_ax}) and number of panels ({num_panels})." ) if cs_was_auto: msg += ( @@ -1484,7 +1539,7 @@ def show( # set up canvas fig_params, scalebar_params_obj = _prepare_params_plot( - num_panels=len(coordinate_systems), + num_panels=num_panels, figsize=figsize, dpi=dpi, fig=fig, @@ -1611,7 +1666,7 @@ def _draw_colorbar( # go through tree - for i, cs in enumerate(coordinate_systems): + for i, (cs, panel_key) in enumerate(panels): sdata = self._copy() cs_row = cs_index.loc[cs] has_images = cs_row["has_images"] @@ -1630,6 +1685,11 @@ def _draw_colorbar( wanted_elements: list[str] = [] for cmd, params in render_cmds: + # Skip render entries that belong to a different color panel. Entries with no + # `panel_key` (None) are shared and drawn into every panel (e.g. a background image). + cmd_panel_key = getattr(params, "panel_key", None) + if panel_key is not None and cmd_panel_key is not None and cmd_panel_key != panel_key: + continue # We create a copy here as the wanted elements can change from one cs to another. params_copy = deepcopy(params) if cmd == "render_images" and has_images: @@ -1739,14 +1799,14 @@ def _draw_colorbar( ) if title is None: - t = cs + t = panel_key if panel_key is not None else cs elif len(title) == 1: t = title[0] else: try: t = title[i] except IndexError as e: - raise IndexError("The number of titles must match the number of coordinate systems.") from e + raise IndexError("The number of titles must match the number of panels.") from e ax.set_title(t) ax.set_aspect("equal") if fig_params.frameon is False: diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index a8a7c928..acf03e3a 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -255,6 +255,9 @@ class ShapesRenderParams: ds_reduction: _DsReduction | None = None colorbar: bool | str | None = "auto" colorbar_params: dict[str, object] | None = None + # Multi-panel color: when set, this render entry belongs to the panel identified by this + # color key. ``None`` means the entry is shared across every panel (e.g. a background layer). + panel_key: str | None = None @dataclass @@ -324,6 +327,9 @@ class LabelsRenderParams: zorder: int = 0 colorbar: bool | str | None = "auto" colorbar_params: dict[str, object] | None = None + # Multi-panel color: when set, this render entry belongs to the panel identified by this + # color key. ``None`` means the entry is shared across every panel (e.g. a background layer). + panel_key: str | None = None @dataclass diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 37bd795d..dd56addf 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -3031,6 +3031,78 @@ def _ensure_table_and_layer_exist_in_sdata( return param_dict +def _resolve_color_panels(color: Any) -> tuple[Any, list[str] | None]: + """Split a ``color`` argument into a scalar color and an optional multi-panel key list. + + Returns ``(scalar_color, panel_keys)``. When ``panel_keys`` is ``None`` the call is a + normal single-color render and ``scalar_color`` is the (unchanged) color to use. When + ``panel_keys`` is a list, the render must be expanded into one panel per key. + + A list of all-strings is treated as multi-panel keys; a length-1 list normalizes to a + scalar color; an all-numeric list stays a single RGB(A) color. Empty, duplicate, or + mixed str/number lists raise ``ValueError``. + """ + if not isinstance(color, list): + return color, None + if all(isinstance(c, str) for c in color): + if len(color) == 0: + raise ValueError("`color` was given an empty list; provide at least one column/key name.") + if len(color) != len(set(color)): + dups = sorted({c for c in color if color.count(c) > 1}) + raise ValueError(f"`color` contains duplicate keys {dups}; each multi-panel key must be unique.") + if len(color) == 1: + return color[0], None + return None, list(color) + if any(isinstance(c, str) for c in color): + raise ValueError( + "`color` list must be either all column/key names (str) for a multi-panel plot, " + "or 3-4 floats for a single RGB(A) color, not a mix of both." + ) + return color, None + + +def _expand_color_panels( + sdata: SpatialData, + color: Any, + render_fn_name: str, + validate: Callable[[Any], dict[str, Any]], +) -> list[tuple[str | None, dict[str, Any]]]: + """Resolve ``color`` into validated per-panel render params for the multi-panel ``color=[...]`` feature. + + ``validate`` is a callback that runs the render function's own parameter validation for a single + color value and returns its per-element ``params_dict``. Returns a list of ``(panel_key, params_dict)`` + pairs: a single ``(None, params_dict)`` for the scalar case, or one entry per key for a key list. + + Enforces that only one ``render_*`` call per figure may pass a color list, and aggregates per-key + validation errors into a single message. Used by ``render_shapes`` and ``render_labels``. + """ + color, panel_keys = _resolve_color_panels(color) + if panel_keys is not None and any( + getattr(params, "panel_key", None) is not None for params in getattr(sdata, "plotting_tree", {}).values() + ): + raise ValueError( + "Only one `render_*` call may use a list of color keys per figure. Other chained render " + "calls must use a single (scalar) color; they are drawn into every panel as a shared layer." + ) + + color_specs = [(None, color)] if panel_keys is None else [(key, key) for key in panel_keys] + panel_param_dicts: list[tuple[str | None, dict[str, Any]]] = [] + key_errors: dict[str, str] = {} + for panel_key, color_value in color_specs: + try: + params_dict = validate(color_value) + except (KeyError, ValueError) as e: + if panel_keys is None: + raise + key_errors[panel_key] = str(e) # type: ignore[index] + continue + panel_param_dicts.append((panel_key, params_dict)) + if key_errors: + details = "\n".join(f" - {key!r}: {msg}" for key, msg in key_errors.items()) + raise ValueError(f"Invalid color key(s) for multi-panel `{render_fn_name}`:\n{details}") + return panel_param_dicts + + def _validate_label_render_params( sdata: sd.SpatialData, element: str | None, diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index 2024bd24..1babf756 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -595,3 +595,13 @@ def test_labels_outline_color_groups_filter_aligns(sdata_blobs: SpatialData): outline_color="stage", ).pl.show(ax=ax) plt.close(fig) + + +def test_render_labels_color_list_creates_one_panel_per_key(sdata_blobs: SpatialData): + """A list of color keys produces one panel per key, titled by the key (#611).""" + # the default blobs table annotates blobs_labels with channel_*_sum vars + axs = sdata_blobs.pl.render_labels("blobs_labels", color=["channel_0_sum", "channel_1_sum"]).pl.show(return_ax=True) + assert isinstance(axs, list) + assert len(axs) == 2 + assert [ax.get_title() for ax in axs] == ["channel_0_sum", "channel_1_sum"] + plt.close("all") diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 1254bbcb..f51cbf36 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -1590,3 +1590,104 @@ def test_outline_color_cross_table(sdata_blobs: SpatialData): outline_color="stage", ).pl.show(ax=ax) plt.close(fig) + + +def _add_score_columns(sdata: SpatialData, element: str = "blobs_circles") -> SpatialData: + """Add three continuous columns to a shapes element so it can be colored by key.""" + gdf = sdata.shapes[element] + rng = get_standard_RNG() + for name in ("scoreA", "scoreB", "scoreC"): + gdf[name] = rng.random(len(gdf)) + return sdata + + +def test_render_shapes_color_list_panel_structure(sdata_blobs: SpatialData): + """A color list yields one panel per key, titled by the key and wrapped by ncols (#611).""" + _add_score_columns(sdata_blobs) + axs = sdata_blobs.pl.render_shapes("blobs_circles", color=["scoreA", "scoreB", "scoreC"]).pl.show( + ncols=2, return_ax=True + ) + assert [ax.get_title() for ax in axs] == ["scoreA", "scoreB", "scoreC"] + # 3 panels at ncols=2 -> a 2x2 grid + assert axs[0].get_subplotspec().get_gridspec().get_geometry() == (2, 2) + plt.close("all") + + +def test_render_shapes_color_list_per_panel_legends(sdata_blobs: SpatialData): + """Each categorical panel gets its own legend with only its own categories (#611).""" + gdf = sdata_blobs.shapes["blobs_circles"] # 5 circles + gdf["catA"] = pd.Categorical(["a", "b", "a", "b", "a"]) + gdf["catB"] = pd.Categorical(["x", "y", "z", "x", "y"]) + axs = sdata_blobs.pl.render_shapes("blobs_circles", color=["catA", "catB"]).pl.show(return_ax=True) + legends = [ax.get_legend() for ax in axs] + assert all(leg is not None for leg in legends) + assert {t.get_text() for t in legends[0].get_texts()} == {"a", "b"} + assert {t.get_text() for t in legends[1].get_texts()} == {"x", "y", "z"} + plt.close("all") + + +def test_render_shapes_color_list_shares_scalar_background(sdata_blobs: SpatialData): + """A scalar-colored render call is drawn into every color panel as a shared background (#611).""" + _add_score_columns(sdata_blobs) + axs = ( + sdata_blobs.pl.render_images("blobs_image") + .pl.render_shapes("blobs_circles", color=["scoreA", "scoreB"]) + .pl.show(return_ax=True) + ) + assert len(axs) == 2 + assert all(len(ax.get_images()) >= 1 for ax in axs) + plt.close("all") + + +@pytest.mark.parametrize("color", [["scoreA"], [1.0, 0.0, 0.0]], ids=["length-1-key-list", "rgb-float-list"]) +def test_render_shapes_color_scalar_forms_stay_single_panel(sdata_blobs: SpatialData, color): + """A length-1 key list normalizes to a scalar, and an RGB(A) float list stays one color (#611).""" + _add_score_columns(sdata_blobs) + ax = sdata_blobs.pl.render_shapes("blobs_circles", color=color).pl.show(return_ax=True) + assert isinstance(ax, plt.Axes) + plt.close("all") + + +@pytest.mark.parametrize( + ("make_chain", "match"), + [ + (lambda s: s.pl.render_shapes("blobs_circles", color=[]), "empty list"), + (lambda s: s.pl.render_shapes("blobs_circles", color=["scoreA", "scoreA"]), "duplicate keys"), + (lambda s: s.pl.render_shapes("blobs_circles", color=["scoreA", 0.5]), "all column/key names"), + (lambda s: s.pl.render_shapes("blobs_circles", color=["scoreA", "nope"]), "Invalid color key"), + ( + lambda s: s.pl.render_shapes("blobs_circles", color=["scoreA", "scoreB"]).pl.render_shapes( + "blobs_circles", color=["scoreA", "scoreC"] + ), + "Only one `render_\\*` call", + ), + ], + ids=["empty", "duplicate", "mixed", "bad-key", "two-lists"], +) +def test_render_shapes_color_list_invalid_raises(sdata_blobs: SpatialData, make_chain, match): + """All multi-panel color misuse raises a clear ValueError before any drawing (#611).""" + _add_score_columns(sdata_blobs) + with pytest.raises(ValueError, match=match): + make_chain(sdata_blobs) + + +def test_render_shapes_color_list_branches_are_independent(sdata_blobs: SpatialData): + """Branching a render chain must not leak color-panel entries across branches (#611). + + Appending a render step to one branch must not mutate the shared plotting tree of the + base object, otherwise a sibling branch wrongly trips the 'only one color list' guard. + """ + _add_score_columns(sdata_blobs) + base = sdata_blobs.pl.render_images("blobs_image") + base_steps = len(base.plotting_tree) + + branch1 = base.pl.render_shapes("blobs_circles", color=["scoreA", "scoreB"]) + # the base chain must be untouched by building branch1 + assert len(base.plotting_tree) == base_steps + + # an independent second branch off the same base must not raise the multi-list guard + axs = base.pl.render_shapes("blobs_circles", color=["scoreA", "scoreC"]).pl.show(return_ax=True) + assert len(axs) == 2 + # branch1 is still usable and unaffected + assert len(branch1.plotting_tree) == base_steps + 2 + plt.close("all")