Skip to content

Commit

Permalink
(feat): allow boolean indices to pass down to X (#1365)
Browse files Browse the repository at this point in the history
* (feat): `_normalize_index` can now return a boolean index

* (chore): add sparse test

* (chore): release note

* (chore): add test

* (chore): add access benchmarking

* (fix): counting optimizations

* (chore): move test + add `X` tests

* (fix): make sure optimization propagates to views of views
  • Loading branch information
ilan-gold committed Feb 23, 2024
1 parent 3f6a217 commit a478647
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 9 deletions.
14 changes: 9 additions & 5 deletions anndata/_core/index.py
Expand Up @@ -46,7 +46,7 @@ def _normalize_index(
| np.ndarray
| pd.Index,
index: pd.Index,
) -> slice | int | np.ndarray: # ndarray of int
) -> slice | int | np.ndarray: # ndarray of int or bool
if not isinstance(index, pd.RangeIndex):
assert (
index.dtype != float and index.dtype != int
Expand Down Expand Up @@ -92,8 +92,7 @@ def name_idx(i):
f"dimension. Boolean index has shape {indexer.shape} while "
f"AnnData index has shape {index.shape}."
)
positions = np.where(indexer)[0]
return positions # np.ndarray[int]
return indexer
else: # indexer should be string array
positions = index.get_indexer(indexer)
if np.any(positions < 0):
Expand Down Expand Up @@ -164,7 +163,10 @@ def _subset_dask(a: DaskArray, subset_idx: Index):
def _subset_spmatrix(a: spmatrix, subset_idx: Index):
# Correcting for indexing behaviour of sparse.spmatrix
if len(subset_idx) > 1 and all(isinstance(x, cabc.Iterable) for x in subset_idx):
subset_idx = (subset_idx[0].reshape(-1, 1), *subset_idx[1:])
first_idx = subset_idx[0]
if issubclass(first_idx.dtype.type, np.bool_):
first_idx = np.where(first_idx)[0]
subset_idx = (first_idx.reshape(-1, 1), *subset_idx[1:])
return a[subset_idx]


Expand All @@ -188,7 +190,9 @@ def _subset_dataset(d, subset_idx):
ordered = list(subset_idx)
rev_order = [slice(None) for _ in range(len(subset_idx))]
for axis, axis_idx in enumerate(ordered.copy()):
if isinstance(axis_idx, np.ndarray) and axis_idx.dtype.type != bool:
if isinstance(axis_idx, np.ndarray):
if axis_idx.dtype == bool:
axis_idx = np.where(axis_idx)[0]
order = np.argsort(axis_idx)
ordered[axis] = axis_idx[order]
rev_order[axis] = np.argsort(order)
Expand Down
4 changes: 4 additions & 0 deletions anndata/_core/views.py
Expand Up @@ -395,6 +395,10 @@ def _resolve_idx(old, new, l):

@_resolve_idx.register(np.ndarray)
def _resolve_idx_ndarray(old, new, l):
if is_bool_dtype(old) and is_bool_dtype(new):
mask_new = np.zeros_like(old)
mask_new[np.flatnonzero(old)[new]] = True
return mask_new
if is_bool_dtype(old):
old = np.where(old)[0]
return old[new]
Expand Down
33 changes: 29 additions & 4 deletions anndata/tests/test_backed_sparse.py
Expand Up @@ -138,12 +138,12 @@ def make_alternating_mask(size: int, step: int) -> np.ndarray:

# non-random indices, with alternating one false and n true
make_alternating_mask_5 = partial(make_alternating_mask, step=5)
make_alternating_mask_10 = partial(make_alternating_mask, step=10)
make_alternating_mask_15 = partial(make_alternating_mask, step=15)


