/
test_array_function.py
95 lines (72 loc) · 2.8 KB
/
test_array_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import pytest
np = pytest.importorskip('numpy', minversion='1.16')
import os
import dask.array as da
from dask.array.utils import assert_eq
env_name = "NUMPY_EXPERIMENTAL_ARRAY_FUNCTION"
missing_arrfunc_cond = env_name not in os.environ or os.environ[env_name] != "1"
missing_arrfunc_reason = env_name + " undefined or disabled"
@pytest.mark.skipif(missing_arrfunc_cond, reason=missing_arrfunc_reason)
@pytest.mark.parametrize('func', [
lambda x: np.concatenate([x, x, x]),
lambda x: np.cov(x, x),
lambda x: np.dot(x, x),
lambda x: np.dstack(x),
lambda x: np.flip(x, axis=0),
lambda x: np.hstack(x),
lambda x: np.matmul(x, x),
lambda x: np.mean(x),
lambda x: np.stack([x, x]),
lambda x: np.sum(x),
lambda x: np.var(x),
lambda x: np.vstack(x),
lambda x: np.fft.fft(x.rechunk(x.shape) if isinstance(x, da.Array) else x),
lambda x: np.fft.fft2(x.rechunk(x.shape) if isinstance(x, da.Array) else x),
lambda x: np.linalg.norm(x)])
def test_array_function_dask(func):
x = np.random.random((100, 100))
y = da.from_array(x, chunks=(50, 50))
res_x = func(x)
res_y = func(y)
assert isinstance(res_y, da.Array)
assert_eq(res_y, res_x)
@pytest.mark.skipif(missing_arrfunc_cond, reason=missing_arrfunc_reason)
@pytest.mark.parametrize('func', [
lambda x: np.min_scalar_type(x),
lambda x: np.linalg.det(x),
lambda x: np.linalg.eigvals(x)])
def test_array_notimpl_function_dask(func):
x = np.random.random((100, 100))
y = da.from_array(x, chunks=(50, 50))
with pytest.raises(TypeError):
func(y)
@pytest.mark.skipif(missing_arrfunc_cond, reason=missing_arrfunc_reason)
def test_array_function_sparse_transpose():
sparse = pytest.importorskip('sparse')
x = da.random.random((500, 500), chunks=(100, 100))
x[x < 0.9] = 0
y = x.map_blocks(sparse.COO)
assert_eq(np.transpose(x), np.transpose(y))
@pytest.mark.skipif(missing_arrfunc_cond, reason=missing_arrfunc_reason)
@pytest.mark.xfail(reason="requires sparse support for __array_function__",
strict=False)
def test_array_function_sparse_tensordot():
sparse = pytest.importorskip('sparse')
x = np.random.random((2, 3, 4))
x[x < 0.9] = 0
y = np.random.random((4, 3, 2))
y[y < 0.9] = 0
xx = sparse.COO(x)
yy = sparse.COO(y)
assert_eq(np.tensordot(x, y, axes=(2, 0)),
np.tensordot(xx, yy, axes=(2, 0)).todense())
@pytest.mark.skipif(missing_arrfunc_cond, reason=missing_arrfunc_reason)
def test_array_function_cupy_svd():
cupy = pytest.importorskip('cupy')
x = cupy.random.random((500, 100))
y = da.from_array(x, chunks=(100, 100), asarray=False)
u_base, s_base, v_base = da.linalg.svd(y)
u, s, v = np.linalg.svd(y)
assert_eq(u, u_base)
assert_eq(s, s_base)
assert_eq(v, v_base)