<img src="http://xarray.pydata.org/en/stable/_static/dataset-diagram-logo.png" align="right" width="30%">

# Xarray and Dask

## Table of contents

2. Parallel/streaming/lazy computation using dask.array with Xarray
3. Reading and writing data with Dask and Xarray
4. Automatic parallelization with apply_ufunc and map_blocks


In [None]:
import numpy as np
import xarray as xr

Lets set up a `LocalCluster` using `dask.distributed`

In [None]:
from dask.distributed import Client

client = Client()
client

<p>&#128070</p> Click the Dashboard link above. Or click the "Search" button in the dashboard.

Let's test that the dashboard is working..

In [None]:
import dask.array

dask.array.ones((1000, 4), chunks=(2, 1)).compute()  # should see activity in dashboard

## Reading and writing data with Dask and Xarray


In [None]:
ds = xr.tutorial.open_dataset(
    "air_temperature",
    chunks={
        "lat": 25,
        "lon": 25,
        "time": -1,
    },  # this tells xarray to open the dataset as a dask array
)
ds

The repr for the `air` DataArray shows the dask repr.

In [None]:
ds.air

## Parallel/streaming/lazy computation using dask.array with Xarray

Xarray seamlessly wraps dask so all computation is deferred until explicitly needed

In [None]:
mean = ds.air.mean("time")  # no activity on dashboard
mean  # contains a dask array

This is true for all xarray operations including slicing

In [None]:
timeseries = (
    ds.air.rolling(time=5).mean().isel(lon=1, lat=20)
)  # no activity on dashboard
timeseries  # contains dask array

### Getting concrete values from dask arrays

At some point, you will want to actually get concrete values from dask.

There are two ways to compute values on dask arrays. These concrete values are usually numpy arrays but could be a `pydata/sparse` array for example.

1. `.compute()` returns an xarray object
2. `.load()` replaces the dask array in the xarray object with a numpy array. This is equivalent to `ds = ds.compute()`

In [None]:
computed = mean.compute()  # activity on dashboard
computed  # has real numpy values

Note that `mean` still contains a dask array

In [None]:
mean

But if we call `.load()`,  `mean` will now contain a numpy array

In [None]:
mean.load()

In [None]:
mean

**Tip:** `.persist()` loads the values into distributed RAM. This is useful if you will be repeatedly using a dataset for computation. You'll see a persistent task on the dashboard

### `.values` vs `.data`

There are two ways to pull out the underlying data in an xarray object.

1. `.values` will always return a NumPy array. For dask-backed xarray objects, this means that compute will always be called
2. `.data` will return a Dask array

**Exercise**: Try extracting a dask array from `ds.air`

### Xarray data structures are first-class dask collections.

This means you can do things like `dask.compute(xarray_object)`, `dask.visualize(xarray_object)`, `dask.persist(xarray_object)`.

**Exercise** Visualize the task graph for `mean`

## Automatic parallelization with apply_ufunc and map_blocks


Almost all of xarray’s built-in operations work on Dask arrays. If you want to use a function that isn’t wrapped by xarray, and have it applied in parallel on each block of your xarray object, you have three options:

1. Extract Dask arrays from xarray objects (`.data`) and use Dask directly.

2. Use `apply_ufunc()` to apply functions that consume and return NumPy arrays.

3. Use `map_blocks()`, `Dataset.map_blocks()` or `DataArray.map_blocks()` to apply functions that consume and return xarray objects.

Which method you use ultimately depends on what the function you're wrapping expects and the level of convenience you desire. 

### `map_blocks`

`map_blocks` is inspired by the `dask.array` function of the same name and lets you map a function on blocks of the xarray object (including Datasets!). 

At *compute* time, your function will receive an xarray object with concrete (computed) values along with all metadata and should return an xarray object.

Here is an example

In [None]:
def time_mean(obj):
    return obj.mean("lat")  # use xarray's convenient API here