def make_one_group_mask(size: int) -> np.ndarray:
one_group_mask = np.zeros(size, dtype=bool)
one_group_mask[size // 4 : size // 2] = True
one_group_mask[1 : size // 2] = True
return one_group_mask


Expand All @@ -158,12 +158,12 @@ def make_one_elem_mask(size: int) -> np.ndarray:
"make_bool_mask,should_trigger_optimization",
[
(make_randomized_mask, None),
(make_alternating_mask_10, True),
(make_alternating_mask_15, True),
(make_alternating_mask_5, False),
(make_one_group_mask, True),
(make_one_elem_mask, False),
],
ids=["randomized", "alternating_10", "alternating_5", "one_group", "one_elem"],
ids=["randomized", "alternating_15", "alternating_5", "one_group", "one_elem"],
)
def test_consecutive_bool(
mocker: MockerFixture,
Expand All @@ -188,6 +188,7 @@ def test_consecutive_bool(
mask = make_bool_mask(csr_disk.shape[0])

# indexing needs to be on `X` directly to trigger the optimization.

# `_normalize_indices`, which is used by `AnnData`, converts bools to ints with `np.where`
from anndata._core import sparse_dataset

Expand All @@ -202,6 +203,30 @@ def test_consecutive_bool(
assert (
spy.call_count == 2 if should_trigger_optimization else not spy.call_count
)
assert_equal(csr_disk[mask, :], csr_disk[np.where(mask)])
if should_trigger_optimization is not None:
assert (
spy.call_count == 3 if should_trigger_optimization else not spy.call_count
)
subset = csc_disk[:, mask]
assert_equal(subset, csc_disk[:, np.where(mask)[0]])
if should_trigger_optimization is not None:
assert (
spy.call_count == 4 if should_trigger_optimization else not spy.call_count
)
if should_trigger_optimization is not None and not csc_disk.isbacked:
size = subset.shape[1]
if should_trigger_optimization:
subset_subset_mask = np.ones(size).astype("bool")
subset_subset_mask[size // 2] = False
else:
subset_subset_mask = make_one_elem_mask(size)
assert_equal(
subset[:, subset_subset_mask], subset[:, np.where(subset_subset_mask)[0]]
)
assert (
spy.call_count == 5 if should_trigger_optimization else not spy.call_count
), f"Actual count: {spy.call_count}"


@pytest.mark.parametrize(
Expand Down
24 changes: 24 additions & 0 deletions anndata/tests/test_views.py
Expand Up @@ -551,6 +551,30 @@ def test_double_index(subset_func, subset_func2):
assert np.all(v1.var == v2.var)


def test_view_different_type_indices(matrix_type):
orig = gen_adata((30, 30), X_type=matrix_type)
boolean_array_mask = np.random.randint(0, 2, 30).astype("bool")
boolean_list_mask = boolean_array_mask.tolist()
integer_array_mask = np.where(boolean_array_mask)[0]
integer_list_mask = integer_array_mask.tolist()

assert_equal(orig[integer_array_mask, :], orig[boolean_array_mask, :])
assert_equal(orig[integer_list_mask, :], orig[boolean_list_mask, :])
assert_equal(orig[integer_list_mask, :], orig[integer_array_mask, :])
assert_equal(orig[:, integer_array_mask], orig[:, boolean_array_mask])
assert_equal(orig[:, integer_list_mask], orig[:, boolean_list_mask])
assert_equal(orig[:, integer_list_mask], orig[:, integer_array_mask])
# check that X element is same independent of access
assert_equal(orig[:, integer_list_mask].X, orig.X[:, integer_list_mask])
assert_equal(orig[:, boolean_list_mask].X, orig.X[:, boolean_list_mask])
assert_equal(orig[:, integer_array_mask].X, orig.X[:, integer_array_mask])
assert_equal(orig[:, integer_list_mask].X, orig.X[:, integer_list_mask])
assert_equal(orig[integer_list_mask, :].X, orig.X[integer_list_mask, :])
assert_equal(orig[boolean_list_mask, :].X, orig.X[boolean_list_mask, :])
assert_equal(orig[integer_array_mask, :].X, orig.X[integer_array_mask, :])
assert_equal(orig[integer_list_mask, :].X, orig.X[integer_list_mask, :])


def test_view_retains_ndarray_subclass():
adata = ad.AnnData(np.zeros((10, 10)))
adata.obsm["foo"] = np.zeros((10, 5)).view(NDArraySubclass)
Expand Down
8 changes: 8 additions & 0 deletions benchmarks/benchmarks/sparse_dataset.py
Expand Up @@ -4,6 +4,7 @@
import zarr
from scipy import sparse

from anndata import AnnData
from anndata.experimental import sparse_dataset, write_elem


Expand Down Expand Up @@ -38,9 +39,16 @@ def setup(self, shape, slice):
g = zarr.group()
write_elem(g, "X", X)
self.x = sparse_dataset(g["X"])
self.adata = AnnData(self.x)

def time_getitem(self, shape, slice):
self.x[self.slice]

def peakmem_getitem(self, shape, slice):
self.x[self.slice]

def time_getitem_adata(self, shape, slice):
self.adata[self.slice]

def peakmem_getitem_adata(self, shape, slice):
self.adata[self.slice]
1 change: 1 addition & 0 deletions docs/release-notes/0.10.6.md
Expand Up @@ -5,6 +5,7 @@

* Defer import of zarr in test helpers, as scanpy CI job relies on them {pr}`1343` {user}`ilan-gold`
* Writing a dataframe with non-unique column names now throws an error, instead of silently overwriting {pr}`1335` {user}`ivirshup`
* Bring optimization from {pr}`1233` to indexing on the whole `AnnData` object, not just the sparse dataset itself {pr}`1365` {user}`ilan-gold`
* Fix mean slice length checking to use improved performance when indexing backed sparse matrices with boolean masks along their major axis {pr}`1366` {user}`ilan-gold`

```{rubric} Documentation
Expand Down

0 comments on commit a478647

Please sign in to comment.