diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 5cd6688e..eec55481 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -58,6 +58,7 @@ ) from spatialdata_plot.pl.utils import ( _ax_show_and_transform, + _check_obs_var_shadow, _convert_shapes, _datashader_canvas_from_dataframe, _decorate_axs, @@ -370,6 +371,8 @@ def _render_shapes( groups = render_params.groups table_layer = render_params.table_layer + _check_obs_var_shadow(sdata, element, col_for_color, render_params.table_name) + sdata_filt = sdata.filter_by_coordinate_system( coordinate_system=coordinate_system, filter_tables=bool(render_params.table_name), @@ -766,6 +769,8 @@ def _render_points( groups = render_params.groups palette = render_params.palette + _check_obs_var_shadow(sdata, element, col_for_color, table_name) + if isinstance(groups, str): groups = [groups] @@ -1687,6 +1692,8 @@ def _render_labels( groups = render_params.groups scale = render_params.scale + _check_obs_var_shadow(sdata, element, col_for_color, table_name) + sdata_filt = sdata.filter_by_coordinate_system( coordinate_system=coordinate_system, filter_tables=bool(table_name), diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 35844c6a..c6ae3350 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -104,6 +104,38 @@ } +def _check_obs_var_shadow( + sdata: SpatialData | None, + element_name: str | None, + value_to_plot: str | None, + table_name: str | None, +) -> None: + """Raise if ``value_to_plot`` exists in both ``table.obs.columns`` and ``table.var_names``. + + Upstream ``_get_table_origins`` uses an ``elif`` chain, so a key that lives in + both locations is silently resolved to ``obs`` — masking the user's likely + intent of plotting gene expression. Catch this here before any value fetch. + Any ``None`` parameter short-circuits the check. + """ + if ( + value_to_plot is None + or table_name is None + or element_name is None + or sdata is None + or table_name not in sdata.tables + ): + return + if table_name not in get_element_annotators(sdata, element_name): + return + table = sdata.tables[table_name] + if value_to_plot in table.obs.columns and value_to_plot in table.var_names: + raise ValueError( + f"`color={value_to_plot!r}` is ambiguous: it exists in both " + f"`table[{table_name!r}].obs.columns` and `table[{table_name!r}].var_names`. " + "Rename one of them (or drop the obs column) so the intended source is unambiguous." + ) + + def _gate_palette_and_groups( element_params: dict[str, Any], param_dict: dict[str, Any], diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index 0a277f05..ab279d45 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -422,6 +422,22 @@ def test_color_column_collision_on_annotating_table_raises(): sdata.pl.render_shapes("s", color="#ffa500") +def test_color_key_obs_var_shadow_raises(): + # regression test for #621 + pts = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0], "y": [1.0, 2.0]})) + obs = pd.DataFrame({"instance_id": [0, 1], "region": ["pts"] * 2, "GeneA": [0.9, 0.6]}, index=["0", "1"]) + table = TableModel.parse( + AnnData(X=np.zeros((2, 1)), obs=obs, var=pd.DataFrame(index=["GeneA"])), + region=["pts"], + region_key="region", + instance_key="instance_id", + ) + sdata = SpatialData(points={"pts": pts}, tables={"t": table}) + + with pytest.raises(ValueError, match=r"'GeneA'.*ambiguous.*obs\.columns.*var_names"): + sdata.pl.render_points("pts", color="GeneA", table_name="t").pl.show() + + def test_explicit_table_name_honored_when_element_has_same_column(): # regression test for #620: explicit table_name= must not be silently # discarded when the element has a same-named column with different values.