diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 9448dbb8..1e41ac5d 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -854,6 +854,7 @@ def render_labels( scale=param_values["scale"], table_name=param_values["table_name"], table_layer=param_values["table_layer"], + transfunc=kwargs.get("transfunc"), zorder=n_steps, colorbar=param_values["colorbar"], colorbar_params=param_values["colorbar_params"], diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 6574730d..852e7a86 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -402,7 +402,7 @@ def _render_shapes( sdata_filt[element] = shapes # color_source_vector is None when the values aren't categorical - if values_are_categorical and render_params.transfunc is not None: + if not values_are_categorical and render_params.transfunc is not None: color_vector = render_params.transfunc(color_vector) norm = copy(render_params.cmap_params.norm) @@ -1702,6 +1702,10 @@ def _render_labels( if isinstance(color_vector.dtype, pd.CategoricalDtype): color_vector = color_vector.remove_unused_categories() + # color_source_vector is None when the values aren't categorical + if color_source_vector is None and render_params.transfunc is not None: + color_vector = render_params.transfunc(color_vector) + def _draw_labels( seg_erosionpx: int | None, seg_boundaries: bool, diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index 16f81578..e53c07e0 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -304,6 +304,7 @@ class LabelsRenderParams: scale: str | None = None table_name: str | None = None table_layer: str | None = None + transfunc: Callable[[float], float] | None = None zorder: int = 0 colorbar: bool | str | None = "auto" colorbar_params: dict[str, object] | None = None diff --git a/tests/_images/Labels_transfunc_applied_to_continuous_labels.png b/tests/_images/Labels_transfunc_applied_to_continuous_labels.png new file mode 100644 index 00000000..ca9cb22a Binary files /dev/null and b/tests/_images/Labels_transfunc_applied_to_continuous_labels.png differ diff --git a/tests/_images/Shapes_transfunc_applied_to_continuous_shapes.png b/tests/_images/Shapes_transfunc_applied_to_continuous_shapes.png new file mode 100644 index 00000000..d19711f7 Binary files /dev/null and b/tests/_images/Shapes_transfunc_applied_to_continuous_shapes.png differ diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index a19a3100..9a8057c1 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -295,6 +295,11 @@ def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData): cmap=_viridis_with_under_over(), ).pl.show() + def test_plot_transfunc_applied_to_continuous_labels(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", transfunc=lambda x: x * 100).pl.show( + title="transfunc: x * 100" + ) + def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialData): sdata_blobs["table"].layers["normalized"] = get_standard_RNG().random(sdata_blobs["table"].X.shape) sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", table_layer="normalized").pl.show() @@ -454,3 +459,17 @@ def test_groups_warns_when_no_groups_match_labels(sdata_blobs: SpatialData, capl sdata_blobs.pl.render_labels( labels_name, color="cat", groups=["nonexistent"], table_name="label_table", na_color=None ).pl.show() + + +def test_transfunc_is_applied_for_continuous_labels(sdata_blobs: SpatialData): + called = [] + + def track(x): + called.append(True) + return x + + fig, ax = plt.subplots() + sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", transfunc=track).pl.show(ax=ax) + plt.close(fig) + + assert called, "transfunc was not called for continuous labels data" diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 51d48cbc..db3df973 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -732,6 +732,11 @@ def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs_shapes_annotated element="blobs_polygons", color="value", norm=Normalize(2, 4, clip=False), cmap=_viridis_with_under_over() ).pl.show() + def test_plot_transfunc_applied_to_continuous_shapes(self, sdata_blobs_shapes_annotated: SpatialData): + sdata_blobs_shapes_annotated.pl.render_shapes( + element="blobs_polygons", color="value", transfunc=lambda x: x * 100 + ).pl.show(title="transfunc: x * 100") + def test_plot_datashader_can_color_with_norm_and_clipping(self, sdata_blobs_shapes_annotated: SpatialData): sdata_blobs_shapes_annotated.pl.render_shapes( element="blobs_polygons", @@ -1310,3 +1315,17 @@ def test_datashader_na_color_nan_overlay(sdata_blobs: SpatialData, na_color: str f"Expected {expected_images} image(s), got {len(ax.get_images())} for na_color={na_color!r}" ) plt.close(fig) + + +def test_transfunc_is_applied_for_continuous_shapes(sdata_blobs_shapes_annotated: SpatialData): + called = [] + + def track(x): + called.append(True) + return x + + fig, ax = plt.subplots() + sdata_blobs_shapes_annotated.pl.render_shapes("blobs_polygons", color="value", transfunc=track).pl.show(ax=ax) + plt.close(fig) + + assert called, "transfunc was not called for continuous shapes data"