-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
test_distributed.py
70 lines (55 loc) · 2.2 KB
/
test_distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import sys
import pytest
import xarray as xr
from xarray.core.pycompat import suppress
distributed = pytest.importorskip('distributed')
da = pytest.importorskip('dask.array')
import dask
from distributed.utils_test import cluster, loop, gen_cluster
from distributed.client import futures_of, wait
from xarray.tests.test_backends import create_tmp_file, ON_WINDOWS
from xarray.tests.test_dataset import create_test_data
from . import assert_allclose, has_scipy, has_netCDF4, has_h5netcdf
ENGINES = []
if has_scipy:
ENGINES.append('scipy')
if has_netCDF4:
ENGINES.append('netcdf4')
if has_h5netcdf:
ENGINES.append('h5netcdf')
@pytest.mark.xfail(sys.platform == 'win32',
reason='https://github.com/pydata/xarray/issues/1738')
@pytest.mark.parametrize('engine', ENGINES)
def test_dask_distributed_integration_test(loop, engine):
with cluster() as (s, _):
with distributed.Client(s['address'], loop=loop):
original = create_test_data()
with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename:
original.to_netcdf(filename, engine=engine)
with xr.open_dataset(filename, chunks=3, engine=engine) as restored:
assert isinstance(restored.var1.data, da.Array)
computed = restored.compute()
assert_allclose(original, computed)
@pytest.mark.skipif(distributed.__version__ <= '1.19.3',
reason='Need recent distributed version to clean up get')
@gen_cluster(client=True, timeout=None)
def test_async(c, s, a, b):
x = create_test_data()
assert not dask.is_dask_collection(x)
y = x.chunk({'dim2': 4}) + 10
assert dask.is_dask_collection(y)
assert dask.is_dask_collection(y.var1)
assert dask.is_dask_collection(y.var2)
z = y.persist()
assert str(z)
assert dask.is_dask_collection(z)
assert dask.is_dask_collection(z.var1)
assert dask.is_dask_collection(z.var2)
assert len(y.__dask_graph__()) > len(z.__dask_graph__())
assert not futures_of(y)
assert futures_of(z)
future = c.compute(z)
w = yield future
assert not dask.is_dask_collection(w)
assert_allclose(x + 10, w)
assert s.task_state