Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 174 additions & 4 deletions src/spatialdata_plot/pl/basic.py

Large diffs are not rendered by default.

132 changes: 117 additions & 15 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from spatialdata_plot._logging import logger
from spatialdata_plot.pl.render_params import (
Color,
ColorbarSpec,
FigParams,
ImageRenderParams,
LabelsRenderParams,
Expand Down Expand Up @@ -61,6 +62,55 @@
_Normalize = Normalize | abc.Sequence[Normalize]


def _split_colorbar_params(params: dict[str, object] | None) -> tuple[dict[str, object], dict[str, object], str | None]:
"""Split colorbar params into layout hints, Matplotlib kwargs, and label override."""
layout: dict[str, object] = {}
cbar_kwargs: dict[str, object] = {}
label_override: str | None = None
for key, value in (params or {}).items():
key_lower = key.lower()
if key_lower in {"loc", "location"}:
layout["location"] = value
elif key_lower == "width" or key_lower == "fraction":
layout["fraction"] = value
elif key_lower == "pad":
layout["pad"] = value
elif key_lower == "label":
label_override = None if value is None else str(value)
else:
cbar_kwargs[key] = value
return layout, cbar_kwargs, label_override


def _resolve_colorbar_label(
colorbar_params: dict[str, object] | None, fallback: str | None, *, is_default_channel_name: bool = False
) -> str | None:
"""Pick a colorbar label from params or fall back to provided value."""
_, _, label = _split_colorbar_params(colorbar_params)
if label is not None:
return label
if is_default_channel_name:
return None
return fallback


def _should_request_colorbar(
colorbar: bool | str | None,
*,
has_mappable: bool,
is_continuous: bool,
auto_condition: bool = True,
) -> bool:
"""Resolve colorbar setting to a final boolean request."""
if not has_mappable or not is_continuous:
return False
if colorbar is True:
return True
if colorbar in {False, None}:
return False
return bool(auto_condition)


def _render_shapes(
sdata: sd.SpatialData,
render_params: ShapesRenderParams,
Expand All @@ -69,6 +119,7 @@ def _render_shapes(
fig_params: FigParams,
scalebar_params: ScalebarParams,
legend_params: LegendParams,
colorbar_requests: list[ColorbarSpec] | None = None,
) -> None:
element = render_params.element
col_for_color = render_params.col_for_color
Expand All @@ -80,7 +131,8 @@ def _render_shapes(
filter_tables=bool(render_params.table_name),
)

if (table_name := render_params.table_name) is None:
table_name = render_params.table_name
if table_name is None:
table = None
shapes = sdata_filt[element]
else:
Expand Down Expand Up @@ -159,16 +211,13 @@ def _render_shapes(
else:
palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()]))

if (
has_valid_color = (
len(set(color_vector)) != 1
or list(set(color_vector))[0] != render_params.cmap_params.na_color.get_hex_with_alpha()
):
)
if has_valid_color and color_source_vector is not None and col_for_color is not None:
# necessary in case different shapes elements are annotated with one table
if color_source_vector is not None and col_for_color is not None:
color_source_vector = color_source_vector.remove_unused_categories()

# False if user specified color-like with 'color' parameter
colorbar = False if col_for_color is None else legend_params.colorbar
color_source_vector = color_source_vector.remove_unused_categories()

# Apply the transformation to the PatchCollection's paths
trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system)
Expand Down Expand Up @@ -515,8 +564,11 @@ def _render_shapes(
if color_source_vector is not None and render_params.col_for_color is not None:
color_source_vector = color_source_vector.remove_unused_categories()

# False if user specified color-like with 'color' parameter
colorbar = False if render_params.col_for_color is None else legend_params.colorbar
wants_colorbar = _should_request_colorbar(
render_params.colorbar,
has_mappable=cax is not None,
is_continuous=render_params.col_for_color is not None and color_source_vector is None,
)

_ = _decorate_axs(
ax=ax,
Expand All @@ -534,7 +586,13 @@ def _render_shapes(
legend_loc=legend_params.legend_loc,
legend_fontoutline=legend_params.legend_fontoutline,
na_in_legend=legend_params.na_in_legend,
colorbar=colorbar,
colorbar=wants_colorbar and legend_params.colorbar,
colorbar_params=render_params.colorbar_params,
colorbar_requests=colorbar_requests,
colorbar_label=_resolve_colorbar_label(
render_params.colorbar_params,
col_for_color if isinstance(col_for_color, str) else None,
),
scalebar_dx=scalebar_params.scalebar_dx,
scalebar_units=scalebar_params.scalebar_units,
)
Expand All @@ -548,6 +606,7 @@ def _render_points(
fig_params: FigParams,
scalebar_params: ScalebarParams,
legend_params: LegendParams,
colorbar_requests: list[ColorbarSpec] | None = None,
) -> None:
element = render_params.element
col_for_color = render_params.col_for_color
Expand Down Expand Up @@ -894,6 +953,12 @@ def _render_points(
else:
palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()]))

wants_colorbar = _should_request_colorbar(
render_params.colorbar,
has_mappable=cax is not None,
is_continuous=col_for_color is not None and color_source_vector is None,
)

