From ccb9a4ffd0c4ec71f75aea042fc6a81a4de7ea70 Mon Sep 17 00:00:00 2001 From: Andreas Eisenbarth Date: Sat, 5 Oct 2024 14:58:25 +0200 Subject: [PATCH 1/3] Add performance test for issue 577 --- pyproject.toml | 1 + tests/io/test_pyramids_performance.py | 87 +++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 tests/io/test_pyramids_performance.py diff --git a/pyproject.toml b/pyproject.toml index 12ae2a3fb..3a075286c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ docs = [ test = [ "pytest", "pytest-cov", + "pytest-mock", ] torch = [ "torch" diff --git a/tests/io/test_pyramids_performance.py b/tests/io/test_pyramids_performance.py new file mode 100644 index 000000000..0f6c21ec0 --- /dev/null +++ b/tests/io/test_pyramids_performance.py @@ -0,0 +1,87 @@ +from pathlib import Path +from typing import TYPE_CHECKING, Union + +import dask +import dask.array +import numpy as np +import pytest +import xarray as xr +import zarr +from datatree import DataTree + +from spatialdata import SpatialData +from spatialdata._io import write_image +from spatialdata._io.format import CurrentRasterFormat +from spatialdata.models import Image2DModel + +if TYPE_CHECKING: + import _pytest.fixtures + + +@pytest.fixture +def sdata_with_image(request: "_pytest.fixtures.SubRequest", tmp_path: Path) -> SpatialData: + params = request.param if request.param is not None else {} + width = params.get("width", 2048) + chunksize = params.get("chunk_size", 1024) + scale_factors = params.get("scale_factors", (2,)) + # Create a disk-backed Dask array for scale 0. + npg = np.random.default_rng(0) + array = npg.integers(low=0, high=2**16, size=(1, width, width)) + array_path = tmp_path / "image.zarr" + dask.array.from_array(array).rechunk(chunksize).to_zarr(array_path) + array_backed = dask.array.from_zarr(array_path) + # Create an in-memory SpatialData with disk-backed scale 0. + image = Image2DModel.parse(array_backed, dims=("c", "y", "x"), scale_factors=scale_factors, chunks=chunksize) + return SpatialData(images={"image": image}) + + +def count_chunks(array: Union[xr.DataArray, xr.Dataset, DataTree]) -> int: + if isinstance(array, DataTree): + array = array.ds + # From `chunksizes`, we get only the number of chunks per axis. + # By multiplying them, we get the total number of chunks in 2D/3D. + return np.prod([len(chunk_sizes) for chunk_sizes in array.chunksizes.values()]) + + +@pytest.mark.parametrize( + ("sdata_with_image",), + [ + ({"width": 32, "chunk_size": 16, "scale_factors": (2,)},), + ({"width": 64, "chunk_size": 16, "scale_factors": (2, 2)},), + ({"width": 128, "chunk_size": 16, "scale_factors": (2, 2, 2)},), + ({"width": 256, "chunk_size": 16, "scale_factors": (2, 2, 2, 2)},), + ], + indirect=["sdata_with_image"], +) +def test_write_image_multiscale_performance(sdata_with_image: SpatialData, tmp_path: Path, mocker): + # Writing multiscale images with several pyramid levels should be efficient. + # Specifically, it should not read the input image more often than necessary + # (see issue https://github.com/scverse/spatialdata/issues/577). + # Instead of measuring the time (which would have high variation if not using big datasets), + # we watch the number of read and write accesses and compare to the theoretical number. + zarr_chunk_write_spy = mocker.spy(zarr.core.Array, "__setitem__") + zarr_chunk_read_spy = mocker.spy(zarr.core.Array, "__getitem__") + + image_name, image = next(iter(sdata_with_image.images.items())) + element_type_group = zarr.group(store=tmp_path / "sdata.zarr", path="/images") + + write_image( + image=image, + group=element_type_group, + name=image_name, + format=CurrentRasterFormat(), + ) + + # The number of chunks of scale level 0 + num_chunks_scale0 = count_chunks(image.scale0 if isinstance(image, DataTree) else image) + # The total number of chunks of all scale levels + num_chunks_all_scales = ( + sum(count_chunks(pyramid) for pyramid in image.children.values()) + if isinstance(image, DataTree) + else count_chunks(image) + ) + + actual_num_chunk_writes = zarr_chunk_write_spy.call_count + actual_num_chunk_reads = zarr_chunk_read_spy.call_count + assert actual_num_chunk_writes == num_chunks_all_scales + assert actual_num_chunk_reads == num_chunks_scale0 From 1be0159aee3af67cb311c0b1415e48c025895e75 Mon Sep 17 00:00:00 2001 From: Andreas Eisenbarth Date: Sat, 5 Oct 2024 14:30:28 +0200 Subject: [PATCH 2/3] Compute all pyramid levels at once --- src/spatialdata/_io/io_raster.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 6dda410c6..3e3c8bc87 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Any, Literal, Optional, Union +import dask.array as da import numpy as np import zarr from datatree import DataTree @@ -195,7 +196,7 @@ def _get_group_for_writing_transformations() -> zarr.Group: # coords = iterate_pyramid_levels(raster_data, "coords") parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format) storage_options = [{"chunks": chunk} for chunk in chunks] - write_multi_scale_ngff( + dask_delayed = write_multi_scale_ngff( pyramid=data, group=group_data, fmt=format, @@ -203,7 +204,10 @@ def _get_group_for_writing_transformations() -> zarr.Group: coordinate_transformations=None, storage_options=storage_options, **metadata, + compute=False, ) + # Compute all pyramid levels at once to allow Dask to optimize the computational graph. + da.compute(*dask_delayed) assert transformations is not None overwrite_coordinate_transformations_raster( group=_get_group_for_writing_transformations(), transformations=transformations, axes=tuple(input_axes) From f167e54e35fe2271fc2915acc06535fd3adf77f5 Mon Sep 17 00:00:00 2001 From: Andreas Eisenbarth Date: Sat, 5 Oct 2024 15:19:21 +0200 Subject: [PATCH 3/3] Add change log entry --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 26d8a5141..771cc0d29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning][]. - Added `shortest_path` parameter to `get_transformation_between_coordinate_systems` - Added `get_pyramid_levels()` utils API - Improved ergonomics of `concatenate()` when element names are non-unique #720 +- Improved performance of writing images with multiscales #577 ## [0.2.3] - 2024-09-25