Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 158 additions & 1 deletion src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
import warnings
from collections import OrderedDict
from collections.abc import Callable
from collections.abc import Callable, Sequence
from copy import deepcopy
from pathlib import Path
from typing import Any, Literal, cast
Expand All @@ -25,12 +25,14 @@
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from spatialdata import get_extent
from spatialdata._utils import _deprecation_alias
from spatialdata.transformations.operations import get_transformation
from xarray import DataArray, DataTree

from spatialdata_plot._accessor import register_spatial_data_accessor
from spatialdata_plot._logging import _log_context
from spatialdata_plot.pl.render import (
_draw_channel_legend,
_render_graph,
_render_images,
_render_labels,
_render_points,
Expand All @@ -44,6 +46,7 @@
ChannelLegendEntry,
CmapParams,
ColorbarSpec,
GraphRenderParams,
ImageRenderParams,
LabelsRenderParams,
LegendParams,
Expand All @@ -64,6 +67,7 @@
_prepare_cmap_norm,
_prepare_params_plot,
_set_outline,
_validate_graph_render_params,
_validate_image_render_params,
_validate_label_render_params,
_validate_points_render_params,
Expand Down Expand Up @@ -856,6 +860,143 @@ def render_labels(
n_steps += 1
return sdata

def render_graph(
self,
element: str | None = None,
color: ColorLike | None = None,
*,
connectivity_key: str = "spatial",
obsp_key: str | None = None,
palette: dict[str, str] | list[str] | str | None = None,
na_color: ColorLike | None = "default",
cmap: Colormap | str | None = None,
norm: Normalize | None = None,
groups: list[str] | str | None = None,
group_key: str | None = None,
edge_width: float | Literal["weight"] = 1.0,
edge_alpha: float | Literal["weight"] = 1.0,
weight_key: str | None = None,
linestyle: str | Sequence[str] = "solid",
rasterize: bool = True,
include_self_loops: bool = False,
colorbar: bool | str | None = "auto",
colorbar_params: dict[str, object] | None = None,
table_name: str | None = None,
) -> sd.SpatialData:
"""Render spatial graph edges between observations.

Draws edges from a connectivity matrix in ``table.obsp`` using
centroid coordinates of the linked spatial element.

Parameters
----------
element : str | None
Name of the shapes/points/labels element the graph connects.
Auto-resolved from the table if omitted.
color : ColorLike | None
A color-like value applied to every edge, or the name of a
``table.obs`` column. Categorical columns colour same-category
edges by the shared value and cross-category edges by
``na_color``. Continuous columns colour edges by the mean of
their endpoint values. Defaults to grey when unset.
connectivity_key : str, default "spatial"
``table.obsp`` key. Tries ``key`` first, then ``f"{key}_connectivities"``.
obsp_key : str | None
``table.obsp`` matrix used as per-edge scalar; coloured via
``cmap``/``norm``. Mutually exclusive with ``color``.
palette : dict[str, str] | list[str] | str | None
Palette for categorical obs coloring. Same as :meth:`render_shapes`.
na_color : ColorLike | None, default "default"
Colour for cross-category edges. ``None`` makes them transparent.
cmap : Colormap | str | None
Colormap for continuous edge coloring.
norm : Normalize | None
Pass ``Normalize(vmin=..., vmax=...)`` to clamp the colormap range.
groups : list[str] | str | None
Show only edges where **both** endpoints fall in these groups.
Requires ``group_key``.
group_key : str | None
``table.obs`` column used for group filtering.
edge_width : float | Literal["weight"], default 1.0
Line width. Pass ``"weight"`` to scale by ``weight_key`` values
into ``[0.5, 3.0]``.
edge_alpha : float | Literal["weight"], default 1.0
Transparency. Pass ``"weight"`` to scale into ``[0.2, 1.0]``.
weight_key : str | None
``table.obsp`` matrix providing per-edge weights. Defaults to
``connectivity_key`` when omitted.
linestyle : str | Sequence[str], default "solid"
``LineCollection`` linestyle (scalar or per-edge).
rasterize : bool, default True
Rasterize the edge collection. Set ``False`` for vector output.
include_self_loops : bool, default False
Render diagonal entries of the connectivity matrix as circles.
colorbar : bool | str | None, default "auto"
Whether to draw a colorbar for continuous edge coloring
(``obsp_key`` or a continuous obs column). ``"auto"`` draws it
when a mappable is present; ``True``/``False`` force it on/off.
colorbar_params : dict[str, object] | None
Optional matplotlib colorbar kwargs and layout hints
(e.g. ``{"loc": "right", "fraction": 0.05, "label": "..."}``).
table_name : str | None
Table containing the graph. Auto-discovered if omitted.

Returns
-------
sd.SpatialData
Copy with rendering parameters stored in the plotting tree.

Notes
-----
Chaining with ``render_shapes``/``render_points`` on the same
categorical column shares the legend; no dedicated edge legend is drawn.
"""
params = _validate_graph_render_params(
self._sdata,
element=element,
connectivity_key=connectivity_key,
obsp_key=obsp_key,
weight_key=weight_key,
palette=palette,
na_color=na_color,
cmap=cmap,
norm=norm,
table_name=table_name,
color=color,
edge_width=edge_width,
edge_alpha=edge_alpha,
groups=groups,
group_key=group_key,
)

sdata = self._copy()
sdata = _verify_plotting_tree(sdata)
n_steps = len(sdata.plotting_tree.keys())
sdata.plotting_tree[f"{n_steps + 1}_render_graph"] = GraphRenderParams(
element=params["element"],
connectivity_obsp_key=params["connectivity_obsp_key"],
table_name=params["table_name"],
color=params["color"],
obs_col=params["obs_col"],
obsp_key=params["obsp_key"],
cmap_params=params["cmap_params"],
palette_map=params["palette_map"],
na_color=params["na_color"],
color_source=params["color_source"],
groups=params["groups"],
group_key=params["group_key"],
edge_width=params["edge_width"],
edge_alpha=params["edge_alpha"],
weight_key=params["weight_key"],
linestyle=linestyle,
rasterize=rasterize,
include_self_loops=include_self_loops,
zorder=n_steps,
colorbar=colorbar,
colorbar_params=colorbar_params,
)
return sdata

def show(
self,
coordinate_systems: list[str] | str | None = None,
Expand Down Expand Up @@ -1020,6 +1161,7 @@ def show(
"render_shapes",
"render_labels",
"render_points",
"render_graph",
]

# prepare rendering params
Expand Down Expand Up @@ -1340,6 +1482,21 @@ def _draw_colorbar(
rasterize=rasterize,
)

elif cmd == "render_graph":
graph_element = params_copy.element
element_in_cs = graph_element in sdata and cs in set(
get_transformation(sdata[graph_element], get_all=True).keys()
)
if element_in_cs:
_render_graph(
sdata=sdata,
render_params=params_copy,
coordinate_system=cs,
ax=ax,
legend_params=legend_params_obj,
colorbar_requests=axis_colorbar_requests,
)

if title is None:
t = cs
elif len(title) == 1:
Expand Down
Loading
Loading