From 9b6717e34ae068cc101c8d76ddbebc7475a8140c Mon Sep 17 00:00:00 2001 From: Felix Fischer Date: Tue, 10 Feb 2026 16:03:34 +0100 Subject: [PATCH 1/3] Fix issue where loader.__len__ returns wrong len --- src/annbatch/abc/sampler.py | 17 ++++++++++++++++- src/annbatch/loader.py | 2 +- src/annbatch/samplers/_chunk_sampler.py | 5 +++++ tests/test_dataset.py | 4 ++++ tests/test_sampler.py | 7 +++++++ 5 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/annbatch/abc/sampler.py b/src/annbatch/abc/sampler.py index 590b70d4..7ee82b02 100644 --- a/src/annbatch/abc/sampler.py +++ b/src/annbatch/abc/sampler.py @@ -41,7 +41,7 @@ def batch_size(self) -> int | None: @abstractmethod def shuffle(self) -> bool: """Whether data is shuffled. - + If `batch_size` is provided and {attr}`annbatch.types.LoadRequest.splits` is not, in-memory loaded data will be shuffled or not based on this param. Shuffling of on-disk data is up to the user (controlled by `chunks` parameter in {class}`annbatch.types.LoadRequest`). @@ -52,6 +52,21 @@ def shuffle(self) -> bool: True if data should be shuffled, False otherwise. """ + @abstractmethod + def n_iters(self, n_obs: int) -> int: + """Return the number of batches. + + Parameters + ---------- + n_obs + The total number of observations available. + + Returns + ------- + int + The total number of batches this sampler will produce. + """ + def sample(self, n_obs: int) -> Iterator[LoadRequest]: """Sample load requests given the total number of observations. diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 167b8641..631fe0aa 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -212,7 +212,7 @@ def __init__( self._concat_strategy = concat_strategy def __len__(self) -> int: - return self.n_obs + return self._batch_sampler.n_iters(self.n_obs) @property def _sp_module(self) -> ModuleType: diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index a785e0bb..bb2ade56 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -97,6 +97,11 @@ def batch_size(self) -> int: def shuffle(self) -> bool: return self._shuffle + def n_iters(self, n_obs: int) -> int: + start, stop = self._mask.start or 0, self._mask.stop or n_obs + total_obs = stop - start + return total_obs // self._batch_size if self._drop_last else math.ceil(total_obs / self._batch_size) + def validate(self, n_obs: int) -> None: """Validate the sampler configuration against the loader's n_obs. diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 299c7d5d..e367268b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math from importlib.util import find_spec from types import NoneType from typing import TYPE_CHECKING, TypedDict @@ -517,6 +518,9 @@ class FailOnSecondValidateSampler(Sampler): def __init__(self): self._validate_count = 0 + def n_iters(self, n_obs: int) -> int: + return math.ceil(n_obs / self.batch_size) + def validate(self, n_obs: int) -> None: self._validate_count += 1 if self._validate_count > 1: diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 05299cc0..a25fff83 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -2,6 +2,8 @@ from __future__ import annotations +import math + import numpy as np import pytest @@ -275,6 +277,11 @@ def batch_size(self) -> int | None: def shuffle(self) -> bool | None: return self._shuffle + def n_iters(self, n_obs: int) -> int: + if self._batch_size is None or self._batch_size == 0: + return 1 + return math.ceil(n_obs / self._batch_size) + def validate(self, n_obs: int) -> None: """No validation needed for test sampler.""" pass From f73b6e27b1ee04daa29213ea5484a0d3f88d24e6 Mon Sep 17 00:00:00 2001 From: Felix Fischer Date: Tue, 10 Feb 2026 16:27:22 +0100 Subject: [PATCH 2/3] Add test to check for correctness of loader.__len__ --- tests/test_dataset.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index e367268b..0b9e3f5a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -323,6 +323,45 @@ def test_drop_last(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], np.testing.assert_allclose(X, X_expected) +@pytest.mark.parametrize("drop_last", [True, False], ids=["drop", "kept"]) +@pytest.mark.parametrize( + ("chunk_size", "preload_nchunks", "batch_size"), + [ + (10, 3, 10), # batch_size divides evenly + (14, 3, 21), # batch_size does not divide evenly + (10, 5, 25), # larger preload + ], +) +def test_len( + adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], + drop_last: bool, + chunk_size: int, + preload_nchunks: int, + batch_size: int, +): + zarr_path = next(adata_with_zarr_path_same_var_space[1].glob("*.zarr")) + data = open_sparse(zarr_path) + n_obs = data["dataset"].shape[0] + + loader = Loader( + shuffle=False, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + preload_to_gpu=False, + to_torch=False, + drop_last=drop_last, + ) + loader.add_dataset(**data) + + expected_len = n_obs // batch_size if drop_last else math.ceil(n_obs / batch_size) + assert len(loader) == expected_len + + # Also verify len matches the actual number of yielded batches + actual_batches = sum(1 for _ in loader) + assert len(loader) == actual_batches + + def test_bad_adata_X_hdf5(adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path]): with h5py.File(next(adata_with_h5_path_different_var_space[1].glob("*.h5ad"))) as f: data = ad.io.sparse_dataset(f["X"]) From 7ad4f0dac115a270c76b68d203767a9c434154f3 Mon Sep 17 00:00:00 2001 From: Felix Fischer Date: Tue, 10 Feb 2026 18:39:48 +0100 Subject: [PATCH 3/3] Removed unnecessary test parametrization --- tests/test_dataset.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 0b9e3f5a..b19b2a07 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -324,29 +324,17 @@ def test_drop_last(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], @pytest.mark.parametrize("drop_last", [True, False], ids=["drop", "kept"]) -@pytest.mark.parametrize( - ("chunk_size", "preload_nchunks", "batch_size"), - [ - (10, 3, 10), # batch_size divides evenly - (14, 3, 21), # batch_size does not divide evenly - (10, 5, 25), # larger preload - ], -) def test_len( adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], drop_last: bool, - chunk_size: int, - preload_nchunks: int, - batch_size: int, ): zarr_path = next(adata_with_zarr_path_same_var_space[1].glob("*.zarr")) data = open_sparse(zarr_path) n_obs = data["dataset"].shape[0] + batch_size = 32 loader = Loader( shuffle=False, - chunk_size=chunk_size, - preload_nchunks=preload_nchunks, batch_size=batch_size, preload_to_gpu=False, to_torch=False, @@ -356,7 +344,6 @@ def test_len( expected_len = n_obs // batch_size if drop_last else math.ceil(n_obs / batch_size) assert len(loader) == expected_len - # Also verify len matches the actual number of yielded batches actual_batches = sum(1 for _ in loader) assert len(loader) == actual_batches