# Lazy evaluation on Dask arrays


If you are unfamiliar with Dask, read
[Parallel computing with Dask](http://xarray.pydata.org/en/stable/dask.html) in
Xarray documentation first. The current version only supports dask arrays on a
single machine. Support of [Dask.distributed](https://distributed.dask.org) is
in roadmap.

xESMF's Dask support is mostly for
[lazy evaluation](https://en.wikipedia.org/wiki/Lazy_evaluation) and
[out-of-core computing](https://en.wikipedia.org/wiki/External_memory_algorithm),
to allow processing large volumes of data with limited memory. You might also
get moderate speed-up on a multi-core machine by
[choosing proper chunk sizes](http://xarray.pydata.org/en/stable/dask.html#chunking-and-performance),
but that generally won't help your entire pipeline too much, because the
read-regrid-write pipeline is severely I/O limited (see
[this issue](https://github.com/pangeo-data/pangeo/issues/334) for more
discussions). On a single machine, the disk bandwidth is typically limited to
~500 MB/s, and you cannot process data faster than such rate. If you need much
faster data processing rate, you should resort to parallel file systems on HPC
clusters or distributed storage on public cloud platforms. Please refer to the
[Pangeo project](http://pangeo.io/) for more information.


In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import dask.array as da  # need to have dask.array installed, although not directly using it here.
import xarray as xr
import xesmf as xe

## A simple example


### Prepare input data


In [2]:
ds = xr.tutorial.open_dataset("air_temperature", chunks={"time": 500})
ds

Unnamed: 0,Array,Chunk
Bytes,14.76 MiB,2.53 MiB
Shape,"(2920, 25, 53)","(500, 25, 53)"
Dask graph,6 chunks in 2 graph layers,6 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 14.76 MiB 2.53 MiB Shape (2920, 25, 53) (500, 25, 53) Dask graph 6 chunks in 2 graph layers Data type float32 numpy.ndarray",53  25  2920,

Unnamed: 0,Array,Chunk
Bytes,14.76 MiB,2.53 MiB
Shape,"(2920, 25, 53)","(500, 25, 53)"
Dask graph,6 chunks in 2 graph layers,6 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [3]:
ds.chunks

Frozen({'time': (500, 500, 500, 500, 500, 420), 'lat': (25,), 'lon': (53,)})

In [4]:
ds["air"].data

Unnamed: 0,Array,Chunk
Bytes,14.76 MiB,2.53 MiB
Shape,"(2920, 25, 53)","(500, 25, 53)"
Dask graph,6 chunks in 2 graph layers,6 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 14.76 MiB 2.53 MiB Shape (2920, 25, 53) (500, 25, 53) Dask graph 6 chunks in 2 graph layers Data type float32 numpy.ndarray",53  25  2920,

Unnamed: 0,Array,Chunk
Bytes,14.76 MiB,2.53 MiB
Shape,"(2920, 25, 53)","(500, 25, 53)"
Dask graph,6 chunks in 2 graph layers,6 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


### Build regridder


In [5]:
ds_out = xr.Dataset(
    {
        "lat": (["lat"], np.arange(16, 75, 1.0)),
        "lon": (["lon"], np.arange(200, 330, 1.5)),
    }
)

regridder = xe.Regridder(ds, ds_out, "bilinear")
regridder

xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_25x53_59x87.nc 
Reuse pre-computed weights? False 
Input grid shape:           (25, 53) 
Output grid shape:          (59, 87) 
Periodic in longitude?      False

### Apply to xarray Dataset/DataArray


In [6]:
# only build the dask graph; actual computation happens later when calling compute()
%time ds_out = regridder(ds)
ds_out

CPU times: user 2.06 s, sys: 19.8 ms, total: 2.08 s
Wall time: 2.09 s


Unnamed: 0,Array,Chunk
Bytes,57.18 MiB,2.53 MiB
Shape,"(2920, 59, 87)","(500, 25, 53)"
Dask graph,36 chunks in 8 graph layers,36 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 57.18 MiB 2.53 MiB Shape (2920, 59, 87) (500, 25, 53) Dask graph 36 chunks in 8 graph layers Data type float32 numpy.ndarray",87  59  2920,

Unnamed: 0,Array,Chunk
Bytes,57.18 MiB,2.53 MiB
Shape,"(2920, 59, 87)","(500, 25, 53)"
Dask graph,36 chunks in 8 graph layers,36 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [7]:
ds_out["air"].data  # chunks are preserved

Unnamed: 0,Array,Chunk
Bytes,57.18 MiB,2.53 MiB
Shape,"(2920, 59, 87)","(500, 25, 53)"
Dask graph,36 chunks in 8 graph layers,36 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 57.18 MiB 2.53 MiB Shape (2920, 59, 87) (500, 25, 53) Dask graph 36 chunks in 8 graph layers Data type float32 numpy.ndarray",87  59  2920,

Unnamed: 0,Array,Chunk
Bytes,57.18 MiB,2.53 MiB
Shape,"(2920, 59, 87)","(500, 25, 53)"
Dask graph,36 chunks in 8 graph layers,36 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [8]:
%time result = ds_out['air'].compute()  # actually applies regridding

CPU times: user 1.03 s, sys: 97.3 ms, total: 1.12 s
Wall time: 755 ms


In [9]:
type(result.data), result.data.shape

(numpy.ndarray, (2920, 59, 87))

## Spatial chunks


Dask support also includes chunking over horizontal/core dimensions (`lat`,
`lon`, or `x`, `y`).


In [10]:
# xESMF will take DataArrays that are chunked along the horizontal/core dimensions
ds_spatial = ds.chunk({"lat": 25, "lon": 25, "time": -1})
ds_spatial

Unnamed: 0,Array,Chunk
Bytes,14.76 MiB,6.96 MiB
Shape,"(2920, 25, 53)","(2920, 25, 25)"
Dask graph,3 chunks in 3 graph layers,3 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 14.76 MiB 6.96 MiB Shape (2920, 25, 53) (2920, 25, 25) Dask graph 3 chunks in 3 graph layers Data type float32 numpy.ndarray",53  25  2920,

Unnamed: 0,Array,Chunk
Bytes,14.76 MiB,6.96 MiB
Shape,"(2920, 25, 53)","(2920, 25, 25)"
Dask graph,3 chunks in 3 graph layers,3 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


Specifying spatial chunks will result in the output DataArray having the same
chunk size on the spatial dimensions.


In [11]:
ds_spatial_out = regridder(ds_spatial)  # Regridding ds_spatial
ds_spatial_out["air"].data

Unnamed: 0,Array,Chunk
Bytes,57.18 MiB,6.96 MiB
Shape,"(2920, 59, 87)","(2920, 25, 25)"
Dask graph,12 chunks in 10 graph layers,12 chunks in 10 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 57.18 MiB 6.96 MiB Shape (2920, 59, 87) (2920, 25, 25) Dask graph 12 chunks in 10 graph layers Data type float32 numpy.ndarray",87  59  2920,

Unnamed: 0,Array,Chunk
Bytes,57.18 MiB,6.96 MiB
Shape,"(2920, 59, 87)","(2920, 25, 25)"
Dask graph,12 chunks in 10 graph layers,12 chunks in 10 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


Unless the `output_chunks` argument is passed to the `regridder`


In [12]:
ds_spatial_out = regridder(ds_spatial, output_chunks={"lat": 10, "lon": 10})
ds_spatial_out["air"].data

Unnamed: 0,Array,Chunk
Bytes,57.18 MiB,1.11 MiB
Shape,"(2920, 59, 87)","(2920, 10, 10)"
Dask graph,54 chunks in 10 graph layers,54 chunks in 10 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 57.18 MiB 1.11 MiB Shape (2920, 59, 87) (2920, 10, 10) Dask graph 54 chunks in 10 graph layers Data type float32 numpy.ndarray",87  59  2920,

Unnamed: 0,Array,Chunk
Bytes,57.18 MiB,1.11 MiB
Shape,"(2920, 59, 87)","(2920, 10, 10)"
Dask graph,54 chunks in 10 graph layers,54 chunks in 10 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


# Parallel weight generation with Dask


Dask can also be used to build the regridder and compute its weights in
parallel. To do so, xESMF uses the chunks on the destination grid and computes
subsets of weights on each chunk in parallel.


## Parallel weight generation example


### Prepare input data


In [2]:
ds = xr.tutorial.open_dataset("air_temperature", chunks={"time": 500})
ds

Unnamed: 0,Array,Chunk
Bytes,14.76 MiB,2.53 MiB
Shape,"(2920, 25, 53)","(500, 25, 53)"
Dask graph,6 chunks in 2 graph layers,6 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 14.76 MiB 2.53 MiB Shape (2920, 25, 53) (500, 25, 53) Dask graph 6 chunks in 2 graph layers Data type float32 numpy.ndarray",53  25  2920,

Unnamed: 0,Array,Chunk
Bytes,14.76 MiB,2.53 MiB
Shape,"(2920, 25, 53)","(500, 25, 53)"
Dask graph,6 chunks in 2 graph layers,6 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


### Prepare output dataset and chunk it


In [4]:
ds_out = xr.tutorial.open_dataset('rasm')
ds_out = ds_out.chunk({'y':50,'x':50})
ds_out.chunks

Frozen({'time': (36,), 'y': (50, 50, 50, 50, 5), 'x': (50, 50, 50, 50, 50, 25)})

### Create regridder, generating the weights in parallel


In [5]:
para_regridder = xe.Regridder(ds, ds_out, "bilinear", parallel=True)
para_regridder



xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_25x53_205x275.nc 
Reuse pre-computed weights? False 
Input grid shape:           (25, 53) 
Output grid shape:          (205, 275) 
Periodic in longitude?      False

Attempting to build the Regridder using the option `parallel=True` with either `reuse_weights=True` or with `weights != None` will produce a warning. In both cases, since the weights are already generated, the regridder will be built skipping the parallel part.

### Using a mask to chunk an empty Dataset

If the destination grid has no variables and contains 1D lat/lon coordinates, using xarray's `.chunk()` method will not work

In [10]:
ds_out = xr.Dataset(
    {
        "lat": (["lat"], np.arange(16, 75, 1.0), {"units": "degrees_north"}),
        "lon": (["lon"], np.arange(200, 330, 1.5), {"units": "degrees_east"}),
    }
)
ds_out


In [11]:
ds_out.chunk({"lat":25,"lon":25})
ds_out.chunks

Frozen({})

To deal with this issue, we can create a `mask` and add it to `ds_out`. Using a boolean mask ensures `ds_out` is not bloated by data and setting the mask to be `True` everywhere will not affect regridding.

In [8]:
mask = da.ones((ds_out.lat.size, ds_out.lon.size), dtype=bool, chunks=(25, 25))
ds_out['mask'] = (ds_out.dims, mask)

# Now we check the chunks of ds_out
ds_out.chunks

Frozen({'lat': (25, 25, 9), 'lon': (25, 25, 25, 12)})