Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning][].
### Fixed
- Formatted progress bar descriptions to be more readable.
- {class}`annbatch.DatasetCollection` now accepts a `rng` argument to the {meth}`annbatch.DatasetCollection.add_adatas` method.
- The ``sparse_chunk_size``, ``sparse_shard_size``, ``dense_chunk_size``, and ``dense_shard_size`` parameters of {func}`annbatch.write_sharded` have been replaced by ``n_obs_per_chunk`` (number of observations per chunk, automatically converted to element counts for sparse arrays) and ``shard_size`` (number of observations per shard or a size string). The corresponding parameters in {meth}`annbatch.DatasetCollection.add_adatas` are ``n_obs_per_chunk`` and ``zarr_shard_size``.
- `zarr_shard_size` in {meth}`annbatch.DatasetCollection.add_adatas` and `shard_size` in {func}`annbatch.write_sharded` now accept a human-readable size string (e.g. ``'1GB'``, ``'512MB'``) in addition to an integer number of observations. When a string is provided, the observation count is derived independently for each array element from its uncompressed bytes-per-row so that every shard stays close to the target size.


## [0.0.8]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ classifiers = [
dependencies = [
"anndata[lazy]>=0.12.9",
"dask>=2025.9",
"humanfriendly>=10",
"pandas>=2.2.2",
"scipy>1.15",
# for debug logging (referenced from the issue template)
Expand Down
178 changes: 106 additions & 72 deletions src/annbatch/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import pandas as pd
import scipy.sparse as sp
import zarr
from anndata._core.sparse_dataset import BaseCompressedSparseDataset
from anndata.experimental.backed import Dataset2D
from dask.array.core import Array as DaskArray
from humanfriendly import parse_size
from tqdm.auto import tqdm
from zarr.codecs import BloscCodec, BloscShuffle

Expand Down Expand Up @@ -59,14 +61,43 @@ def _round_down(num: int, divisor: int):
return num - (num % divisor)


def _shard_size_param_to_n_obs(shard_size: int | str, elem) -> int:
"""Convert `shard_size` to a number of observations given the size of an element from the anndata object.

If *shard_size* is already an int, it is interpreted as `n_obs`. When it is a
size string the target byte budget is divided by the element's
uncompressed bytes-per-observation-row.
"""
if isinstance(shard_size, int):
return shard_size
target_bytes = parse_size(shard_size, binary=True)

def _cs_bytes(x) -> int:
return int(x.data.nbytes + x.indptr.nbytes + x.indices.nbytes)

n_obs = elem.shape[0] if hasattr(elem, "shape") else len(elem)
if n_obs == 0:
return 1

if isinstance(elem, h5py.Dataset):
total_bytes = int(np.array(elem.shape).prod() * elem.dtype.itemsize)
elif isinstance(elem, BaseCompressedSparseDataset):
total_bytes = _cs_bytes(elem._to_backed())
elif sp.issparse(elem):
total_bytes = _cs_bytes(elem)
else:
total_bytes = elem.__sizeof__()

bytes_per_row = total_bytes / n_obs
return max(1, int(target_bytes / bytes_per_row)) if bytes_per_row > 0 else 1


def write_sharded(
group: zarr.Group,
adata: ad.AnnData,
*,
sparse_chunk_size: int = 32768,
sparse_shard_size: int = 134_217_728,
dense_chunk_size: int = 1024,
dense_shard_size: int = 4194304,
n_obs_per_chunk: int = 64,
shard_size: int | str = 2_097_152,
compressors: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),),
key: str | None = None,
):
Expand All @@ -78,14 +109,14 @@ def write_sharded(
The destination group, must be zarr v3
adata
The source anndata object
sparse_chunk_size
Chunk size of `indices` and `data` inside a shard.
sparse_shard_size
Shard size i.e., number of elements in a single sparse `data` or `indices` file.
dense_chunk_size
Number of obs elements per dense chunk along the first axis
dense_shard_size
Number of obs elements per dense shard along the first axis
n_obs_per_chunk
Number of observations per chunk. For dense arrays this directly sets the first-axis chunk size.
For sparse arrays it is converted to element counts using the average non-zero elements per row of the matrix being written.
shard_size
Number of observations per shard, or a size string (e.g. ``'1GB'``, ``'512MB'``).
If a size string is provided, the observation count is derived independently for each array element from its uncompressed bytes-per-row so that every shard stays close to the target size.
For dense arrays the resolved count directly sets the first-axis shard size.
For sparse arrays it is converted to element counts using the average non-zero elements per row of the matrix being written.
compressors
The compressors to pass to `zarr`.
key
Expand All @@ -107,25 +138,36 @@ def callback(
if iospec.encoding_type in {"array"} and (
any(n in store.name for n in {"obsm", "layers", "obsp"}) or "X" == elem_name
):
obs_per_shard = _shard_size_param_to_n_obs(shard_size, elem)
# Get either the desired size or the next multiple down to ensure divisibility of chunks and shards
shard_size = min(dense_shard_size, _round_down(elem.shape[0], dense_chunk_size))
chunk_size = min(dense_chunk_size, _round_down(elem.shape[0], dense_chunk_size))
# If the shape is less than the computed size (impossible given rounds?) or the rounding caused created a 0-size chunk, then error
if elem.shape[0] < chunk_size or chunk_size == 0:
dense_chunk = min(n_obs_per_chunk, _round_down(elem.shape[0], n_obs_per_chunk))
if elem.shape[0] < dense_chunk or dense_chunk == 0:
raise ValueError(
f"Choose a dense shard obs {dense_shard_size} and chunk obs {dense_chunk_size} with non-zero size less than the number of observations {elem.shape[0]}"
f"Choose a shard obs {shard_size} and chunk obs {n_obs_per_chunk} with non-zero size less than the number of observations {elem.shape[0]}"
)
dense_shard = min(obs_per_shard, _round_down(elem.shape[0], n_obs_per_chunk))
dense_shard = max(dense_chunk, _round_down(dense_shard, dense_chunk))
dataset_kwargs = {
**dataset_kwargs,
"shards": (shard_size,) + elem.shape[1:], # only shard over 1st dim
"chunks": (chunk_size,) + elem.shape[1:], # only chunk over 1st dim
"shards": (dense_shard,) + elem.shape[1:], # only shard over 1st dim
"chunks": (dense_chunk,) + elem.shape[1:], # only chunk over 1st dim
"compressors": compressors,
}
elif iospec.encoding_type in {"csr_matrix", "csc_matrix"}:
obs_per_shard = _shard_size_param_to_n_obs(shard_size, elem)
nnz = elem.nnz
if elem.shape[0] == 0:
raise ValueError(f"Cannot write sharded sparse matrix {elem_name!r} with 0 observations.")
avg_nnz_per_obs = nnz / elem.shape[0]
sparse_chunk = max(1, int(n_obs_per_chunk * avg_nnz_per_obs))
sparse_chunk = min(sparse_chunk, nnz) if nnz > 0 else sparse_chunk
sparse_shard = max(1, int(obs_per_shard * avg_nnz_per_obs))
sparse_shard = min(sparse_shard, nnz) if nnz > 0 else sparse_shard
sparse_shard = max(sparse_chunk, _round_down(sparse_shard, sparse_chunk))
dataset_kwargs = {
**dataset_kwargs,
"shards": (sparse_shard_size,),
"chunks": (sparse_chunk_size,),
"shards": (sparse_shard,),
"chunks": (sparse_chunk,),
"compressors": compressors,
}
write_func(store, elem_name, elem, dataset_kwargs=dataset_kwargs)
Expand Down Expand Up @@ -401,10 +443,8 @@ def add_adatas(
*,
load_adata: Callable[[zarr.Group | h5py.Group | PathLike[str] | str], ad.AnnData] = _default_load_adata,
var_subset: Iterable[str] | None = None,
zarr_sparse_chunk_size: int = 32768,
zarr_sparse_shard_size: int = 134_217_728,
zarr_dense_chunk_size: int = 1024,
zarr_dense_shard_size: int = 4_194_304,
n_obs_per_chunk: int = 64,
zarr_shard_size: int | str = "1GB",
zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),),
h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip",
n_obs_per_dataset: int = 2_097_152,
Expand Down Expand Up @@ -435,22 +475,22 @@ def add_adatas(
var_subset
Subset of gene names to include in the store. If None, all genes are included.
Genes are subset based on the `var_names` attribute of the concatenated AnnData object.
zarr_sparse_chunk_size
Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store.
zarr_sparse_shard_size
Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store.
zarr_dense_chunk_size
Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array.
zarr_dense_shard_size
Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array.
n_obs_per_chunk
Number of observations per zarr chunk. For dense arrays this is used directly as the first-axis chunk size.
For sparse arrays it is converted to element counts using the average number of non-zero elements per row of the matrix being written.
zarr_shard_size
Number of observations per zarr shard, or a size string (e.g. ``'1GB'``).
If a size string is provided, the number of obersevations per zarr shard is estimated automatically.
For sparse arrays the number of observations is converted to element counts using the average number of non-zero elements per row of the matrix being written
zarr_compressor
Compressors to use to compress the data in the zarr store.
h5ad_compressor
Compressors to use to compress the data in the h5ad store. See anndata.write_h5ad.
n_obs_per_dataset
Number of observations to load into memory at once for shuffling / pre-processing.
The higher this number, the more memory is used, but the better the shuffling.
This corresponds to the size of the shards created.
This corresponds to the size of the dataset level shards created.
Only applicable when adding datasets for the first time, otherwise ignored.
Only applicable when adding datasets for the first time, otherwise ignored.
shuffle
Whether to shuffle the data before writing it to the store.
Expand Down Expand Up @@ -484,16 +524,16 @@ def add_adatas(
...)
"""
if shuffle_chunk_size > n_obs_per_dataset:
raise ValueError("Cannot have a large slice size than observations per dataset")
raise ValueError(
"Cannot have a larger slice size than observations per dataset. Reduce `shuffle_chunk_size` or increase `n_obs_per_dataset`."
)
if rng is None:
rng = np.random.default_rng()
shared_kwargs = {
"adata_paths": adata_paths,
"load_adata": load_adata,
"zarr_sparse_chunk_size": zarr_sparse_chunk_size,
"zarr_sparse_shard_size": zarr_sparse_shard_size,
"zarr_dense_chunk_size": zarr_dense_chunk_size,
"zarr_dense_shard_size": zarr_dense_shard_size,
"n_obs_per_chunk": n_obs_per_chunk,
"zarr_shard_size": zarr_shard_size,
"zarr_compressor": zarr_compressor,
"h5ad_compressor": h5ad_compressor,
"shuffle_chunk_size": shuffle_chunk_size,
Expand All @@ -512,10 +552,8 @@ def _create_collection(
adata_paths: Iterable[PathLike[str]] | Iterable[str],
load_adata: Callable[[PathLike[str] | str], ad.AnnData] = _default_load_adata,
var_subset: Iterable[str] | None = None,
zarr_sparse_chunk_size: int = 32768,
zarr_sparse_shard_size: int = 134_217_728,
zarr_dense_chunk_size: int = 1024,
zarr_dense_shard_size: int = 4_194_304,
n_obs_per_chunk: int = 64,
zarr_shard_size: int | str = "1GB",
zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),),
h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip",
n_obs_per_dataset: int = 2_097_152,
Expand Down Expand Up @@ -545,14 +583,13 @@ def _create_collection(
Subset of gene names to include in the store. If None, all genes are included.
Genes are subset based on the `var_names` attribute of the concatenated AnnData object.
Only applicable when adding datasets for the first time, otherwise ignored and the incoming data's var space is subsetted to that of the existing collection.
zarr_sparse_chunk_size
Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store.
zarr_sparse_shard_size
Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store.
zarr_dense_chunk_size
Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array.
zarr_dense_shard_size
Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array.
n_obs_per_chunk
Number of observations per zarr chunk. For dense arrays this is used directly as the first-axis chunk size.
For sparse arrays it is converted to element counts using the average number of non-zero elements per row of the matrix being written.
zarr_shard_size
Number of observations per zarr shard, or a size string (e.g. ``'1GB'``).
If a size string is provided, the number of obersevations per zarr shard is estimated automatically.
For sparse arrays the number of observations is converted to element counts using the average number of non-zero elements per row of the matrix being written
zarr_compressor
Compressors to use to compress the data in the zarr store.
h5ad_compressor
Expand All @@ -572,6 +609,11 @@ def _create_collection(
"""
if not self.is_empty:
raise RuntimeError("Cannot create a collection at a location that already has a shuffled collection")
if shuffle_chunk_size > n_obs_per_dataset:
raise ValueError(
"Cannot have a larger slice size than observations per dataset. Reduce `shuffle_chunk_size` or increase `n_obs_per_dataset`."
)

_check_for_mismatched_keys(adata_paths, load_adata=load_adata)
adata_concat = _lazy_load_adata(adata_paths, load_adata=load_adata)
adata_concat.obs_names_make_unique()
Expand Down Expand Up @@ -600,10 +642,8 @@ def _create_collection(
write_sharded(
self._group,
adata_chunk,
sparse_chunk_size=zarr_sparse_chunk_size,
sparse_shard_size=zarr_sparse_shard_size,
dense_chunk_size=min(adata_chunk.shape[0], zarr_dense_chunk_size),
dense_shard_size=min(adata_chunk.shape[0], zarr_dense_shard_size),
n_obs_per_chunk=n_obs_per_chunk,
shard_size=zarr_shard_size,
compressors=zarr_compressor,
key=f"{DATASET_PREFIX}_{i}",
)
Expand All @@ -621,10 +661,8 @@ def _add_to_collection(
*,
adata_paths: Iterable[PathLike[str]] | Iterable[str],
load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.read_h5ad,
zarr_sparse_chunk_size: int = 32768,
zarr_sparse_shard_size: int = 134_217_728,
zarr_dense_chunk_size: int = 1024,
zarr_dense_shard_size: int = 4_194_304,
n_obs_per_chunk: int = 64,
zarr_shard_size: int | str = "1GB",
zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),),
h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip",
shuffle_chunk_size: int = 1000,
Expand All @@ -646,14 +684,13 @@ def _add_to_collection(
If you only need a subset of the input anndata files' elems (e.g., only `X` and `obs`), you can provide a custom function here to speed up loading and harmonize your data.
The input to the function is a path to an anndata file, and the output is an anndata object.
If the input data is too large to fit into memory, you should use :func:`annndata.experimental.read_lazy` instead.
zarr_sparse_chunk_size
Size of the chunks to use for the `indices` and `data` of a sparse matrix in the zarr store.
zarr_sparse_shard_size
Size of the shards to use for the `indices` and `data` of a sparse matrix in the zarr store.
zarr_dense_chunk_size
Number of observations per dense zarr chunk i.e., sharding is only done along the first axis of the array.
zarr_dense_shard_size
Number of observations per dense zarr shard i.e., chunking is only done along the first axis of the array.
n_obs_per_chunk
Number of observations per zarr chunk. For dense arrays this is used directly as the first-axis chunk size.
For sparse arrays it is converted to element counts using the average number of non-zero elements per row of the matrix being written.
zarr_shard_size
Number of observations per zarr shard, or a size string (e.g. ``'1GB'``).
If a size string is provided, the number of obersevations per zarr shard is estimated automatically.
For sparse arrays the number of observations is converted to element counts using the average number of non-zero elements per row of the matrix being written
zarr_compressor
Compressors to use to compress the data in the zarr store.
should_sparsify_output_in_memory
Expand All @@ -668,7 +705,6 @@ def _add_to_collection(
raise ValueError("Store is empty. Please run `DatasetCollection.add_adatas` first.")
# Check for mismatched keys among the inputs.
_check_for_mismatched_keys(adata_paths, load_adata=load_adata)

adata_concat = _lazy_load_adata(adata_paths, load_adata=load_adata)
if math.ceil(adata_concat.shape[0] / shuffle_chunk_size) < len(self._dataset_keys):
raise ValueError(
Expand Down Expand Up @@ -703,10 +739,8 @@ def _add_to_collection(
write_sharded(
self._group,
adata,
sparse_chunk_size=zarr_sparse_chunk_size,
sparse_shard_size=zarr_sparse_shard_size,
dense_chunk_size=min(adata.shape[0], zarr_dense_chunk_size),
dense_shard_size=min(adata.shape[0], zarr_dense_shard_size),
n_obs_per_chunk=n_obs_per_chunk,
shard_size=zarr_shard_size,
compressors=zarr_compressor,
key=dataset,
)
Expand Down
Loading
Loading