Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/annbatch/abc/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def batch_size(self) -> int | None:
@abstractmethod
def shuffle(self) -> bool:
"""Whether data is shuffled.

If `batch_size` is provided and {attr}`annbatch.types.LoadRequest.splits` is not, in-memory loaded data will be shuffled or not based on this param.

Shuffling of on-disk data is up to the user (controlled by `chunks` parameter in {class}`annbatch.types.LoadRequest`).
Expand All @@ -52,6 +52,21 @@ def shuffle(self) -> bool:
True if data should be shuffled, False otherwise.
"""

@abstractmethod
def n_iters(self, n_obs: int) -> int:
"""Return the number of batches.

Parameters
----------
n_obs
The total number of observations available.

Returns
-------
int
The total number of batches this sampler will produce.
"""

def sample(self, n_obs: int) -> Iterator[LoadRequest]:
"""Sample load requests given the total number of observations.

Expand Down
2 changes: 1 addition & 1 deletion src/annbatch/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def __init__(
self._concat_strategy = concat_strategy

def __len__(self) -> int:
return self.n_obs
return self._batch_sampler.n_iters(self.n_obs)

@property
def _sp_module(self) -> ModuleType:
Expand Down
5 changes: 5 additions & 0 deletions src/annbatch/samplers/_chunk_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ def batch_size(self) -> int:
def shuffle(self) -> bool:
return self._shuffle

def n_iters(self, n_obs: int) -> int:
start, stop = self._mask.start or 0, self._mask.stop or n_obs
total_obs = stop - start
return total_obs // self._batch_size if self._drop_last else math.ceil(total_obs / self._batch_size)

def validate(self, n_obs: int) -> None:
"""Validate the sampler configuration against the loader's n_obs.

Expand Down
30 changes: 30 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import math
from importlib.util import find_spec
from types import NoneType
from typing import TYPE_CHECKING, TypedDict
Expand Down Expand Up @@ -333,6 +334,32 @@ def test_drop_last(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path],
np.testing.assert_allclose(X, X_expected)


@pytest.mark.parametrize("drop_last", [True, False], ids=["drop", "kept"])
def test_len(
adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path],
drop_last: bool,
):
zarr_path = next(adata_with_zarr_path_same_var_space[1].glob("*.zarr"))
data = open_sparse(zarr_path)
n_obs = data["dataset"].shape[0]
batch_size = 32

loader = Loader(
shuffle=False,
batch_size=batch_size,
preload_to_gpu=False,
to_torch=False,
drop_last=drop_last,
)
loader.add_dataset(**data)

expected_len = n_obs // batch_size if drop_last else math.ceil(n_obs / batch_size)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this as a parametrized number explicitly. Also not sure why we're parametrizing on chunk size or number of chunks to preload. Can you remove? Seems unrelated

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's supposed to also check the n_iters function of the ChunkSampler. This covers different edge cases. I can remove it, if we only wanna check the n_iters.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But https://github.com/scverse/annbatch/pull/144/changes#diff-102eacca529505a20b2ae79d187762c2125e058343b16a2452f981a82f7c17bcR100-R103 doesn't use those parameters - it's just batch_size and n_obs dependent along with the mask - so maybe parametrize by mask start/end?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, so this rathe tests if the ChunkSampler handles this correctly. Thinking about this, this isn't really the point of this test. I've removed it now 👍

assert len(loader) == expected_len
# Also verify len matches the actual number of yielded batches
actual_batches = sum(1 for _ in loader)
assert len(loader) == actual_batches


def test_bad_adata_X_hdf5(adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path]):
with h5py.File(next(adata_with_h5_path_different_var_space[1].glob("*.h5ad"))) as f:
data = ad.io.sparse_dataset(f["X"])
Expand Down Expand Up @@ -579,6 +606,9 @@ class FailOnSecondValidateSampler(Sampler):
def __init__(self):
self._validate_count = 0

def n_iters(self, n_obs: int) -> int:
return math.ceil(n_obs / self.batch_size)

def validate(self, n_obs: int) -> None:
self._validate_count += 1
if self._validate_count > 1:
Expand Down
7 changes: 7 additions & 0 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import math

import numpy as np
import pytest

Expand Down Expand Up @@ -275,6 +277,11 @@ def batch_size(self) -> int | None:
def shuffle(self) -> bool | None:
return self._shuffle

def n_iters(self, n_obs: int) -> int:
if self._batch_size is None or self._batch_size == 0:
return 1
return math.ceil(n_obs / self._batch_size)

def validate(self, n_obs: int) -> None:
"""No validation needed for test sampler."""
pass
Expand Down
Loading