_ = _decorate_axs(
ax=ax,
cax=cax,
Expand All @@ -910,7 +975,13 @@ def _render_points(
legend_loc=legend_params.legend_loc,
legend_fontoutline=legend_params.legend_fontoutline,
na_in_legend=legend_params.na_in_legend,
colorbar=legend_params.colorbar,
colorbar=wants_colorbar and legend_params.colorbar,
colorbar_params=render_params.colorbar_params,
colorbar_requests=colorbar_requests,
colorbar_label=_resolve_colorbar_label(
render_params.colorbar_params,
col_for_color if isinstance(col_for_color, str) else None,
),
scalebar_dx=scalebar_params.scalebar_dx,
scalebar_units=scalebar_params.scalebar_units,
)
Expand All @@ -925,6 +996,7 @@ def _render_images(
scalebar_params: ScalebarParams,
legend_params: LegendParams,
rasterize: bool,
colorbar_requests: list[ColorbarSpec] | None = None,
) -> None:
sdata_filt = sdata.filter_by_coordinate_system(
coordinate_system=coordinate_system,
Expand Down Expand Up @@ -1003,9 +1075,26 @@ def _render_images(
norm=render_params.cmap_params.norm,
)

if legend_params.colorbar:
wants_colorbar = _should_request_colorbar(
render_params.colorbar,
has_mappable=n_channels == 1,
is_continuous=True,
auto_condition=n_channels == 1,
)
if wants_colorbar and legend_params.colorbar and colorbar_requests is not None:
sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm)
fig_params.fig.colorbar(sm, ax=ax)
colorbar_requests.append(
ColorbarSpec(
ax=ax,
mappable=sm,
params=render_params.colorbar_params,
label=_resolve_colorbar_label(
render_params.colorbar_params,
str(channels[0]),
is_default_channel_name=isinstance(channels[0], (int, np.integer)),
),
)
)

# 2) Image has any number of channels but 1
else:
Expand Down Expand Up @@ -1165,6 +1254,7 @@ def _render_labels(
scalebar_params: ScalebarParams,
legend_params: LegendParams,
rasterize: bool,
colorbar_requests: list[ColorbarSpec] | None = None,
) -> None:
element = render_params.element
table_name = render_params.table_name
Expand Down Expand Up @@ -1310,6 +1400,12 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
else:
raise ValueError("Parameters 'fill_alpha' and 'outline_alpha' cannot both be 0.")

colorbar_requested = _should_request_colorbar(
render_params.colorbar,
has_mappable=cax is not None,
is_continuous=color is not None and color_source_vector is None and not categorical,
)

_ = _decorate_axs(
ax=ax,
cax=cax,
Expand All @@ -1326,7 +1422,13 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
legend_loc=legend_params.legend_loc,
legend_fontoutline=legend_params.legend_fontoutline,
na_in_legend=(legend_params.na_in_legend if groups is None else len(groups) == len(set(color_vector))),
colorbar=legend_params.colorbar,
colorbar=colorbar_requested and legend_params.colorbar,
colorbar_params=render_params.colorbar_params,
colorbar_requests=colorbar_requests,
colorbar_label=_resolve_colorbar_label(
render_params.colorbar_params,
color if isinstance(color, str) else None,
),
scalebar_dx=scalebar_params.scalebar_dx,
scalebar_units=scalebar_params.scalebar_units,
# scalebar_kwargs=scalebar_params.scalebar_kwargs,
Expand Down
25 changes: 25 additions & 0 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
from matplotlib.axes import Axes
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Colormap, ListedColormap, Normalize, rgb2hex, to_hex
from matplotlib.figure import Figure

Expand Down Expand Up @@ -183,6 +184,22 @@ class LegendParams:
colorbar: bool = True


@dataclass
class ColorbarSpec:
"""Data required to create a colorbar."""

ax: Axes
mappable: ScalarMappable
params: dict[str, object] | None = None
label: str | None = None
alpha: float | None = None


CBAR_DEFAULT_LOCATION = "right"
CBAR_DEFAULT_FRACTION = 0.075
CBAR_DEFAULT_PAD = 0.015


@dataclass
class ScalebarParams:
"""Scalebar params."""
Expand Down Expand Up @@ -213,6 +230,8 @@ class ShapesRenderParams:
table_layer: str | None = None
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
colorbar: bool | str | None = "auto"
colorbar_params: dict[str, object] | None = None


@dataclass
Expand All @@ -233,6 +252,8 @@ class PointsRenderParams:
table_name: str | None = None
table_layer: str | None = None
ds_reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None = None
colorbar: bool | str | None = "auto"
colorbar_params: dict[str, object] | None = None


@dataclass
Expand All @@ -247,6 +268,8 @@ class ImageRenderParams:
percentiles_for_norm: tuple[float | None, float | None] = (None, None)
scale: str | None = None
zorder: int = 0
colorbar: bool | str | None = "auto"
colorbar_params: dict[str, object] | None = None


@dataclass
Expand All @@ -267,3 +290,5 @@ class LabelsRenderParams:
table_name: str | None = None
table_layer: str | None = None
zorder: int = 0
colorbar: bool | str | None = "auto"
colorbar_params: dict[str, object] | None = None
Loading