From 910a7d4fcf56ff8910d3a79ea7710c96f789cc49 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 16 Mar 2026 14:59:53 +0000 Subject: [PATCH 01/11] perf: use pinned memory with `preload_to_gpu` --- src/annbatch/loader.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 973a47d9..614f0db4 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -584,9 +584,14 @@ async def _fetch_data_dense(self, dataset: ZarrArray, slices: list[slice]) -> np for s in slices ] ) + buffer_prototype = zarr.core.buffer.default_buffer_prototype() + kwargs = dict(prototype=buffer_prototype) + if self._preload_to_gpu: + import cupyx as cpx + kwargs["out"] = buffer_prototype.nd_buffer(cpx.empty_pinned(indexer.shape, dataset.dtype)) res = cast( "np.ndarray", - await dataset._async_array._get_selection(indexer, prototype=zarr.core.buffer.default_buffer_prototype()), + await dataset._async_array._get_selection(indexer, **kwargs), ) return res @@ -661,9 +666,17 @@ async def _fetch_data_sparse( for l in indptr_limits ] ) + def get_kwargs(z: zarr.Array) -> dict: + kwargs = dict(prototype=zarr.core.buffer.default_buffer_prototype()) + if self._preload_to_gpu: + import cupyx as cpx + kwargs["out"] = buffer_prototype.nd_buffer(cpx.empty_pinned(indexer.shape, z.dtype)) + return kwargs data_np, indices_np = await asyncio.gather( - data._get_selection(indexer, prototype=zarr.core.buffer.default_buffer_prototype()), - indices._get_selection(indexer, prototype=zarr.core.buffer.default_buffer_prototype()), + **( + z._get_selection(indexer, **get_kwargs()) + for z in [data, indices] + ) ) gaps = (s1.start - s0.stop for s0, s1 in pairwise(indptr_limits)) offsets = accumulate(chain([indptr_limits[0].start], gaps)) From ec43fa755a7723337c865afba6c3bcf12b73af7c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Mar 2026 15:02:58 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/annbatch/loader.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 614f0db4..b2f52583 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -585,9 +585,10 @@ async def _fetch_data_dense(self, dataset: ZarrArray, slices: list[slice]) -> np ] ) buffer_prototype = zarr.core.buffer.default_buffer_prototype() - kwargs = dict(prototype=buffer_prototype) + kwargs = {"prototype": buffer_prototype} if self._preload_to_gpu: import cupyx as cpx + kwargs["out"] = buffer_prototype.nd_buffer(cpx.empty_pinned(indexer.shape, dataset.dtype)) res = cast( "np.ndarray", @@ -666,17 +667,17 @@ async def _fetch_data_sparse( for l in indptr_limits ] ) + def get_kwargs(z: zarr.Array) -> dict: - kwargs = dict(prototype=zarr.core.buffer.default_buffer_prototype()) + kwargs = {"prototype": zarr.core.buffer.default_buffer_prototype()} if self._preload_to_gpu: import cupyx as cpx + kwargs["out"] = buffer_prototype.nd_buffer(cpx.empty_pinned(indexer.shape, z.dtype)) return kwargs + data_np, indices_np = await asyncio.gather( - **( - z._get_selection(indexer, **get_kwargs()) - for z in [data, indices] - ) + **(z._get_selection(indexer, **get_kwargs()) for z in [data, indices]) ) gaps = (s1.start - s0.stop for s0, s1 in pairwise(indptr_limits)) offsets = accumulate(chain([indptr_limits[0].start], gaps)) From e113e4c78e4ce6caf52febb3a0f026e7c97830db Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 16 Mar 2026 15:16:30 +0000 Subject: [PATCH 03/11] fix: sparse --- src/annbatch/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index b2f52583..cfa8428a 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -677,7 +677,7 @@ def get_kwargs(z: zarr.Array) -> dict: return kwargs data_np, indices_np = await asyncio.gather( - **(z._get_selection(indexer, **get_kwargs()) for z in [data, indices]) + *(z._get_selection(indexer, **get_kwargs(z)) for z in [data, indices]) ) gaps = (s1.start - s0.stop for s0, s1 in pairwise(indptr_limits)) offsets = accumulate(chain([indptr_limits[0].start], gaps)) From e99c4b856334c425025d7943696c27e93d998ebc Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 16 Mar 2026 15:17:37 +0000 Subject: [PATCH 04/11] fix: what? --- src/annbatch/loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index cfa8428a..201cc212 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -669,7 +669,8 @@ async def _fetch_data_sparse( ) def get_kwargs(z: zarr.Array) -> dict: - kwargs = {"prototype": zarr.core.buffer.default_buffer_prototype()} + buffer_prototype = zarr.core.buffer.default_buffer_prototype() + kwargs = {"prototype": buffer_prototype} if self._preload_to_gpu: import cupyx as cpx From 888de84b7060f99b5f6ccdd959e575b061175d8d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 16 Mar 2026 15:20:36 +0000 Subject: [PATCH 05/11] fix: pandas docs? --- docs/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index be3b1bb6..9218e629 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -100,7 +100,7 @@ "scipy": ("https://docs.scipy.org/doc/scipy", None), "cupy": ("https://docs.cupy.dev/en/stable/", None), "zarrs": ("https://zarrs-python.readthedocs.io/en/latest/", None), - "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/version/2.3", None), "h5py": ("https://docs.h5py.org/en/latest", None), } From ca39cca30c8362a72d301b694025d784f25873cd Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Tue, 17 Mar 2026 14:33:05 +0100 Subject: [PATCH 06/11] refactor: kwarg function --- src/annbatch/loader.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index a56de068..d269c80d 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -60,7 +60,6 @@ def _cupy_dtype(dtype: np.dtype) -> np.dtype: return np.dtype("float32") return np.dtype("float64") - class Loader[ BackingArray: BackingArray_T, InputInMemoryArray: InputInMemoryArray_T, @@ -554,6 +553,16 @@ def _slices_to_slices_with_array_index( dataset_index_to_slices_sorted[k] = dataset_index_to_slices[k] return dataset_index_to_slices_sorted + + def _get_kwargs_for_zarr_fetching(z: zarr.Array) -> dict: + buffer_prototype = zarr.core.buffer.default_buffer_prototype() + kwargs = {"prototype": buffer_prototype} + if self._preload_to_gpu: + import cupyx as cpx + + kwargs["out"] = buffer_prototype.nd_buffer(cpx.empty_pinned(indexer.shape, z.dtype)) + return kwargs + @singledispatchmethod async def _fetch_data(self, dataset: ZarrArray | CSRDatasetElems, slices: list[slice]) -> InputInMemoryArray: """Fetch data from an on-disk store. @@ -588,15 +597,9 @@ async def _fetch_data_dense(self, dataset: ZarrArray, slices: list[slice]) -> np for s in slices ] ) - buffer_prototype = zarr.core.buffer.default_buffer_prototype() - kwargs = {"prototype": buffer_prototype} - if self._preload_to_gpu: - import cupyx as cpx - - kwargs["out"] = buffer_prototype.nd_buffer(cpx.empty_pinned(indexer.shape, dataset.dtype)) res = cast( "np.ndarray", - await dataset._async_array._get_selection(indexer, **kwargs), + await dataset._async_array._get_selection(indexer, **self._get_kwargs_for_zarr_fetching(dataset)), ) return res @@ -672,17 +675,8 @@ async def _fetch_data_sparse( ] ) - def get_kwargs(z: zarr.Array) -> dict: - buffer_prototype = zarr.core.buffer.default_buffer_prototype() - kwargs = {"prototype": buffer_prototype} - if self._preload_to_gpu: - import cupyx as cpx - - kwargs["out"] = buffer_prototype.nd_buffer(cpx.empty_pinned(indexer.shape, z.dtype)) - return kwargs - data_np, indices_np = await asyncio.gather( - *(z._get_selection(indexer, **get_kwargs(z)) for z in [data, indices]) + *(z._get_selection(indexer, **self._get_kwargs_for_zarr_fetching(z)) for z in [data, indices]) ) gaps = (s1.start - s0.stop for s0, s1 in pairwise(indptr_limits)) offsets = accumulate(chain([indptr_limits[0].start], gaps)) From 3c4bb8556e33f6c5ca47ed146a78603c6595aa42 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 13:33:15 +0000 Subject: [PATCH 07/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/annbatch/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index d269c80d..e9add3c3 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -60,6 +60,7 @@ def _cupy_dtype(dtype: np.dtype) -> np.dtype: return np.dtype("float32") return np.dtype("float64") + class Loader[ BackingArray: BackingArray_T, InputInMemoryArray: InputInMemoryArray_T, @@ -553,7 +554,6 @@ def _slices_to_slices_with_array_index( dataset_index_to_slices_sorted[k] = dataset_index_to_slices[k] return dataset_index_to_slices_sorted - def _get_kwargs_for_zarr_fetching(z: zarr.Array) -> dict: buffer_prototype = zarr.core.buffer.default_buffer_prototype() kwargs = {"prototype": buffer_prototype} From 8f1309fc4937ff56959dea608646974707d67619 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Tue, 17 Mar 2026 14:34:00 +0100 Subject: [PATCH 08/11] fix: self --- src/annbatch/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index e9add3c3..cb0cdd26 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -554,7 +554,7 @@ def _slices_to_slices_with_array_index( dataset_index_to_slices_sorted[k] = dataset_index_to_slices[k] return dataset_index_to_slices_sorted - def _get_kwargs_for_zarr_fetching(z: zarr.Array) -> dict: + def _get_kwargs_for_zarr_fetching(self, z: zarr.Array) -> dict: buffer_prototype = zarr.core.buffer.default_buffer_prototype() kwargs = {"prototype": buffer_prototype} if self._preload_to_gpu: From 3b9e64e6b04b927bd4fa7bc213296900672c3b7e Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Tue, 17 Mar 2026 14:35:08 +0100 Subject: [PATCH 09/11] fix: indexer shape --- src/annbatch/loader.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index cb0cdd26..bd067952 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -554,13 +554,13 @@ def _slices_to_slices_with_array_index( dataset_index_to_slices_sorted[k] = dataset_index_to_slices[k] return dataset_index_to_slices_sorted - def _get_kwargs_for_zarr_fetching(self, z: zarr.Array) -> dict: + def _get_kwargs_for_zarr_fetching(self, z: zarr.Array, indexer_shape: tuple[int]) -> dict: buffer_prototype = zarr.core.buffer.default_buffer_prototype() kwargs = {"prototype": buffer_prototype} if self._preload_to_gpu: import cupyx as cpx - kwargs["out"] = buffer_prototype.nd_buffer(cpx.empty_pinned(indexer.shape, z.dtype)) + kwargs["out"] = buffer_prototype.nd_buffer(cpx.empty_pinned(indexer_shape, z.dtype)) return kwargs @singledispatchmethod @@ -599,7 +599,7 @@ async def _fetch_data_dense(self, dataset: ZarrArray, slices: list[slice]) -> np ) res = cast( "np.ndarray", - await dataset._async_array._get_selection(indexer, **self._get_kwargs_for_zarr_fetching(dataset)), + await dataset._async_array._get_selection(indexer, **self._get_kwargs_for_zarr_fetching(dataset, indexer.shape)), ) return res @@ -676,7 +676,7 @@ async def _fetch_data_sparse( ) data_np, indices_np = await asyncio.gather( - *(z._get_selection(indexer, **self._get_kwargs_for_zarr_fetching(z)) for z in [data, indices]) + *(z._get_selection(indexer, **self._get_kwargs_for_zarr_fetching(z, indexer.shape)) for z in [data, indices]) ) gaps = (s1.start - s0.stop for s0, s1 in pairwise(indptr_limits)) offsets = accumulate(chain([indptr_limits[0].start], gaps)) From 51aab410f47fab15b1129e9476f11bd438799d2e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Mar 2026 13:35:15 +0000 Subject: [PATCH 10/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/annbatch/loader.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index bd067952..20656718 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -599,7 +599,9 @@ async def _fetch_data_dense(self, dataset: ZarrArray, slices: list[slice]) -> np ) res = cast( "np.ndarray", - await dataset._async_array._get_selection(indexer, **self._get_kwargs_for_zarr_fetching(dataset, indexer.shape)), + await dataset._async_array._get_selection( + indexer, **self._get_kwargs_for_zarr_fetching(dataset, indexer.shape) + ), ) return res @@ -676,7 +678,10 @@ async def _fetch_data_sparse( ) data_np, indices_np = await asyncio.gather( - *(z._get_selection(indexer, **self._get_kwargs_for_zarr_fetching(z, indexer.shape)) for z in [data, indices]) + *( + z._get_selection(indexer, **self._get_kwargs_for_zarr_fetching(z, indexer.shape)) + for z in [data, indices] + ) ) gaps = (s1.start - s0.stop for s0, s1 in pairwise(indptr_limits)) offsets = accumulate(chain([indptr_limits[0].start], gaps)) From ad2a72198b1ada67632bb2e380cb1b8b2e4cf1f9 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Tue, 17 Mar 2026 14:36:30 +0100 Subject: [PATCH 11/11] fix: typing --- src/annbatch/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 20656718..4ed3abbe 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -554,7 +554,7 @@ def _slices_to_slices_with_array_index( dataset_index_to_slices_sorted[k] = dataset_index_to_slices[k] return dataset_index_to_slices_sorted - def _get_kwargs_for_zarr_fetching(self, z: zarr.Array, indexer_shape: tuple[int]) -> dict: + def _get_kwargs_for_zarr_fetching(self, z: zarr.Array, indexer_shape: tuple[int, ...]) -> dict: buffer_prototype = zarr.core.buffer.default_buffer_prototype() kwargs = {"prototype": buffer_prototype} if self._preload_to_gpu: