-
Notifications
You must be signed in to change notification settings - Fork 579
/
__init__.py
106 lines (85 loc) · 3.53 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
This file contains helper functions for the scanpy test suite.
"""
from itertools import permutations
import scanpy as sc
import numpy as np
import warnings
from anndata.tests.helpers import asarray, assert_equal
# TODO: Report more context on the fields being compared on error
# TODO: Allow specifying paths to ignore on comparison
###########################
# Representation choice
###########################
# These functions can be used to check that functions are correctly using arugments like `layers`, `obsm`, etc.
def check_rep_mutation(func, X, *, fields=("layer", "obsm"), **kwargs):
"""Check that only the array meant to be modified is modified."""
adata = sc.AnnData(X=X.copy())
for field in fields:
sc.get._set_obs_rep(adata, X, **{field: field})
X_array = asarray(X)
adata_X = func(adata, copy=True, **kwargs)
adatas_proc = {
field: func(adata, copy=True, **{field: field}, **kwargs) for field in fields
}
# Modified fields
for field in fields:
result_array = asarray(
sc.get._get_obs_rep(adatas_proc[field], **{field: field})
)
np.testing.assert_array_equal(asarray(adata_X.X), result_array)
# Unmodified fields
for field in fields:
np.testing.assert_array_equal(X_array, asarray(adatas_proc[field].X))
np.testing.assert_array_equal(
X_array, asarray(sc.get._get_obs_rep(adata_X, **{field: field}))
)
for field_a, field_b in permutations(fields, 2):
result_array = asarray(
sc.get._get_obs_rep(adatas_proc[field_a], **{field_b: field_b})
)
np.testing.assert_array_equal(X_array, result_array)
def check_rep_results(func, X, *, fields=["layer", "obsm"], **kwargs):
"""Checks that the results of a computation add values/ mutate the anndata object in a consistent way."""
# Gen data
empty_X = np.zeros(shape=X.shape, dtype=X.dtype)
adata = sc.AnnData(
X=empty_X.copy(),
layers={"layer": empty_X.copy()},
obsm={"obsm": empty_X.copy()},
)
adata_X = adata.copy()
adata_X.X = X.copy()
adatas_proc = {}
for field in fields:
cur = adata.copy()
sc.get._set_obs_rep(cur, X.copy(), **{field: field})
adatas_proc[field] = cur
# Apply function
func(adata_X, **kwargs)
for field in fields:
func(adatas_proc[field], **{field: field}, **kwargs)
# Reset X
adata_X.X = empty_X.copy()
for field in fields:
sc.get._set_obs_rep(adatas_proc[field], empty_X.copy(), **{field: field})
for field_a, field_b in permutations(fields, 2):
assert_equal(adatas_proc[field_a], adatas_proc[field_b])
for field in fields:
assert_equal(adata_X, adatas_proc[field])
def _check_check_values_warnings(function, adata, expected_warning, kwargs={}):
"""
Runs `function` on `adata` with provided arguments `kwargs` twice:
once with `check_values=True` and once with `check_values=False`.
Checks that the `expected_warning` is only raised whtn `check_values=True`.
"""
# expecting 0 no-int warnings
with warnings.catch_warnings(record=True) as record:
function(adata.copy(), **kwargs, check_values=False)
warning_msgs = [w.message.args[0] for w in record]
assert expected_warning not in warning_msgs
# expecting 1 no-int warning
with warnings.catch_warnings(record=True) as record:
function(adata.copy(), **kwargs, check_values=True)
warning_msgs = [w.message.args[0] for w in record]
assert expected_warning in warning_msgs