ds.map_blocks(time_mean)  # this is lazy!

In [None]:
ds.map_blocks(time_mean).identical(
    ds.mean("lat")
)  # this will calculate values and will return True if the computation works as expected

**Exercise** Try applying the following function with `map_blocks`. Specify `scale` as an argument and `offset` as a kwarg.

The docstring should help: https://xarray.pydata.org/en/stable/generated/xarray.map_blocks.html

```
def time_mean_scaled(obj, scale, offset):
    return obj.mean("lat") * scale + offset
```

#### More advanced functions

`map_blocks` needs to know what the returned object looks like *exactly*. This means that more complicated functions can be challenging. For these advanced use cases, `map_blocks` allows a `template` kwarg. See https://xarray.pydata.org/en/latest/dask.html#map-blocks for more details

# apply_ufunc

`apply_ufunc` is a more advanced wrapper that is designed to apply functions that expect and return NumPy (or other arrays). For example, this would include all of SciPy's API. Since `apply_ufunc` operates on lower-level NumPy or Dask objects, it skips the overhead of using Xarray objects making it a good choice for performance-critical functions.

`apply_ufunc` can be a little tricky to get right since it operates at a lower level than `map_blocks`. On the other hand, Xarray uses `apply_ufunc` internally to implement much of its API, meaning that it is quite powerful!

### A simple example

Simple functions that act independently on each value should work without any additional arguments. However `dask` handling needs to be explictly enabled

In [None]:
squared_error = lambda x, y: (x - y) ** 2

xr.apply_ufunc(squared_error, ds.air, 1)

Since `squared_error` can handle dask arrays without computing them, we specify `dask="allowed"`.

In [None]:
sqer = xr.apply_ufunc(squared_error, ds.air, 1, dask="allowed",)
sqer  # dask array!

### A more complicated example with a dask-aware function

For using more complex operations that consider some array values collectively, it’s important to understand the idea of “core dimensions” from NumPy’s generalized ufuncs. Core dimensions are defined as dimensions that should not be broadcast over. Usually, they correspond to the fundamental dimensions over which an operation is defined, e.g., the summed axis in `np.sum`. A good clue that core dimensions are needed is the presence of an ``axis`` argument on the corresponding NumPy function.

With apply_ufunc, core dimensions are recognized by name, and then moved to the last dimension of any input arguments before applying the given function. This means that for functions that accept an `axis` argument, you usually need to set ``axis=-1``

Let's use `dask.array.mean` as an example of a function that can handle dask arrays and uses an `axis` kwarg

In [None]:
def time_mean(da):
    return xr.apply_ufunc(
        dask.array.mean,
        da,
        input_core_dims=[["time"]],
        dask="allowed",
        kwargs={"axis": -1},  # core dimensions are moved to the end
    )


time_mean(ds.air)

In [None]:
ds.air.mean("time").identical(time_mean(ds.air))

### Automatically parallelizing dask-unaware functions

A very useful `apply_ufunc` feature is the ability to apply arbitrary functions in parallel to each block. This ability can be activated using `dask="parallelized"`. Again xarray needs a lot of extra metadata, so depending on the function, extra arguments such as `output_dtypes` and `output_sizes` may be necessary.

We will use `scipy.integrate.trapz` as an example of a function that cannot handle dask arrays and requires a core dimension.

In [None]:
import scipy.integrate
import scipy as sp

sp.integrate.trapz(ds.air.data)  # does NOT return a dask array

**Exercise** Use `apply_ufunc` to apply `sp.integrate.trapz` along the ``time`` axis so that you get a dask array returned. You will need to specify `dask="parallelized"` and `output_dtypes` (a list of `dtypes` per returned variable).

## More

1. https://xarray.pydata.org/en/stable/examples/apply_ufunc_vectorize_1d.html#
2. https://docs.dask.org/en/latest/array-best-practices.html