Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
85ed17d
perf: session-scope blobs fixtures to cut setup time by ~10s
LucaMarconato May 5, 2026
9f1378d
perf: use blobs(length=128) in test_vectorize to cut to_circles cost
LucaMarconato May 5, 2026
149bfdf
chore: remove benchmark CSVs from repo
LucaMarconato May 5, 2026
73f4387
chore: remove benchmark stats csv
LucaMarconato May 5, 2026
7726556
perf: eliminate double-validation in __setitem__ and use fast fixture…
LucaMarconato May 5, 2026
ac97136
perf: add validate=False option to get_model, clarify setitem validat…
LucaMarconato May 5, 2026
acb25b1
perf: skip re-validation when building SpatialData from already-valid…
LucaMarconato May 5, 2026
4113706
Revert "perf: skip re-validation when building SpatialData from alrea…
LucaMarconato May 6, 2026
89a2202
perf: vectorize label centroid computation, 30x speedup
LucaMarconato May 6, 2026
b9d9b57
perf: session-scope blobs fixtures in concatenate and get_attrs tests
LucaMarconato May 6, 2026
b5b09e3
perf: reduce categories and partitions in test_categories_on_partitio…
LucaMarconato May 6, 2026
031ba56
perf: shrink test_categories_on_partitioned_dataframe to N=10, local …
LucaMarconato May 6, 2026
51674b9
ci: parallelize test suite with pytest-xdist worksteal
LucaMarconato May 6, 2026
4da1bc5
refactor: clean up _get_centroids_for_labels docstring and clarify in…
LucaMarconato May 6, 2026
61553c7
refactor: minor code swap for improved clarity
LucaMarconato May 6, 2026
0448312
perf: speed up test suite — subset sdata, disable numba JIT, promote …
LucaMarconato May 6, 2026
b66a41b
fix+perf: faster test suite — subset fixtures, disable numba JIT, fix…
LucaMarconato May 6, 2026
58b64c8
refactor: update comments and refine test_consolidated_metadata subset
LucaMarconato May 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@ jobs:
matrix:
include:
- {os: windows-latest, python: "3.11", dask-version: "2025.12.0", name: "min dask"}
- {os: windows-latest, python: "3.13", dask-version: "latest"}
- {os: windows-latest, python: "3.14", dask-version: "latest"}
- {os: ubuntu-latest, python: "3.11", dask-version: "latest"}
- {os: ubuntu-latest, python: "3.13", dask-version: "latest"}
- {os: ubuntu-latest, python: "3.14", dask-version: "latest"}
- {os: macos-latest, python: "3.11", dask-version: "latest"}
- {os: macos-latest, python: "3.13", prerelease: "allow", name: "prerelease"}
- {os: macos-latest, python: "3.14", prerelease: "allow", name: "prerelease"}
env:
OS: ${{ matrix.os }}
Expand Down Expand Up @@ -62,7 +59,7 @@ jobs:
PLATFORM: ${{ matrix.os }}
DISPLAY: :42
run: |
uv run pytest --cov --color=yes --cov-report=xml
uv run pytest --cov --color=yes --cov-report=xml -n auto --dist worksteal
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ test = [
"pytest",
"pytest-cov",
"pytest-mock",
"pytest-xdist",
"torch",
]
docs = [
Expand Down Expand Up @@ -102,6 +103,7 @@ strict = true
addopts = [
"--import-mode=importlib", # allow using test files with same name
"-s", # print output from tests
"-p no:napari", # napari registers a pytest plugin via its entry point; disable it here since spatialdata tests don't need it
]
# These are all markers coming from xarray, dask or anndata. Added here to silence warnings.
markers = [
Expand Down
25 changes: 3 additions & 22 deletions src/spatialdata/_core/_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
PointsModel,
ShapesModel,
TableModel,
get_axes_names,
get_model,
)

Expand Down Expand Up @@ -50,6 +49,7 @@ def _check_key(key: str, element_keys: Iterable[str], shared_keys: set[str | Non
raise KeyError(*e.args) from e

def __setitem__(self, key: str, value: T) -> None:
# note that each __setitem__ in the subclasses calls get_model(), which performs data validation
self._add_shared_key(key)
super().__setitem__(key, value)

Expand All @@ -72,15 +72,7 @@ def __setitem__(self, key: str, value: Raster_T) -> None:
schema = get_model(value)
if schema not in (Image2DModel, Image3DModel):
raise TypeError(f"Unknown element type with schema: {schema!r}.")
ndim = len(get_axes_names(value))
if ndim == 3:
Image2DModel.validate(value)
super().__setitem__(key, value)
elif ndim == 4:
Image3DModel.validate(value)
super().__setitem__(key, value)
else:
NotImplementedError("TODO: implement for ndim > 4.")
super().__setitem__(key, value)


class Labels(Elements[DataArray | DataTree]):
Expand All @@ -89,15 +81,7 @@ def __setitem__(self, key: str, value: Raster_T) -> None:
schema = get_model(value)
if schema not in (Labels2DModel, Labels3DModel):
raise TypeError(f"Unknown element type with schema: {schema!r}.")
ndim = len(get_axes_names(value))
if ndim == 2:
Labels2DModel.validate(value)
super().__setitem__(key, value)
elif ndim == 3:
Labels3DModel.validate(value)
super().__setitem__(key, value)
else:
NotImplementedError("TODO: implement for ndim > 3.")
super().__setitem__(key, value)


class Shapes(Elements[GeoDataFrame]):
Expand All @@ -106,7 +90,6 @@ def __setitem__(self, key: str, value: GeoDataFrame) -> None:
schema = get_model(value)
if schema != ShapesModel:
raise TypeError(f"Unknown element type with schema: {schema!r}.")
ShapesModel.validate(value)
super().__setitem__(key, value)


Expand All @@ -116,7 +99,6 @@ def __setitem__(self, key: str, value: DaskDataFrame) -> None:
schema = get_model(value)
if schema != PointsModel:
raise TypeError(f"Unknown element type with schema: {schema!r}.")
PointsModel.validate(value)
super().__setitem__(key, value)


Expand All @@ -126,5 +108,4 @@ def __setitem__(self, key: str, value: AnnData) -> None:
schema = get_model(value)
if schema != TableModel:
raise TypeError(f"Unknown element type with schema: {schema!r}.")
TableModel.validate(value)
super().__setitem__(key, value)
59 changes: 22 additions & 37 deletions src/spatialdata/_core/centroids.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

from collections import defaultdict
from functools import singledispatch

import dask.array as da
import numpy as np
import pandas as pd
import xarray as xr
from dask.dataframe import DataFrame as DaskDataFrame
Expand Down Expand Up @@ -56,40 +55,29 @@ def get_centroids(
raise ValueError(f"The object type {type(e)} is not supported.")


def _get_centroids_for_axis(xdata: xr.DataArray, axis: str) -> pd.DataFrame:
def _get_centroids_for_labels(xdata: xr.DataArray) -> pd.DataFrame:
"""
Compute the component "axis" of the centroid of each label as a weighted average of the xarray coordinates.
Compute centroids for all labels in a DataArray in a single O(n_voxels) pass.

Parameters
----------
xdata
The xarray DataArray containing the labels.
axis
The axis for which the centroids are computed.

Returns
-------
pd.DataFrame
A DataFrame containing one column, named after "axis", with the centroids of the labels along that axis.
The index of the DataFrame is the collection of label values, sorted in ascending order.
Works for any number of spatial dimensions (2D and 3D labels).
"""
centroids: dict[int, float] = defaultdict(float)
for i in xdata[axis]:
portion = xdata.sel(**{axis: i}).data
u = da.unique(portion, return_counts=True)
labels_values = u[0].compute()
counts = u[1].compute()
for j in range(len(labels_values)):
label_value = labels_values[j]
count = counts[j]
centroids[label_value] += count * i.values.item()

all_labels_values, all_labels_counts = da.unique(xdata.data, return_counts=True)
all_labels = dict(zip(all_labels_values.compute(), all_labels_counts.compute(), strict=True))
for label_value in centroids:
centroids[label_value] /= all_labels[label_value]
centroids = dict(sorted(centroids.items(), key=lambda x: x[0]))
return pd.DataFrame({axis: centroids.values()}, index=list(centroids.keys()))
arr = xdata.data.compute()
axes = list(xdata.dims)

# Map label values to a contiguous range for bincount efficiency.
label_ids, inverse = np.unique(arr, return_inverse=True)
flat_inverse = inverse.ravel()
counts = np.bincount(flat_inverse) # per-label pixel counts

# indexing="ij" (matrix convention) ensures the i-th grid varies along the i-th
# dimension of the output, correctly aligning with xdata.dims for any number of axes.
coord_grids = np.meshgrid(*[xdata[ax].values for ax in axes], indexing="ij")
data: dict[str, np.ndarray] = {}
for ax, grid in zip(axes, coord_grids, strict=True):
coord_sums = np.bincount(flat_inverse, weights=grid.ravel().astype(float))
data[ax] = coord_sums / counts # counts > 0 by construction (unique guarantees this)

return pd.DataFrame(data, index=label_ids)


@get_centroids.register(DataArray)
Expand All @@ -109,10 +97,7 @@ def _(
assert len(e["scale0"]) == 1
e = next(iter(e["scale0"].values()))

dfs = []
for axis in get_axes_names(e):
dfs.append(_get_centroids_for_axis(e, axis))
df = pd.concat(dfs, axis=1)
df = _get_centroids_for_labels(e)
if not return_background and 0 in df.index:
df = df.drop(index=0) # drop the background label
t = get_transformation(e, coordinate_system)
Expand Down
8 changes: 6 additions & 2 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,14 +1250,17 @@ def parse(

def get_model(
e: SpatialElement,
validate: bool = True,
) -> Schema_t:
"""
Get the model for the given element.
Get the model for the given element. Validate using the model if `validate` is `True`.

Parameters
----------
e
The element.
validate
Whether to validate the element using the model.

Returns
-------
Expand All @@ -1268,7 +1271,8 @@ def _validate_and_return(
schema: Schema_t,
e: SpatialElement,
) -> Schema_t:
schema.validate(e)
if validate:
schema.validate(e)
return schema

if isinstance(e, DataArray | DataTree):
Expand Down
104 changes: 85 additions & 19 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from __future__ import annotations

from collections.abc import Sequence
import os

# Disable numba JIT for the test suite (the test data is small so initializing the JIT is slower than using plain
# Python)
os.environ.setdefault("NUMBA_DISABLE_JIT", "1")

import copy as _copy
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -33,6 +40,25 @@
TableModel,
)


def _fast_deepcopy_sdata(sd: SpatialData) -> SpatialData:
"""
Fast deepcopy for SpatialData objects in tests.

Uses copy.deepcopy (which skips model re-validation) and manually restores
the attrs that copy.deepcopy loses for DaskDataFrame (issue #503) and
GeoDataFrame (issue #286).
"""
points_attrs = {k: _copy.deepcopy(v._attrs) for k, v in sd.points.items()}
shapes_attrs = {k: _copy.deepcopy(v.attrs) for k, v in sd.shapes.items()}
sd_copy = _copy.deepcopy(sd)
for k, attrs in points_attrs.items():
sd_copy.points[k]._attrs = attrs
for k, attrs in shapes_attrs.items():
sd_copy.shapes[k].attrs = attrs
return sd_copy


SEED = 0
RNG = default_rng(seed=SEED)

Expand All @@ -41,16 +67,26 @@
POINT_PATH = Path(__file__).parent / "data/points.json"


@pytest.fixture()
def images() -> SpatialData:
@pytest.fixture(scope="session")
def _images_session() -> SpatialData:
return SpatialData(images=_get_images())


@pytest.fixture()
def labels() -> SpatialData:
def images(_images_session: SpatialData) -> SpatialData:
return _fast_deepcopy_sdata(_images_session)


@pytest.fixture(scope="session")
def _labels_session() -> SpatialData:
return SpatialData(labels=_get_labels())


@pytest.fixture()
def labels(_labels_session: SpatialData) -> SpatialData:
return _fast_deepcopy_sdata(_labels_session)


@pytest.fixture()
def shapes() -> SpatialData:
return SpatialData(shapes=_get_shapes())
Expand Down Expand Up @@ -87,8 +123,8 @@ def tables() -> list[AnnData]:
return _tables


@pytest.fixture()
def full_sdata() -> SpatialData:
@pytest.fixture(scope="session")
def _full_sdata_session() -> SpatialData:
return SpatialData(
images=_get_images(),
labels=_get_labels(),
Expand All @@ -98,6 +134,22 @@ def full_sdata() -> SpatialData:
)


@pytest.fixture()
def full_sdata(_full_sdata_session: SpatialData) -> SpatialData:
return _fast_deepcopy_sdata(_full_sdata_session)


@pytest.fixture(scope="session")
def _sdata_full_session() -> SpatialData:
return SpatialData(
images=_get_images(),
labels=_get_labels(),
shapes=_get_shapes(),
points=_get_points(),
tables=_get_tables(region="labels2d"),
)


@pytest.fixture(
# params=["labels"]
params=["full", "empty"]
Expand All @@ -110,15 +162,9 @@ def full_sdata() -> SpatialData:
]
# + ["empty_" + x for x in ["table"]] # TODO: empty table not supported yet
)
def sdata(request) -> SpatialData:
def sdata(request, _sdata_full_session: SpatialData) -> SpatialData:
if request.param == "full":
return SpatialData(
images=_get_images(),
labels=_get_labels(),
shapes=_get_shapes(),
points=_get_points(),
tables=_get_tables(region="labels2d"),
)
return _fast_deepcopy_sdata(_sdata_full_session)
if request.param == "empty":
return SpatialData()
return request.getfixturevalue(request.param)
Expand Down Expand Up @@ -304,18 +350,38 @@ def _get_new_table(spatial_element: None | str | Sequence[str], instance_id: Non
return TableModel.parse(adata=adata, spatial_element=spatial_element, instance_id=instance_id)


@pytest.fixture()
def labels_blobs() -> ArrayLike:
"""Create a 2D labels."""
@pytest.fixture(scope="session")
def _labels_blobs_session() -> ArrayLike:
return BlobsDataset()._labels_blobs()


@pytest.fixture()
def sdata_blobs() -> SpatialData:
def labels_blobs(_labels_blobs_session: ArrayLike) -> ArrayLike:
"""Create a 2D labels."""
return deepcopy(_labels_blobs_session)


@pytest.fixture(scope="session")
def _sdata_blobs_session() -> SpatialData:
from spatialdata.datasets import blobs

return deepcopy(blobs(256, 300, 3))
return blobs(256, 300, 3)


@pytest.fixture()
def sdata_blobs(_sdata_blobs_session: SpatialData) -> SpatialData:
"""Create a 2D labels."""
return _fast_deepcopy_sdata(_sdata_blobs_session)


@pytest.fixture()
def blobs_factory(_sdata_blobs_session: SpatialData) -> Callable[[], SpatialData]:
"""Return a factory that creates cheap fresh copies of the session-scoped blobs dataset."""

def _make() -> SpatialData:
return _fast_deepcopy_sdata(_sdata_blobs_session)

return _make


def _make_points(coordinates: np.ndarray) -> DaskDataFrame:
Expand Down
8 changes: 8 additions & 0 deletions tests/core/operations/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations

import os

# Disable numba JIT compilation for rasterize tests. Datashader (used by rasterize) triggers
# numba JIT on first call, costing ~1.4s per worker. Python-mode gives identical results for
# the small test data here — unlike real data, there is no throughput advantage from JIT.
os.environ.setdefault("NUMBA_DISABLE_JIT", "1")
Loading
Loading