diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index eb0fc300..07b44a44 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -1354,6 +1354,8 @@ def _render_images( ) _ax_show_and_transform(stacked, trans_data, ax, **show_kwargs) + if render_params.channels_as_legend: + logger.warning("channels_as_legend is not supported for true RGB images and will be ignored.") return # 1) Image has only 1 channel @@ -1386,7 +1388,13 @@ def _render_images( is_continuous=True, auto_condition=n_channels == 1, ) - if wants_colorbar and legend_params.colorbar and colorbar_requests is not None: + if render_params.channels_as_legend and channel_legend_entries is not None: + # Sample at 0.75 (upper quarter) for a vivid, non-extreme representative color; + # consistent with the multi-channel composite path below. + _collect_channel_legend_entries( + [channels[0]], [matplotlib.colors.to_hex(cmap(0.75))], channel_legend_entries + ) + elif wants_colorbar and legend_params.colorbar and colorbar_requests is not None: sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm) colorbar_requests.append( ColorbarSpec( diff --git a/tests/_images/ChannelsAsCategories_channels_as_legend_sequential_single_channels.png b/tests/_images/ChannelsAsCategories_channels_as_legend_sequential_single_channels.png new file mode 100644 index 00000000..2b6b25c1 Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_legend_sequential_single_channels.png differ diff --git a/tests/_images/ChannelsAsCategories_channels_as_legend_single_channel.png b/tests/_images/ChannelsAsCategories_channels_as_legend_single_channel.png new file mode 100644 index 00000000..27a4341b Binary files /dev/null and b/tests/_images/ChannelsAsCategories_channels_as_legend_single_channel.png differ diff --git a/tests/pl/test_render_images.py b/tests/pl/test_render_images.py index 7bfef71e..4a6bbe9c 100644 --- a/tests/pl/test_render_images.py +++ b/tests/pl/test_render_images.py @@ -10,6 +10,7 @@ from spatialdata.models import Image2DModel import spatialdata_plot # noqa: F401 +from spatialdata_plot._logging import logger, logger_warns from spatialdata_plot.pl.render import _is_rgb_image from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over @@ -545,13 +546,49 @@ def test_plot_channels_as_legend_legend_lower_right(self, sdata_blobs: SpatialDa legend_loc="lower right" ) + def test_plot_channels_as_legend_single_channel(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_images(element="blobs_image", channel=0, channels_as_legend=True).pl.show() + + def test_plot_channels_as_legend_sequential_single_channels(self, sdata_blobs_str: SpatialData): + ( + sdata_blobs_str.pl.render_images( + element="blobs_image", + channel="c1", + palette=["cyan"], + alpha=0.5, + channels_as_legend=True, + ) + .pl.render_images( + element="blobs_image", + channel="c2", + palette=["magenta"], + alpha=0.5, + channels_as_legend=True, + ) + .pl.show() + ) + class TestChannelsAsCategoriesNonVisual: """Non-visual tests for channels_as_legend edge cases.""" - def test_channels_as_legend_ignored_for_single_channel(self, sdata_blobs: SpatialData): + def test_channels_as_legend_single_channel_shows_legend_no_colorbar(self, sdata_blobs: SpatialData): fig, ax = plt.subplots() sdata_blobs.pl.render_images(element="blobs_image", channel=0, channels_as_legend=True).pl.show(ax=ax) + legend = ax.get_legend() + assert legend is not None + assert "0" in [t.get_text() for t in legend.get_texts()] + assert len(fig.axes) == 1 # no colorbar inset axes + plt.close("all") + + def test_channels_as_legend_rgb_warns_and_no_legend(self, caplog): + data = np.zeros((3, 50, 50), dtype=np.float64) + data[0], data[1], data[2] = 0.8, 0.2, 0.1 + img = Image2DModel.parse(data, dims=("c", "y", "x"), c_coords=["r", "g", "b"]) + sdata = SpatialData(images={"img": img}) + fig, ax = plt.subplots() + with logger_warns(caplog, logger, match="not supported for true RGB"): + sdata.pl.render_images("img", channels_as_legend=True).pl.show(ax=ax) assert ax.get_legend() is None plt.close("all")