Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize writing performance for MultiscaleSpatialImage #577

Open
LucaMarconato opened this issue Jun 11, 2024 · 1 comment
Open

Optimize writing performance for MultiscaleSpatialImage #577

LucaMarconato opened this issue Jun 11, 2024 · 1 comment

Comments

@LucaMarconato
Copy link
Member

As observed by @ArneDefauw, unnecessary loading operations are performed when calling write_multiscale() on a list of lazy tensors derived from to_multiscale().

Optimizing the order in which the data is computed and written to disk, so to avoid the loading of the same chunks 2+ times, would probably lead to a drastic performance improvement, up to 10-fold.

@ArneDefauw
Copy link
Contributor

I include a minimal example to reproduce the observed behaviour.

If arr.persist() is called, the code completes in ~10s, but if arr.persist() is commented, the code compeletes in ~50 s (i.e. 10s for each scale -> some_function is called 5 times).
As an alternative to using .persist(), writing to a zarr store, and then loading it back, evidently 'solves' the problem in a similar way.


import os
import tempfile
import time

import dask.array as da
import numpy as np
import spatialdata
from spatialdata.datasets import blobs

sdata = blobs()

start = time.time()

with tempfile.TemporaryDirectory() as temp_dir:
    sdata.write(os.path.join(temp_dir, "sdata_blobs_dummy.zarr"))

    def _some_function(arr):
        arr = arr * 2
        time.sleep(10)
        return arr

    arr = sdata["blobs_image"].data

    arr = da.map_blocks(_some_function, arr, dtype=float, meta=np.array((), dtype=float))

    arr = arr.persist()

    # or as alternative to persist, write to intermediate zarr store
    # dask_zarr_path = os.path.join(temp_dir, "dask_array.zarr")
    # arr.to_zarr(dask_zarr_path, overwrite=True)
    # arr = da.from_zarr(dask_zarr_path)

    se = spatialdata.models.Image2DModel.parse(
        arr,
        scale_factors=[2, 2, 2, 2],
    )

    sdata["blobs_image_processed"] = se

    sdata.write_element("blobs_image_processed")


print(time.time() - start)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants