Skip to content

Commit

Permalink
rewrite api
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Jul 20, 2023
1 parent bfebddc commit c0408bb
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 50 deletions.
89 changes: 47 additions & 42 deletions anndata/experimental/merge.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,43 @@
import os
import shutil
from functools import singledispatch
from pathlib import Path
from typing import (
Any,
Callable,
Collection,
Iterable,
Literal,
Mapping,
Optional,
Sequence,
Union,
Literal,
Iterable,
Mapping,
MutableMapping,
)
import pandas as pd
import typing
from pathlib import Path

import numpy as np
from scipy.sparse import csr_matrix, csc_matrix
import pandas as pd
from scipy.sparse import csc_matrix, csr_matrix

from .._io.specs import read_elem, write_elem
from .._core.file_backing import to_memory
from .._core.merge import (
_resolve_dim,
MissingVal,
Reindexer,
StrategiesLiteral,
_resolve_dim,
concat_arrays,
gen_inner_reindexers,
gen_reindexer,
intersect_keys,
unify_dtypes,
merge_indices,
merge_dataframes,
merge_indices,
resolve_merge_strategy,
gen_reindexer,
gen_inner_reindexers,
concat_arrays,
MissingVal,
Reindexer,
unify_dtypes,
)
from .._core.sparse_dataset import SparseDataset
from .._core.file_backing import to_memory
from .._io.specs import read_elem, write_elem
from ..compat import H5Array, H5Group, ZarrArray, ZarrGroup
from . import read_dispatched
from ..compat import ZarrGroup, ZarrArray, H5Group, H5Array


SPARSE_MATRIX = {"csc_matrix", "csr_matrix"}

