diff --git a/anndata/_core/index.py b/anndata/_core/index.py index 15d840945..bdf1bfe27 100644 --- a/anndata/_core/index.py +++ b/anndata/_core/index.py @@ -46,7 +46,7 @@ def _normalize_index( | np.ndarray | pd.Index, index: pd.Index, -) -> slice | int | np.ndarray: # ndarray of int +) -> slice | int | np.ndarray: # ndarray of int or bool if not isinstance(index, pd.RangeIndex): assert ( index.dtype != float and index.dtype != int @@ -92,8 +92,7 @@ def name_idx(i): f"dimension. Boolean index has shape {indexer.shape} while " f"AnnData index has shape {index.shape}." ) - positions = np.where(indexer)[0] - return positions # np.ndarray[int] + return indexer else: # indexer should be string array positions = index.get_indexer(indexer) if np.any(positions < 0): @@ -164,7 +163,10 @@ def _subset_dask(a: DaskArray, subset_idx: Index): def _subset_spmatrix(a: spmatrix, subset_idx: Index): # Correcting for indexing behaviour of sparse.spmatrix if len(subset_idx) > 1 and all(isinstance(x, cabc.Iterable) for x in subset_idx): - subset_idx = (subset_idx[0].reshape(-1, 1), *subset_idx[1:]) + first_idx = subset_idx[0] + if issubclass(first_idx.dtype.type, np.bool_): + first_idx = np.where(first_idx)[0] + subset_idx = (first_idx.reshape(-1, 1), *subset_idx[1:]) return a[subset_idx] @@ -188,7 +190,9 @@ def _subset_dataset(d, subset_idx): ordered = list(subset_idx) rev_order = [slice(None) for _ in range(len(subset_idx))] for axis, axis_idx in enumerate(ordered.copy()): - if isinstance(axis_idx, np.ndarray) and axis_idx.dtype.type != bool: + if isinstance(axis_idx, np.ndarray): + if axis_idx.dtype == bool: + axis_idx = np.where(axis_idx)[0] order = np.argsort(axis_idx) ordered[axis] = axis_idx[order] rev_order[axis] = np.argsort(order) diff --git a/anndata/_core/views.py b/anndata/_core/views.py index 36faf5fbe..ce86a27ee 100644 --- a/anndata/_core/views.py +++ b/anndata/_core/views.py @@ -395,6 +395,10 @@ def _resolve_idx(old, new, l): @_resolve_idx.register(np.ndarray) def _resolve_idx_ndarray(old, new, l): + if is_bool_dtype(old) and is_bool_dtype(new): + mask_new = np.zeros_like(old) + mask_new[np.flatnonzero(old)[new]] = True + return mask_new if is_bool_dtype(old): old = np.where(old)[0] return old[new] diff --git a/anndata/tests/test_backed_sparse.py b/anndata/tests/test_backed_sparse.py index 05efa747c..7ce6860d1 100644 --- a/anndata/tests/test_backed_sparse.py +++ b/anndata/tests/test_backed_sparse.py @@ -138,12 +138,12 @@ def make_alternating_mask(size: int, step: int) -> np.ndarray: # non-random indices, with alternating one false and n true make_alternating_mask_5 = partial(make_alternating_mask, step=5) -make_alternating_mask_10 = partial(make_alternating_mask, step=10) +make_alternating_mask_15 = partial(make_alternating_mask, step=15) def make_one_group_mask(size: int) -> np.ndarray: one_group_mask = np.zeros(size, dtype=bool) - one_group_mask[size // 4 : size // 2] = True + one_group_mask[1 : size // 2] = True return one_group_mask @@ -158,12 +158,12 @@ def make_one_elem_mask(size: int) -> np.ndarray: "make_bool_mask,should_trigger_optimization", [ (make_randomized_mask, None), - (make_alternating_mask_10, True), + (make_alternating_mask_15, True), (make_alternating_mask_5, False), (make_one_group_mask, True), (make_one_elem_mask, False), ], - ids=["randomized", "alternating_10", "alternating_5", "one_group", "one_elem"], + ids=["randomized", "alternating_15", "alternating_5", "one_group", "one_elem"], ) def test_consecutive_bool( mocker: MockerFixture, @@ -188,6 +188,7 @@ def test_consecutive_bool( mask = make_bool_mask(csr_disk.shape[0]) # indexing needs to be on `X` directly to trigger the optimization. + # `_normalize_indices`, which is used by `AnnData`, converts bools to ints with `np.where` from anndata._core import sparse_dataset @@ -202,6 +203,30 @@ def test_consecutive_bool( assert ( spy.call_count == 2 if should_trigger_optimization else not spy.call_count ) + assert_equal(csr_disk[mask, :], csr_disk[np.where(mask)]) + if should_trigger_optimization is not None: + assert ( + spy.call_count == 3 if should_trigger_optimization else not spy.call_count + ) + subset = csc_disk[:, mask] + assert_equal(subset, csc_disk[:, np.where(mask)[0]]) + if should_trigger_optimization is not None: + assert ( + spy.call_count == 4 if should_trigger_optimization else not spy.call_count + ) + if should_trigger_optimization is not None and not csc_disk.isbacked: + size = subset.shape[1] + if should_trigger_optimization: + subset_subset_mask = np.ones(size).astype("bool") + subset_subset_mask[size // 2] = False + else: + subset_subset_mask = make_one_elem_mask(size) + assert_equal( + subset[:, subset_subset_mask], subset[:, np.where(subset_subset_mask)[0]] + ) + assert ( + spy.call_count == 5 if should_trigger_optimization else not spy.call_count + ), f"Actual count: {spy.call_count}" @pytest.mark.parametrize( diff --git a/anndata/tests/test_views.py b/anndata/tests/test_views.py index dca77265d..7ac4cfefc 100644 --- a/anndata/tests/test_views.py +++ b/anndata/tests/test_views.py @@ -551,6 +551,30 @@ def test_double_index(subset_func, subset_func2): assert np.all(v1.var == v2.var) +def test_view_different_type_indices(matrix_type): + orig = gen_adata((30, 30), X_type=matrix_type) + boolean_array_mask = np.random.randint(0, 2, 30).astype("bool") + boolean_list_mask = boolean_array_mask.tolist() + integer_array_mask = np.where(boolean_array_mask)[0] + integer_list_mask = integer_array_mask.tolist() + + assert_equal(orig[integer_array_mask, :], orig[boolean_array_mask, :]) + assert_equal(orig[integer_list_mask, :], orig[boolean_list_mask, :]) + assert_equal(orig[integer_list_mask, :], orig[integer_array_mask, :]) + assert_equal(orig[:, integer_array_mask], orig[:, boolean_array_mask]) + assert_equal(orig[:, integer_list_mask], orig[:, boolean_list_mask]) + assert_equal(orig[:, integer_list_mask], orig[:, integer_array_mask]) + # check that X element is same independent of access + assert_equal(orig[:, integer_list_mask].X, orig.X[:, integer_list_mask]) + assert_equal(orig[:, boolean_list_mask].X, orig.X[:, boolean_list_mask]) + assert_equal(orig[:, integer_array_mask].X, orig.X[:, integer_array_mask]) + assert_equal(orig[:, integer_list_mask].X, orig.X[:, integer_list_mask]) + assert_equal(orig[integer_list_mask, :].X, orig.X[integer_list_mask, :]) + assert_equal(orig[boolean_list_mask, :].X, orig.X[boolean_list_mask, :]) + assert_equal(orig[integer_array_mask, :].X, orig.X[integer_array_mask, :]) + assert_equal(orig[integer_list_mask, :].X, orig.X[integer_list_mask, :]) + + def test_view_retains_ndarray_subclass(): adata = ad.AnnData(np.zeros((10, 10))) adata.obsm["foo"] = np.zeros((10, 5)).view(NDArraySubclass) diff --git a/benchmarks/benchmarks/sparse_dataset.py b/benchmarks/benchmarks/sparse_dataset.py index ba2e4a71c..05daf0e81 100644 --- a/benchmarks/benchmarks/sparse_dataset.py +++ b/benchmarks/benchmarks/sparse_dataset.py @@ -4,6 +4,7 @@ import zarr from scipy import sparse +from anndata import AnnData from anndata.experimental import sparse_dataset, write_elem @@ -38,9 +39,16 @@ def setup(self, shape, slice): g = zarr.group() write_elem(g, "X", X) self.x = sparse_dataset(g["X"]) + self.adata = AnnData(self.x) def time_getitem(self, shape, slice): self.x[self.slice] def peakmem_getitem(self, shape, slice): self.x[self.slice] + + def time_getitem_adata(self, shape, slice): + self.adata[self.slice] + + def peakmem_getitem_adata(self, shape, slice): + self.adata[self.slice] diff --git a/docs/release-notes/0.10.6.md b/docs/release-notes/0.10.6.md index 1ebde67f3..e8618118a 100644 --- a/docs/release-notes/0.10.6.md +++ b/docs/release-notes/0.10.6.md @@ -5,6 +5,7 @@ * Defer import of zarr in test helpers, as scanpy CI job relies on them {pr}`1343` {user}`ilan-gold` * Writing a dataframe with non-unique column names now throws an error, instead of silently overwriting {pr}`1335` {user}`ivirshup` +* Bring optimization from {pr}`1233` to indexing on the whole `AnnData` object, not just the sparse dataset itself {pr}`1365` {user}`ilan-gold` * Fix mean slice length checking to use improved performance when indexing backed sparse matrices with boolean masks along their major axis {pr}`1366` {user}`ilan-gold` ```{rubric} Documentation