diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index c6e611a82..b79c72f33 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -20,13 +20,10 @@ jobs: matrix: include: - {os: windows-latest, python: "3.11", dask-version: "2025.12.0", name: "min dask"} - - {os: windows-latest, python: "3.13", dask-version: "latest"} - {os: windows-latest, python: "3.14", dask-version: "latest"} - {os: ubuntu-latest, python: "3.11", dask-version: "latest"} - - {os: ubuntu-latest, python: "3.13", dask-version: "latest"} - {os: ubuntu-latest, python: "3.14", dask-version: "latest"} - {os: macos-latest, python: "3.11", dask-version: "latest"} - - {os: macos-latest, python: "3.13", prerelease: "allow", name: "prerelease"} - {os: macos-latest, python: "3.14", prerelease: "allow", name: "prerelease"} env: OS: ${{ matrix.os }} @@ -62,7 +59,7 @@ jobs: PLATFORM: ${{ matrix.os }} DISPLAY: :42 run: | - uv run pytest --cov --color=yes --cov-report=xml + uv run pytest --cov --color=yes --cov-report=xml -n auto --dist worksteal - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: diff --git a/pyproject.toml b/pyproject.toml index 07ec8140b..794b05021 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ test = [ "pytest", "pytest-cov", "pytest-mock", + "pytest-xdist", "torch", ] docs = [ @@ -102,6 +103,7 @@ strict = true addopts = [ "--import-mode=importlib", # allow using test files with same name "-s", # print output from tests + "-p no:napari", # napari registers a pytest plugin via its entry point; disable it here since spatialdata tests don't need it ] # These are all markers coming from xarray, dask or anndata. Added here to silence warnings. markers = [ diff --git a/src/spatialdata/_core/_elements.py b/src/spatialdata/_core/_elements.py index a3267a701..846b8b056 100644 --- a/src/spatialdata/_core/_elements.py +++ b/src/spatialdata/_core/_elements.py @@ -21,7 +21,6 @@ PointsModel, ShapesModel, TableModel, - get_axes_names, get_model, ) @@ -50,6 +49,7 @@ def _check_key(key: str, element_keys: Iterable[str], shared_keys: set[str | Non raise KeyError(*e.args) from e def __setitem__(self, key: str, value: T) -> None: + # note that each __setitem__ in the subclasses calls get_model(), which performs data validation self._add_shared_key(key) super().__setitem__(key, value) @@ -72,15 +72,7 @@ def __setitem__(self, key: str, value: Raster_T) -> None: schema = get_model(value) if schema not in (Image2DModel, Image3DModel): raise TypeError(f"Unknown element type with schema: {schema!r}.") - ndim = len(get_axes_names(value)) - if ndim == 3: - Image2DModel.validate(value) - super().__setitem__(key, value) - elif ndim == 4: - Image3DModel.validate(value) - super().__setitem__(key, value) - else: - NotImplementedError("TODO: implement for ndim > 4.") + super().__setitem__(key, value) class Labels(Elements[DataArray | DataTree]): @@ -89,15 +81,7 @@ def __setitem__(self, key: str, value: Raster_T) -> None: schema = get_model(value) if schema not in (Labels2DModel, Labels3DModel): raise TypeError(f"Unknown element type with schema: {schema!r}.") - ndim = len(get_axes_names(value)) - if ndim == 2: - Labels2DModel.validate(value) - super().__setitem__(key, value) - elif ndim == 3: - Labels3DModel.validate(value) - super().__setitem__(key, value) - else: - NotImplementedError("TODO: implement for ndim > 3.") + super().__setitem__(key, value) class Shapes(Elements[GeoDataFrame]): @@ -106,7 +90,6 @@ def __setitem__(self, key: str, value: GeoDataFrame) -> None: schema = get_model(value) if schema != ShapesModel: raise TypeError(f"Unknown element type with schema: {schema!r}.") - ShapesModel.validate(value) super().__setitem__(key, value) @@ -116,7 +99,6 @@ def __setitem__(self, key: str, value: DaskDataFrame) -> None: schema = get_model(value) if schema != PointsModel: raise TypeError(f"Unknown element type with schema: {schema!r}.") - PointsModel.validate(value) super().__setitem__(key, value) @@ -126,5 +108,4 @@ def __setitem__(self, key: str, value: AnnData) -> None: schema = get_model(value) if schema != TableModel: raise TypeError(f"Unknown element type with schema: {schema!r}.") - TableModel.validate(value) super().__setitem__(key, value) diff --git a/src/spatialdata/_core/centroids.py b/src/spatialdata/_core/centroids.py index 24ea616c6..cde583e15 100644 --- a/src/spatialdata/_core/centroids.py +++ b/src/spatialdata/_core/centroids.py @@ -1,9 +1,8 @@ from __future__ import annotations -from collections import defaultdict from functools import singledispatch -import dask.array as da +import numpy as np import pandas as pd import xarray as xr from dask.dataframe import DataFrame as DaskDataFrame @@ -56,40 +55,29 @@ def get_centroids( raise ValueError(f"The object type {type(e)} is not supported.") -def _get_centroids_for_axis(xdata: xr.DataArray, axis: str) -> pd.DataFrame: +def _get_centroids_for_labels(xdata: xr.DataArray) -> pd.DataFrame: """ - Compute the component "axis" of the centroid of each label as a weighted average of the xarray coordinates. + Compute centroids for all labels in a DataArray in a single O(n_voxels) pass. - Parameters - ---------- - xdata - The xarray DataArray containing the labels. - axis - The axis for which the centroids are computed. - - Returns - ------- - pd.DataFrame - A DataFrame containing one column, named after "axis", with the centroids of the labels along that axis. - The index of the DataFrame is the collection of label values, sorted in ascending order. + Works for any number of spatial dimensions (2D and 3D labels). """ - centroids: dict[int, float] = defaultdict(float) - for i in xdata[axis]: - portion = xdata.sel(**{axis: i}).data - u = da.unique(portion, return_counts=True) - labels_values = u[0].compute() - counts = u[1].compute() - for j in range(len(labels_values)): - label_value = labels_values[j] - count = counts[j] - centroids[label_value] += count * i.values.item() - - all_labels_values, all_labels_counts = da.unique(xdata.data, return_counts=True) - all_labels = dict(zip(all_labels_values.compute(), all_labels_counts.compute(), strict=True)) - for label_value in centroids: - centroids[label_value] /= all_labels[label_value] - centroids = dict(sorted(centroids.items(), key=lambda x: x[0])) - return pd.DataFrame({axis: centroids.values()}, index=list(centroids.keys())) + arr = xdata.data.compute() + axes = list(xdata.dims) + + # Map label values to a contiguous range for bincount efficiency. + label_ids, inverse = np.unique(arr, return_inverse=True) + flat_inverse = inverse.ravel() + counts = np.bincount(flat_inverse) # per-label pixel counts + + # indexing="ij" (matrix convention) ensures the i-th grid varies along the i-th + # dimension of the output, correctly aligning with xdata.dims for any number of axes. + coord_grids = np.meshgrid(*[xdata[ax].values for ax in axes], indexing="ij") + data: dict[str, np.ndarray] = {} + for ax, grid in zip(axes, coord_grids, strict=True): + coord_sums = np.bincount(flat_inverse, weights=grid.ravel().astype(float)) + data[ax] = coord_sums / counts # counts > 0 by construction (unique guarantees this) + + return pd.DataFrame(data, index=label_ids) @get_centroids.register(DataArray) @@ -109,10 +97,7 @@ def _( assert len(e["scale0"]) == 1 e = next(iter(e["scale0"].values())) - dfs = [] - for axis in get_axes_names(e): - dfs.append(_get_centroids_for_axis(e, axis)) - df = pd.concat(dfs, axis=1) + df = _get_centroids_for_labels(e) if not return_background and 0 in df.index: df = df.drop(index=0) # drop the background label t = get_transformation(e, coordinate_system) diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 2bfcf88ce..0b6c32013 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -1250,14 +1250,17 @@ def parse( def get_model( e: SpatialElement, + validate: bool = True, ) -> Schema_t: """ - Get the model for the given element. + Get the model for the given element. Validate using the model if `validate` is `True`. Parameters ---------- e The element. + validate + Whether to validate the element using the model. Returns ------- @@ -1268,7 +1271,8 @@ def _validate_and_return( schema: Schema_t, e: SpatialElement, ) -> Schema_t: - schema.validate(e) + if validate: + schema.validate(e) return schema if isinstance(e, DataArray | DataTree): diff --git a/tests/conftest.py b/tests/conftest.py index c97939129..3bd5425d2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,13 @@ from __future__ import annotations -from collections.abc import Sequence +import os + +# Disable numba JIT for the test suite (the test data is small so initializing the JIT is slower than using plain +# Python) +os.environ.setdefault("NUMBA_DISABLE_JIT", "1") + +import copy as _copy +from collections.abc import Callable, Sequence from pathlib import Path from typing import Any @@ -33,6 +40,25 @@ TableModel, ) + +def _fast_deepcopy_sdata(sd: SpatialData) -> SpatialData: + """ + Fast deepcopy for SpatialData objects in tests. + + Uses copy.deepcopy (which skips model re-validation) and manually restores + the attrs that copy.deepcopy loses for DaskDataFrame (issue #503) and + GeoDataFrame (issue #286). + """ + points_attrs = {k: _copy.deepcopy(v._attrs) for k, v in sd.points.items()} + shapes_attrs = {k: _copy.deepcopy(v.attrs) for k, v in sd.shapes.items()} + sd_copy = _copy.deepcopy(sd) + for k, attrs in points_attrs.items(): + sd_copy.points[k]._attrs = attrs + for k, attrs in shapes_attrs.items(): + sd_copy.shapes[k].attrs = attrs + return sd_copy + + SEED = 0 RNG = default_rng(seed=SEED) @@ -41,16 +67,26 @@ POINT_PATH = Path(__file__).parent / "data/points.json" -@pytest.fixture() -def images() -> SpatialData: +@pytest.fixture(scope="session") +def _images_session() -> SpatialData: return SpatialData(images=_get_images()) @pytest.fixture() -def labels() -> SpatialData: +def images(_images_session: SpatialData) -> SpatialData: + return _fast_deepcopy_sdata(_images_session) + + +@pytest.fixture(scope="session") +def _labels_session() -> SpatialData: return SpatialData(labels=_get_labels()) +@pytest.fixture() +def labels(_labels_session: SpatialData) -> SpatialData: + return _fast_deepcopy_sdata(_labels_session) + + @pytest.fixture() def shapes() -> SpatialData: return SpatialData(shapes=_get_shapes()) @@ -87,8 +123,8 @@ def tables() -> list[AnnData]: return _tables -@pytest.fixture() -def full_sdata() -> SpatialData: +@pytest.fixture(scope="session") +def _full_sdata_session() -> SpatialData: return SpatialData( images=_get_images(), labels=_get_labels(), @@ -98,6 +134,22 @@ def full_sdata() -> SpatialData: ) +@pytest.fixture() +def full_sdata(_full_sdata_session: SpatialData) -> SpatialData: + return _fast_deepcopy_sdata(_full_sdata_session) + + +@pytest.fixture(scope="session") +def _sdata_full_session() -> SpatialData: + return SpatialData( + images=_get_images(), + labels=_get_labels(), + shapes=_get_shapes(), + points=_get_points(), + tables=_get_tables(region="labels2d"), + ) + + @pytest.fixture( # params=["labels"] params=["full", "empty"] @@ -110,15 +162,9 @@ def full_sdata() -> SpatialData: ] # + ["empty_" + x for x in ["table"]] # TODO: empty table not supported yet ) -def sdata(request) -> SpatialData: +def sdata(request, _sdata_full_session: SpatialData) -> SpatialData: if request.param == "full": - return SpatialData( - images=_get_images(), - labels=_get_labels(), - shapes=_get_shapes(), - points=_get_points(), - tables=_get_tables(region="labels2d"), - ) + return _fast_deepcopy_sdata(_sdata_full_session) if request.param == "empty": return SpatialData() return request.getfixturevalue(request.param) @@ -304,18 +350,38 @@ def _get_new_table(spatial_element: None | str | Sequence[str], instance_id: Non return TableModel.parse(adata=adata, spatial_element=spatial_element, instance_id=instance_id) -@pytest.fixture() -def labels_blobs() -> ArrayLike: - """Create a 2D labels.""" +@pytest.fixture(scope="session") +def _labels_blobs_session() -> ArrayLike: return BlobsDataset()._labels_blobs() @pytest.fixture() -def sdata_blobs() -> SpatialData: +def labels_blobs(_labels_blobs_session: ArrayLike) -> ArrayLike: """Create a 2D labels.""" + return deepcopy(_labels_blobs_session) + + +@pytest.fixture(scope="session") +def _sdata_blobs_session() -> SpatialData: from spatialdata.datasets import blobs - return deepcopy(blobs(256, 300, 3)) + return blobs(256, 300, 3) + + +@pytest.fixture() +def sdata_blobs(_sdata_blobs_session: SpatialData) -> SpatialData: + """Create a 2D labels.""" + return _fast_deepcopy_sdata(_sdata_blobs_session) + + +@pytest.fixture() +def blobs_factory(_sdata_blobs_session: SpatialData) -> Callable[[], SpatialData]: + """Return a factory that creates cheap fresh copies of the session-scoped blobs dataset.""" + + def _make() -> SpatialData: + return _fast_deepcopy_sdata(_sdata_blobs_session) + + return _make def _make_points(coordinates: np.ndarray) -> DaskDataFrame: diff --git a/tests/core/operations/conftest.py b/tests/core/operations/conftest.py new file mode 100644 index 000000000..c6241996b --- /dev/null +++ b/tests/core/operations/conftest.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +import os + +# Disable numba JIT compilation for rasterize tests. Datashader (used by rasterize) triggers +# numba JIT on first call, costing ~1.4s per worker. Python-mode gives identical results for +# the small test data here — unlike real data, there is no throughput advantage from JIT. +os.environ.setdefault("NUMBA_DISABLE_JIT", "1") diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 68b538e0a..a61ae34e2 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -13,7 +13,6 @@ from spatialdata._core.operations._utils import transform_to_data_extent from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike -from spatialdata.datasets import blobs from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel, get_table_keys from spatialdata.testing import assert_elements_dict_are_identical, assert_spatial_data_objects_are_identical from spatialdata.transformations.operations import get_transformation, set_transformation @@ -369,9 +368,9 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None: @pytest.mark.parametrize("concatenate_tables", [True, False]) @pytest.mark.parametrize("obs_names_make_unique", [True, False]) -def test_concatenate_sdatas_from_iterable(concatenate_tables: bool, obs_names_make_unique: bool) -> None: - sdata0 = blobs() - sdata1 = blobs() +def test_concatenate_sdatas_from_iterable(blobs_factory, concatenate_tables: bool, obs_names_make_unique: bool) -> None: + sdata0 = blobs_factory() + sdata1 = blobs_factory() sdatas = {"sample0": sdata0, "sample1": sdata1} with pytest.raises(KeyError, match="Images must have unique names across the SpatialData objects"): @@ -435,8 +434,8 @@ def _get_table_and_poly(i: int) -> tuple[AnnData, GeoDataFrame]: assert len(sdata["table"]) == 4 * len(poly0_a) -def test_concatenate_sdatas_single_item() -> None: - sdata = blobs() +def test_concatenate_sdatas_single_item(sdata_blobs) -> None: + sdata = sdata_blobs def _n_elements(sdata: SpatialData) -> int: return len([0 for _, _, _ in sdata.gen_elements()]) @@ -450,9 +449,9 @@ def _n_elements(sdata: SpatialData) -> int: @pytest.mark.parametrize("merge_coordinate_systems_on_name", [True, False]) -def test_concatenate_merge_coordinate_systems_on_name(merge_coordinate_systems_on_name): - blob1 = blobs() - blob2 = blobs() +def test_concatenate_merge_coordinate_systems_on_name(blobs_factory, merge_coordinate_systems_on_name): + blob1 = blobs_factory() + blob2 = blobs_factory() if merge_coordinate_systems_on_name: with pytest.raises( @@ -529,9 +528,9 @@ def test_del_item(full_sdata: SpatialData) -> None: _ = full_sdata["not_present"] -def test_no_shared_transformations() -> None: +def test_no_shared_transformations(sdata_blobs) -> None: """Test transformation dictionary copy for transformations not to be shared.""" - sdata = blobs() + sdata = sdata_blobs element_name = "blobs_image" test_space = "test" set_transformation(sdata.images[element_name], Identity(), to_coordinate_system=test_space) @@ -626,7 +625,7 @@ def test_transform_to_data_extent(full_sdata: SpatialData, maintain_positioning: sdata = transform_to_data_extent( full_sdata, "global", - target_width=1000, + target_width=100, maintain_positioning=maintain_positioning, ) diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index 59a380f7a..a4e7077f8 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -146,7 +146,7 @@ def test_transform_raster(full_sdata: SpatialData, element_type: str, multiscale assert element_type == "labels" sdata = SpatialData(labels={k: v for k, v in full_sdata.labels.items() if isinstance(v, datatype)}) - affine = _get_affine(small_translation=False) + affine = _get_affine(small_translation=True) _postpone_transformation( sdata, from_coordinate_system="global", to_coordinate_system="transformed", transformation=affine diff --git a/tests/core/operations/test_vectorize.py b/tests/core/operations/test_vectorize.py index ae83f6c95..503b34e7a 100644 --- a/tests/core/operations/test_vectorize.py +++ b/tests/core/operations/test_vectorize.py @@ -19,7 +19,7 @@ from spatialdata.testing import assert_elements_are_identical # each of the tests operates on different elements, hence we can initialize the data once without conflicts -sdata = blobs() +sdata = blobs(length=128) # conversion from labels @@ -29,9 +29,9 @@ def test_labels_2d_to_circles(is_multiscale: bool) -> None: element = sdata[key] new_circles = to_circles(element) - assert np.isclose(new_circles.loc[1].geometry.x, 330.59258152354386) - assert np.isclose(new_circles.loc[1].geometry.y, 78.85026897788404) - assert np.isclose(new_circles.loc[1].radius, 69.229993) + assert np.isclose(new_circles.loc[1].geometry.x, 66.33699870633895) + assert np.isclose(new_circles.loc[1].geometry.y, 94.86610608020699) + assert np.isclose(new_circles.loc[1].radius, 15.686094) assert 7 not in new_circles.index @@ -83,10 +83,10 @@ def test_polygons_to_circles() -> None: data = { "geometry": [ - Point(315.8120722406787, 220.18894606643332), - Point(270.1386975678398, 417.8747936281634), + Point(78.95301806016967, 55.04723651660833), + Point(67.53467439195995, 104.46869840704085), ], - "radius": [16.608781, 17.541365], + "radius": [4.152195, 4.385341], } expected = ShapesModel.parse(GeoDataFrame(data, geometry="geometry")) @@ -99,10 +99,10 @@ def test_multipolygons_to_circles() -> None: data = { "geometry": [ - Point(340.37951022629096, 250.76310705786318), - Point(337.1680699150594, 316.39984581697314), + Point(85.09487755657274, 62.690776764465795), + Point(84.23037752020095, 79.09996145424327), ], - "radius": [23.488363, 19.059285], + "radius": [5.872091, 4.736710], } expected = ShapesModel.parse(GeoDataFrame(data, geometry="geometry")) assert_elements_are_identical(new_circles, expected) diff --git a/tests/core/test_get_attrs.py b/tests/core/test_get_attrs.py index a6b467fa9..b36444a33 100644 --- a/tests/core/test_get_attrs.py +++ b/tests/core/test_get_attrs.py @@ -3,14 +3,11 @@ import pandas as pd import pytest -from spatialdata.datasets import blobs - @pytest.fixture -def sdata_attrs(): - sdata = blobs() - sdata.attrs["test"] = {"a": {"b": 12}, "c": 8} - return sdata +def sdata_attrs(sdata_blobs): + sdata_blobs.attrs["test"] = {"a": {"b": 12}, "c": 8} + return sdata_blobs def test_get_attrs_as_is(sdata_attrs): @@ -70,7 +67,6 @@ def test_non_string_sep(sdata_attrs): sdata_attrs.get_attrs(key="test", sep=123) -def test_empty_attrs(): - sdata = blobs() +def test_empty_attrs(sdata_blobs): with pytest.raises(KeyError, match="was not found in sdata.attrs."): - sdata.get_attrs(key="test") + sdata_blobs.get_attrs(key="test") diff --git a/tests/dataloader/conftest.py b/tests/dataloader/conftest.py new file mode 100644 index 000000000..0df842af8 --- /dev/null +++ b/tests/dataloader/conftest.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +import os + +# Disable numba JIT compilation for dataloader tests. Datashader (used by rasterize) triggers +# numba JIT on first call, costing ~1.4s per worker. Python-mode gives identical results for +# the small test data here — unlike real data, there is no throughput advantage from JIT. +os.environ.setdefault("NUMBA_DISABLE_JIT", "1") diff --git a/tests/io/test_format.py b/tests/io/test_format.py index 3ef1319c3..906f0f1cc 100644 --- a/tests/io/test_format.py +++ b/tests/io/test_format.py @@ -237,6 +237,7 @@ def test_container_v1_to_v2(self, full_sdata): assert sdata_read_v2.has_consolidated_metadata() def test_channel_names_raster_images_v1_to_v2_to_v3(self, images): + images = images.subset(["image2d", "image2d_multiscale"]) with tempfile.TemporaryDirectory() as tmpdir: f1 = Path(tmpdir) / "data1.zarr" f2 = Path(tmpdir) / "data2.zarr" diff --git a/tests/io/test_metadata.py b/tests/io/test_metadata.py index dd0ae704f..c07f026c0 100644 --- a/tests/io/test_metadata.py +++ b/tests/io/test_metadata.py @@ -33,6 +33,7 @@ def test_save_transformations(full_sdata): @pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles"]) def test_validate_can_write_metadata_on_element(full_sdata, element_name): + full_sdata = full_sdata.subset([element_name]) with tempfile.TemporaryDirectory() as tmp_dir: # element not present in the SpatialData object with pytest.raises( @@ -60,6 +61,9 @@ def test_validate_can_write_metadata_on_element(full_sdata, element_name): @pytest.mark.parametrize("element_name", ["image2d", "labels2d", "points_0", "circles"]) def test_save_transformations_incremental(element_name, full_sdata, caplog): """test io for transformations""" + # image2d is the dask-backed anchor for `assert not sdata2.is_self_contained()`: circles + # has no dask backing files and would make that assertion pass trivially. + full_sdata = full_sdata.subset([element_name, "image2d"]) with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, f"{element_name}0.zarr") full_sdata.write(f0) @@ -148,6 +152,7 @@ def test_save_channel_names_incremental(images: SpatialData, write: str) -> None # test io for consolidated metadata def test_consolidated_metadata(full_sdata: SpatialData) -> None: + full_sdata = full_sdata.subset(["labels3d_multiscale_numpy", "image2d", "circles", "points_0", "table"]) with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, "data0.zarr") full_sdata.write(f0) diff --git a/tests/io/test_partial_read.py b/tests/io/test_partial_read.py index 28460046e..93003390b 100644 --- a/tests/io/test_partial_read.py +++ b/tests/io/test_partial_read.py @@ -51,7 +51,7 @@ def pytest_warns_multiple( yield -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def test_case(request: _pytest.fixtures.SubRequest): """ Fixture that helps to use fixtures as arguments in parametrize. @@ -84,7 +84,7 @@ def session_tmp_path(request: _pytest.fixtures.SubRequest) -> Path: return Path(directory) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_corrupted_elem_types_zgroup(session_tmp_path: Path) -> PartialReadTestCase: # Zarr v2 sdata = blobs() @@ -106,7 +106,7 @@ def sdata_with_corrupted_elem_types_zgroup(session_tmp_path: Path) -> PartialRea ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_corrupted_elem_types_zarr_json(session_tmp_path: Path) -> PartialReadTestCase: # Zarr v3 sdata = blobs() @@ -128,7 +128,7 @@ def sdata_with_corrupted_elem_types_zarr_json(session_tmp_path: Path) -> Partial ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_corrupted_zarr_json_elements(session_tmp_path: Path) -> PartialReadTestCase: # Zarr v3 # zarr.json is a zero-byte file, aborted during write, or contains invalid JSON syntax @@ -152,7 +152,7 @@ def sdata_with_corrupted_zarr_json_elements(session_tmp_path: Path) -> PartialRe ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_corrupted_zattrs_elements(session_tmp_path: Path) -> PartialReadTestCase: # Zarr v2 # .zattrs is a zero-byte file, aborted during write, or contains invalid JSON syntax @@ -176,7 +176,7 @@ def sdata_with_corrupted_zattrs_elements(session_tmp_path: Path) -> PartialReadT ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_corrupted_image_chunks_zarrv3(session_tmp_path: Path) -> PartialReadTestCase: # images/blobs_image/0 is a zero-byte file or aborted during write sdata = blobs() @@ -198,7 +198,7 @@ def sdata_with_corrupted_image_chunks_zarrv3(session_tmp_path: Path) -> PartialR ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_corrupted_image_chunks_zarrv2(session_tmp_path: Path) -> PartialReadTestCase: # images/blobs_image/0 is a zero-byte file or aborted during write sdata = blobs() @@ -219,7 +219,7 @@ def sdata_with_corrupted_image_chunks_zarrv2(session_tmp_path: Path) -> PartialR ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_corrupted_parquet_zarrv3(session_tmp_path: Path) -> PartialReadTestCase: # points/blobs_points/0 is a zero-byte file or aborted during write sdata = blobs() @@ -243,7 +243,7 @@ def sdata_with_corrupted_parquet_zarrv3(session_tmp_path: Path) -> PartialReadTe ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_corrupted_parquet_zarrv2(session_tmp_path: Path) -> PartialReadTestCase: # points/blobs_points/0 is a zero-byte file or aborted during write sdata = blobs() @@ -267,7 +267,7 @@ def sdata_with_corrupted_parquet_zarrv2(session_tmp_path: Path) -> PartialReadTe ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_missing_zarr_json_element(session_tmp_path: Path) -> PartialReadTestCase: # zarr.json is missing sdata = blobs() @@ -286,7 +286,7 @@ def sdata_with_missing_zarr_json_element(session_tmp_path: Path) -> PartialReadT ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_missing_zattrs_element(session_tmp_path: Path) -> PartialReadTestCase: # Zarrv2 # .zattrs is missing @@ -306,7 +306,7 @@ def sdata_with_missing_zattrs_element(session_tmp_path: Path) -> PartialReadTest ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_missing_image_chunks_zarrv3( session_tmp_path: Path, ) -> PartialReadTestCase: @@ -328,7 +328,7 @@ def sdata_with_missing_image_chunks_zarrv3( ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_missing_image_chunks_zarrv2( session_tmp_path: Path, ) -> PartialReadTestCase: @@ -352,7 +352,7 @@ def sdata_with_missing_image_chunks_zarrv2( ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_invalid_zattrs_element_violating_spec(session_tmp_path: Path) -> PartialReadTestCase: # Zarr v2 # .zattrs contains readable JSON which is not valid for SpatialData/NGFF specs @@ -375,7 +375,7 @@ def sdata_with_invalid_zattrs_element_violating_spec(session_tmp_path: Path) -> ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_invalid_zarr_json_element_violating_spec(session_tmp_path: Path) -> PartialReadTestCase: # zarr.json contains readable JSON which is not valid for SpatialData/NGFF specs # for example due to a missing/misspelled/renamed key @@ -427,12 +427,12 @@ def _create_sdata_with_table_region_not_found(session_tmp_path: Path, zarr_versi ) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_table_region_not_found_zarrv3(session_tmp_path: Path) -> PartialReadTestCase: return _create_sdata_with_table_region_not_found(session_tmp_path, zarr_version=3) -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def sdata_with_table_region_not_found_zarrv2(session_tmp_path: Path) -> PartialReadTestCase: return _create_sdata_with_table_region_not_found(session_tmp_path, zarr_version=2) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 209a43046..df1cbbb63 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -329,11 +329,7 @@ def test_incremental_io_on_disk( In particular the complex "dask-backed" case for workaround 1 could be simplified once """ - tmpdir = Path(tmp_path) / "incremental_io.zarr" - sdata = SpatialData() - sdata.write(tmpdir, sdata_formats=sdata_container_format) - - for name in [ + _elements = [ "image2d", "image3d_multiscale_xarray", "labels2d", @@ -341,7 +337,15 @@ def test_incremental_io_on_disk( "points_0", "multipoly", "table", - ]: + ] + # Reduce to only the elements under test so the fixture deepcopy stays small. + full_sdata = full_sdata.subset(_elements) + + tmpdir = Path(tmp_path) / "incremental_io.zarr" + sdata = SpatialData() + sdata.write(tmpdir, sdata_formats=sdata_container_format) + + for name in _elements: sdata[name] = full_sdata[name] sdata.write_element(name, sdata_formats=sdata_container_format) if dask_backed: @@ -514,25 +518,25 @@ def test_overwrite_fails_when_no_zarr_store_but_dask_backed_data( ): full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) - def test_overwrite_fails_when_zarr_store_present( - self, full_sdata, sdata_container_format: SpatialDataContainerFormatType - ): + def test_overwrite_fails_when_zarr_store_present(self, sdata_container_format: SpatialDataContainerFormatType): # addressing https://github.com/scverse/spatialdata/issues/137 with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f, sdata_formats=sdata_container_format) + # An empty store is enough to trigger both exceptions. + sdata = SpatialData() + sdata.write(f, sdata_formats=sdata_container_format) with pytest.raises( ValueError, match="The Zarr store already exists. Use `overwrite=True` to try overwriting the store.", ): - full_sdata.write(f, sdata_formats=sdata_container_format) + sdata.write(f, sdata_formats=sdata_container_format) with pytest.raises( ValueError, match=r"Details: the target path either contains, coincides or is contained in the current Zarr store", ): - full_sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) + sdata.write(f, overwrite=True, sdata_formats=sdata_container_format) # support for overwriting backed sdata has been temporarily removed # with tempfile.TemporaryDirectory() as tmpdir: @@ -745,6 +749,10 @@ def test_single_scale_image_roundtrip_stays_dataarray(tmp_path: Path) -> None: @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: + # image2d/labels2d/points_0 are used explicitly in the combined-element block below; + # circles covers the "else: assert self_contained" branch (non-dask-backed elements). + full_sdata = full_sdata.subset(["image2d", "labels2d", "points_0", "circles"]) + # data only in-memory, so the SpatialData object and all its elements are self-contained assert full_sdata.is_self_contained() description = full_sdata.elements_are_self_contained() @@ -849,6 +857,8 @@ def test_symmetric_difference_with_zarr_store( @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_change_path_of_subset(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: """A subset SpatialData object has not Zarr path associated, show that we can reassign the path""" + # points_0_1 is the extra element that stays only on disk, satisfying the only_on_disk > 0 assertion. + full_sdata = full_sdata.subset(["image2d", "labels2d", "points_0", "circles", "table", "points_0_1"]) with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") full_sdata.write(f, sdata_formats=sdata_container_format) @@ -952,6 +962,9 @@ def test_delete_element_from_disk( element_name: str, sdata_container_format: SpatialDataContainerFormatType, ) -> None: + # Reduce to only the element under test plus one extra to keep writes fast. + full_sdata = full_sdata.subset([element_name, "points_0_1"]) + # can't delete an element for a SpatialData object without associated Zarr store with pytest.raises(ValueError, match="The SpatialData object is not backed by a Zarr store."): full_sdata.delete_element_from_disk("image2d") @@ -1007,6 +1020,8 @@ def test_element_already_on_disk_different_type( # Attempting to perform and IO operation will trigger an error. # The checks assessed in this test will not be needed anymore after # https://github.com/scverse/spatialdata/issues/504 is addressed + # Only the single element under test needs to be on disk to create the type-mismatch state. + full_sdata = full_sdata.subset([element_name]) with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") full_sdata.write(f, sdata_formats=sdata_container_format) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 22e63cc1a..887466e09 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -924,9 +924,10 @@ def test_warning_on_large_chunks(): def test_categories_on_partitioned_dataframe(sdata_blobs: SpatialData): + rng = default_rng(seed=0) df = sdata_blobs["blobs_points"].compute() - df["genes"] = RNG.choice([f"gene_{i}" for i in range(200)], len(df)) - N_PARTITIONS = 200 + df["genes"] = rng.choice([f"gene_{i}" for i in range(10)], len(df)) + N_PARTITIONS = 10 ddf = dd.from_pandas(df, npartitions=N_PARTITIONS) ddf["genes"] = ddf["genes"].astype("category")