Expand Down Expand Up @@ -67,18 +68,18 @@ def _indices_equal(indices: Iterable[pd.Index]) -> bool:
def _gen_slice_to_append(
datasets: Sequence[SparseDataset],
reindexers,
max_loaded_sparse_elems: int,
max_loaded_elems: int,
axis=0,
fill_value=None,
):
for ds, ri in zip(datasets, reindexers):
n_slices = ds.shape[axis] * ds.shape[1 - axis] // max_loaded_sparse_elems
n_slices = ds.shape[axis] * ds.shape[1 - axis] // max_loaded_elems
if n_slices < 2:
yield (csr_matrix, csc_matrix)[axis](
ri(to_memory(ds), axis=1 - axis, fill_value=fill_value)
)
else:
slice_size = max_loaded_sparse_elems // ds.shape[1 - axis]
slice_size = max_loaded_elems // ds.shape[1 - axis]
if slice_size == 0:
slice_size = 1
rem_slices = ds.shape[axis]
Expand Down Expand Up @@ -197,7 +198,7 @@ def write_concat_sparse(
datasets: Sequence[SparseDataset],
output_group: Union[ZarrGroup, H5Group],
output_path: Union[ZarrGroup, H5Group],
max_loaded_sparse_elems: int,
max_loaded_elems: int,
axis: Literal[0, 1] = 0,
reindexers: Reindexer = None,
fill_value=None,
Expand All @@ -209,7 +210,7 @@ def write_concat_sparse(
datasets (Sequence[SparseDataset]): A sequence of SparseDataset objects to be concatenated.
output_group (Union[ZarrGroup, H5Group]): The output group where the concatenated dataset will be written.
output_path (Union[ZarrGroup, H5Group]): The output path where the concatenated dataset will be written.
max_loaded_sparse_elems (int): The maximum number of sparse elements to load at once.
max_loaded_elems (int): The maximum number of sparse elements to load at once.
axis (Literal[0, 1], optional): The axis along which the datasets should be concatenated.
Defaults to 0.
reindexers (Reindexer, optional): A reindexer object that defines the reindexing operation to be applied.
Expand All @@ -221,7 +222,7 @@ def write_concat_sparse(
elems = iter(datasets)
else:
elems = _gen_slice_to_append(
datasets, reindexers, max_loaded_sparse_elems, axis, fill_value
datasets, reindexers, max_loaded_elems, axis, fill_value
)
init_elem = next(elems)
write_elem(output_group, output_path, init_elem)
Expand All @@ -237,7 +238,7 @@ def _write_concat_mappings(
output_group: Union[ZarrGroup, H5Group],
keys,
path,
max_loaded_sparse_elems,
max_loaded_elems,
axis=0,
index=None,
reindexers=None,
Expand All @@ -263,15 +264,15 @@ def _write_concat_mappings(
index=index,
reindexers=reindexers,
fill_value=fill_value,
max_loaded_sparse_elems=max_loaded_sparse_elems,
max_loaded_elems=max_loaded_elems,
)


def _write_concat_arrays(
arrays: Sequence[Union[ZarrArray, H5Array, SparseDataset]],
output_group,
output_path,
max_loaded_sparse_elems,
max_loaded_elems,
axis=0,
reindexers=None,
fill_value=None,
Expand All @@ -297,7 +298,7 @@ def _write_concat_arrays(
arrays,
output_group,
output_path,
max_loaded_sparse_elems,
max_loaded_elems,
axis,
reindexers,
fill_value,
Expand All @@ -316,7 +317,7 @@ def _write_concat_sequence(
arrays: Sequence[Union[pd.DataFrame, SparseDataset, H5Array, ZarrArray]],
output_group,
output_path,
max_loaded_sparse_elems,
max_loaded_elems,
axis=0,
index=None,
reindexers=None,
Expand Down Expand Up @@ -354,7 +355,7 @@ def _write_concat_sequence(
arrays,
output_group,
output_path,
max_loaded_sparse_elems,
max_loaded_elems,
axis,
reindexers,
fill_value,
Expand Down Expand Up @@ -399,22 +400,26 @@ def _write_dim_annot(groups, output_group, dim, concat_indices, label, label_col
def concat_on_disk(
in_files: Union[
Collection[Union[str, os.PathLike]],
typing.MutableMapping,
MutableMapping[str, Union[str, os.PathLike]],
],
out_file: Union[str, os.PathLike],
overwrite: bool = False,
max_loaded_sparse_elems: int = 100_000_000,
*,
overwrite: bool = False,
max_loaded_elems: int = 100_000_000,
axis: Literal[0, 1] = 0,
join: Literal["inner", "outer"] = "inner",
merge: Union[StrategiesLiteral, Callable, None] = None,
uns_merge: Union[StrategiesLiteral, Callable, None] = None,
merge: Union[
StrategiesLiteral, Callable[[Collection[Mapping]], Mapping], None
] = None,
uns_merge: Union[
StrategiesLiteral, Callable[[Collection[Mapping]], Mapping], None
] = None,
label: Optional[str] = None,
keys: Optional[Collection] = None,
keys: Optional[Collection[str]] = None,
index_unique: Optional[str] = None,
fill_value: Optional[Any] = None,
pairwise: bool = False,
):
) -> None:
"""Concatenates multiple AnnData objects along a specified axis using their
corresponding stores or paths, and writes the resulting AnnData object
to a target location on disk.
Expand All @@ -427,7 +432,7 @@ def concat_on_disk(
the `concat` function.
To adjust the maximum amount of data loaded in memory; for sparse
arrays use the max_loaded_sparse_elems argument; for dense arrays
arrays use the max_loaded_elems argument; for dense arrays
see the Dask documentation, as the Dask concatenation function is used
to concatenate dense arrays in this function
Params
Expand All @@ -441,7 +446,7 @@ def concat_on_disk(
overwrite
If `False` while a file already exists it will raise an error,
otherwise it will overwrite.
max_loaded_sparse_elems
max_loaded_elems
The maximum number of elements to load in memory when concatenating
sparse arrays. Note that this number also includes the empty entries.
Set to 100m by default meaning roughly 400mb will be loaded
Expand Down Expand Up @@ -590,7 +595,7 @@ def concat_on_disk(
axis=axis,
reindexers=reindexers,
fill_value=fill_value,
max_loaded_sparse_elems=max_loaded_sparse_elems,
max_loaded_elems=max_loaded_elems,
)

# Write Layers and {dim}m
Expand All @@ -610,7 +615,7 @@ def concat_on_disk(
output_group,
intersect_keys(maps),
m,
max_loaded_sparse_elems=max_loaded_sparse_elems,
max_loaded_elems=max_loaded_elems,
axis=m_axis,
index=m_index,
reindexers=m_reindexers,
Expand Down
16 changes: 8 additions & 8 deletions anndata/tests/test_concatenate_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def file_format(request):

# trying with 10 should be slow but will guarantee that the feature is being used
@pytest.fixture(params=[10, 100_000_000])
def max_loaded_sparse_elems(request):
def max_loaded_elems(request):
return request.param


Expand All @@ -79,15 +79,15 @@ def _adatas_to_paths(adatas, tmp_path, file_format):


def assert_eq_concat_on_disk(
adatas, tmp_path, file_format, max_loaded_sparse_elems=None, *args, **kwargs
adatas, tmp_path, file_format, max_loaded_elems=None, *args, **kwargs
):
# create one from the concat function
res1 = concat(adatas, *args, **kwargs)
# create one from the on disk concat function
paths = _adatas_to_paths(adatas, tmp_path, file_format)
out_name = tmp_path / ("out." + file_format)
if max_loaded_sparse_elems is not None:
kwargs["max_loaded_sparse_elems"] = max_loaded_sparse_elems
if max_loaded_elems is not None:
kwargs["max_loaded_elems"] = max_loaded_elems
concat_on_disk(paths, out_name, *args, **kwargs)
res2 = read_elem(as_group(out_name))
assert_equal(res1, res2, exact=False)
Expand All @@ -105,7 +105,7 @@ def get_array_type(array_type, axis):


def test_anndatas_without_reindex(
axis, array_type, join_type, tmp_path, max_loaded_sparse_elems, file_format
axis, array_type, join_type, tmp_path, max_loaded_elems, file_format
):
N = 50
M = 50
Expand All @@ -130,14 +130,14 @@ def test_anndatas_without_reindex(
adatas,
tmp_path,
file_format,
max_loaded_sparse_elems,
max_loaded_elems,
axis=axis,
join=join_type,
)


def test_anndatas_with_reindex(
axis, array_type, join_type, tmp_path, file_format, max_loaded_sparse_elems
axis, array_type, join_type, tmp_path, file_format, max_loaded_elems
):
N = 50
M = 50
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_anndatas_with_reindex(
adatas,
tmp_path,
file_format,
max_loaded_sparse_elems,
max_loaded_elems,
axis=axis,
join=join_type,
)
Expand Down

0 comments on commit c0408bb

Please sign in to comment.