Skip to content

Commit

Permalink
(chore): move test + add X tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Feb 15, 2024
1 parent d80a16f commit 8d685d2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
15 changes: 0 additions & 15 deletions anndata/tests/test_inplace_subset.py
Expand Up @@ -68,21 +68,6 @@ def test_inplace_subset_obs(matrix_type, subset_func):
assert_equal(from_view.layers[k], modified.layers[k], exact=True)


def test_subset_different_types(matrix_type):
orig = gen_adata((30, 30), X_type=matrix_type)
boolean_array_mask = np.random.randint(0, 2, 30).astype("bool")
boolean_list_mask = boolean_array_mask.tolist()
integer_array_mask = np.where(boolean_array_mask)[0]
integer_list_mask = integer_array_mask.tolist()

assert_equal(orig[integer_array_mask, :], orig[boolean_array_mask, :])
assert_equal(orig[integer_list_mask, :], orig[boolean_list_mask, :])
assert_equal(orig[integer_list_mask, :], orig[integer_array_mask, :])
assert_equal(orig[:, integer_array_mask], orig[:, boolean_array_mask])
assert_equal(orig[:, integer_list_mask], orig[:, boolean_list_mask])
assert_equal(orig[:, integer_list_mask], orig[:, integer_array_mask])


@pytest.mark.parametrize("dim", ("obs", "var"))
def test_inplace_subset_no_X(subset_func, dim):
orig = gen_adata((30, 30))
Expand Down
24 changes: 24 additions & 0 deletions anndata/tests/test_views.py
Expand Up @@ -551,6 +551,30 @@ def test_double_index(subset_func, subset_func2):
assert np.all(v1.var == v2.var)


def test_view_different_type_indices(matrix_type):
orig = gen_adata((30, 30), X_type=matrix_type)
boolean_array_mask = np.random.randint(0, 2, 30).astype("bool")
boolean_list_mask = boolean_array_mask.tolist()
integer_array_mask = np.where(boolean_array_mask)[0]
integer_list_mask = integer_array_mask.tolist()

assert_equal(orig[integer_array_mask, :], orig[boolean_array_mask, :])
assert_equal(orig[integer_list_mask, :], orig[boolean_list_mask, :])
assert_equal(orig[integer_list_mask, :], orig[integer_array_mask, :])
assert_equal(orig[:, integer_array_mask], orig[:, boolean_array_mask])
assert_equal(orig[:, integer_list_mask], orig[:, boolean_list_mask])
assert_equal(orig[:, integer_list_mask], orig[:, integer_array_mask])
# check that X element is same independent of access
assert_equal(orig[:, integer_list_mask].X, orig.X[:, integer_list_mask])
assert_equal(orig[:, boolean_list_mask].X, orig.X[:, boolean_list_mask])
assert_equal(orig[:, integer_array_mask].X, orig.X[:, integer_array_mask])
assert_equal(orig[:, integer_list_mask].X, orig.X[:, integer_list_mask])
assert_equal(orig[integer_list_mask, :].X, orig.X[integer_list_mask, :])
assert_equal(orig[boolean_list_mask, :].X, orig.X[boolean_list_mask, :])
assert_equal(orig[integer_array_mask, :].X, orig.X[integer_array_mask, :])
assert_equal(orig[integer_list_mask, :].X, orig.X[integer_list_mask, :])


def test_view_retains_ndarray_subclass():
adata = ad.AnnData(np.zeros((10, 10)))
adata.obsm["foo"] = np.zeros((10, 5)).view(NDArraySubclass)
Expand Down

0 comments on commit 8d685d2

Please sign in to comment.