# _Notebook 2: Advanced `Dask` for Climate Science Data Analysis_
Part 2/3 of the [ETH Zurich UP](https://up.ethz.ch) Dask Workshop \
_Aaron Wienkers, 2024_

**Objectives of this Notebook**:
1. Learn how to use the `Dask` Dashboard for debugging and optimising performance
2. Understand how to design task-based parallel algorithms compatible with `dask`-`xarray`
3. Explore advanced features like `xr.apply_ufunc` and specialised chunking with `flox`
4. Scale up computations using a `dask` SLURM cluster

---
---

## Pre-requisites
Import libraries and datasets following Notebook 0.

In [1]:
import xarray as xr
import numpy as np
import dask
from dask.distributed import Client, LocalCluster

import sys
import subprocess
import re
import warnings
warnings.filterwarnings('ignore')

data_dir = '/scratch/b/b382615/dask_example_scratch/'
file_zarr = data_dir + 'example_data_chunks1.zarr'
file_zarr2 = data_dir + 'example_data_chunks2.zarr'
file_zarr3 = data_dir + 'example_data_chunks3.zarr'

In [2]:
client = Client(LocalCluster(n_workers=32, threads_per_worker=2))

---
---

## Performance Debugging with the `Dask` Dashboard

The `Dask` Dashboard is an invaluable tool for understanding the performance of your `dask` computations and identifying potential issues. \
Together with knowledge of the 7 Common `Dask` Pitfalls covered in `1_where_things_can_go_wrong.ipynb`, the `dask` dashboard is a powerful tool for debugging and optimising your code.

### Setting Up the Dashboard

If you use JupyterHub through a browser, then this line may be necessary to forward the dashboard port to your local machine: 
<small>
```python
dask.config.config.get('distributed').get('dashboard').update({'link':'{JUPYTERHUB_SERVICE_PREFIX}/proxy/{port}/status'})
```
</small>

If you instead use VS Code Remote Explorer (i.e. through SSH Tunnel), then run the following commands and forward the corresponding port via VS Code: \
(N.B.: This has been tested on DKRZ Levante & CSCS Alps, but may require modification on other systems.)

In [3]:
remote_node = subprocess.run(['hostname'], capture_output=True, text=True).stdout.strip().split('.')[0]
port = re.search(r':(\d+)/', client.dashboard_link).group(1)
print('Hostname is', remote_node)
print(f"Forward Port = {remote_node}:{port}")
print(f"Dashboard Link: {client.dashboard_link}")

Hostname is l40183
Forward Port = l40183:8787
Dashboard Link: http://127.0.0.1:8787/status


Now, let's redo a problem from Notebook 1 and see what `dask` is doing under the hood:

In [4]:
ds = xr.open_zarr(file_zarr2, chunks={'time': 200}).isel(time=slice(0,2000))

In [5]:
tau_x = ds.tau_x
f = 2.0 * 7.292e-5 * np.sin(np.deg2rad(tau_x.lat))
f_masked = xr.where(np.abs(tau_x.lat) < 5.0, np.nan, f)
M_y = -tau_x / (1020.0 * f_masked)
M_y_mean = M_y.mean(dim={'lon','time'}).compute()

---
### Interpreting the Task Stream on the `Dask`board

A good-looking Task Stream should have:
- Balanced workload across workers
- Minimal idle time (empty space)
- Limited (or at least interleaved) data transfer between workers

Warnings to look out for:
- <span style="color: red;">Red Bars</span> dominating the Task Stream: Indicates excessive worker-worker communication
- <span style="color: orange;">Orange Bars</span>: Indicates I/O (i.e. spilling intermediate results to disk)

Let's try to improve this code snippet by analysing the `Dask` Dashboard:

In [None]:
ds = xr.open_zarr(file_zarr2, chunks={'time': 1, 'lat': 100, 'lon': -1}).isel(time=slice(0,365))
sst_rolling = ds.sst.rolling(time=15).mean().persist()
correlation = xr.corr(sst_rolling, ds.tau_x, dim='time').compute()

_A few things to consider / investigate:_ \
(Inspired by `1_where_things_can_go_wrong.ipynb`)
1. Are there too many/few tasks ?
2. How can I reduce inter-worker communication ?
3. Am I unnecessarily constraining the task graph ?
4. Are my chunks working with the data on disk ?
5. Are my chunks appropriate for the operations ?

In [6]:
ds = xr.open_zarr(file_zarr2, chunks={'time': 1, 'lat': 100, 'lon': -1}).isel(time=slice(0,365))
ds.u

Unnamed: 0,Array,Chunk
Bytes,8.81 GiB,1.37 MiB
Shape,"(365, 1800, 3600)","(1, 100, 3600)"
Dask graph,6570 chunks in 3 graph layers,6570 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 8.81 GiB 1.37 MiB Shape (365, 1800, 3600) (1, 100, 3600) Dask graph 6570 chunks in 3 graph layers Data type float32 numpy.ndarray",3600  1800  365,

Unnamed: 0,Array,Chunk
Bytes,8.81 GiB,1.37 MiB
Shape,"(365, 1800, 3600)","(1, 100, 3600)"
Dask graph,6570 chunks in 3 graph layers,6570 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [7]:
%%time
sst_rolling = ds.sst.rolling(time=15).mean().persist()
correlation = xr.corr(sst_rolling, ds.tau_x, dim='time').compute()

CPU times: user 1min 53s, sys: 6.59 s, total: 1min 59s
Wall time: 2min 1s


#### This is some sort of optimum...

In [8]:
# See what the native chunksize is...
xr.open_zarr(file_zarr2, chunks={}).u.data.chunksize

(20, 36, 3600)

In [9]:
ds = xr.open_zarr(file_zarr2, chunks={'time': 200, 'lat': 36, 'lon': -1}).isel(time=slice(0,365))
ds.u

Unnamed: 0,Array,Chunk
Bytes,8.81 GiB,98.88 MiB
Shape,"(365, 1800, 3600)","(200, 36, 3600)"
Dask graph,100 chunks in 3 graph layers,100 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 8.81 GiB 98.88 MiB Shape (365, 1800, 3600) (200, 36, 3600) Dask graph 100 chunks in 3 graph layers Data type float32 numpy.ndarray",3600  1800  365,

Unnamed: 0,Array,Chunk
Bytes,8.81 GiB,98.88 MiB
Shape,"(365, 1800, 3600)","(200, 36, 3600)"
Dask graph,100 chunks in 3 graph layers,100 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [10]:
%%time
sst_rolling = ds.sst.rolling(time=15).mean()
correlation = xr.corr(sst_rolling, ds.tau_x, dim='time').compute()

CPU times: user 9.19 s, sys: 1 s, total: 10.2 s
Wall time: 13.7 s


---
---

## Designing Task-Based `Dask`-Based Parallel Algorithms

Implementing _task-based parallelism_ often requires rethinking your original _serial_ algorithm (as is the case with distributed/MPI programming requiring re-working your serial code).  \
Here are some thoughts to keep in mind while designing your analysis so that `dask` can execute it in parallel:

- Using `xarray` best practices will take you very far. For example, 
    - Utilise `xarray` built-in functions wherever possible (!)
    - Don't use direct array indexing, `[]`. Instead `xarray` selection methods like `.sel()` or `.isel()` help maintain chunking.
    - Avoid indexing with chunked arrays. This is often a valid time to `.compute()` an intermediate result.
    - N.B.: Many global `numpy` functions will destroy `dask`-backed DataArrays, so always prefer the `xarray`-native variants.
- Some global operations may not be compatible on chunked arrays. Instead, think about how you can make reductions and contract the data on the other (un-chunked) dimensions first, before then performing a global operation on the reduced result. 
    * _This often may require some clever deviations in your algorithm._ 🤔
- Think in terms of independent tasks rather than a single linear sequence of global steps. 
    * _Independent_ here means the operation can be performed using only the data within a single chunk (i.e. no cross-chunk dependencies).
    * If this is not possible, then it may be time to revisit your initial decision w.r.t. chunking.

Let's try to rewrite the following serial algorithms that currently work on a subset of data, but now we want to scale-up our analysis.

---
### Example 1:

In [11]:
ds = xr.open_zarr(file_zarr2, chunks={}).isel(time=slice(0,100))
sst_np = ds.sst.compute().values  # This is now a `numpy` array

In [12]:
sst_threshold_np = np.percentile(sst_np, 95, axis=0)
area_above_threshold_np = np.sum(sst_np > sst_threshold_np, axis=(1, 2))

Using `xarray` & `dask`:

In [13]:
ds = xr.open_zarr(file_zarr2, chunks={'time':-1, 'lat':36}).isel(time=slice(0,1000))
ds.sst

Unnamed: 0,Array,Chunk
Bytes,24.14 GiB,494.38 MiB
Shape,"(1000, 1800, 3600)","(1000, 36, 3600)"
Dask graph,50 chunks in 3 graph layers,50 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 24.14 GiB 494.38 MiB Shape (1000, 1800, 3600) (1000, 36, 3600) Dask graph 50 chunks in 3 graph layers Data type float32 numpy.ndarray",3600  1800  1000,

Unnamed: 0,Array,Chunk
Bytes,24.14 GiB,494.38 MiB
Shape,"(1000, 1800, 3600)","(1000, 36, 3600)"
Dask graph,50 chunks in 3 graph layers,50 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [14]:
sst_threshold = ds.sst.quantile(0.95, dim=('time'))  # Don't be tempted to use `np.percentile` 🤨

In [15]:
area_above_threshold = xr.where(ds.sst > sst_threshold, 1, 0).sum(dim=('lat', 'lon')).compute()

N.B.: Avoid using the `xr` method, i.e. `ds.sst.where()`... \
N.N.B.: Using `drop=True` is also often a recipe for disaster when working with `dask` arrays.

---
### Example 2:

We would like to calculate the mean SST corresponding to each coulour in each time-step... (Inspiration from Emma 🤔)

In [16]:
# Take a small subset so that our serial algorithm can finish today...
sst = xr.open_zarr(file_zarr, chunks={}).sst.isel(time=slice(0,100)).sel(lat=slice(0,90)).sel(lon=slice(0,180))
sst_np = sst.compute().data

# Make up some random integer data labels and zero out half of them
colours_np = np.random.randint(0, 256, size=sst_np.shape)
colours_np[colours_np > 128] = 0

Using maybe sub-pythonic `numpy`... but this is just a prototype algorithm so 🤷‍♂️

In [17]:
colours_max = colours_np.max()
colour_sst_mean_np = np.zeros([sst_np.shape[0], colours_max + 1], dtype=int)

In [18]:
%%time
for t in range(sst_np.shape[0]):
    for l in range(colours_max+1):
        mask_colour_l = (colours_np[t] == l)
        colour_l_count = mask_colour_l.sum()
        colour_sst_mean_np[t,l] = sst_np[t,mask_colour_l].sum() / colour_l_count

CPU times: user 53.5 s, sys: 1.91 s, total: 55.4 s
Wall time: 47.3 s


Using `xarray` & `dask`:

In [19]:
# Ensure `colours` is a `DataArray`, and that it is chunked with intention
colours = xr.DataArray(colours_np, dims=('time', 'lat', 'lon')).chunk(sst.chunks)
colours_max = colours.max()
colours

Unnamed: 0,Array,Chunk
Bytes,1.21 GiB,12.36 MiB
Shape,"(100, 900, 1800)","(1, 900, 1800)"
Dask graph,100 chunks in 1 graph layer,100 chunks in 1 graph layer
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 1.21 GiB 12.36 MiB Shape (100, 900, 1800) (1, 900, 1800) Dask graph 100 chunks in 1 graph layer Data type int64 numpy.ndarray",1800  900  100,

Unnamed: 0,Array,Chunk
Bytes,1.21 GiB,12.36 MiB
Shape,"(100, 900, 1800)","(1, 900, 1800)"
Dask graph,100 chunks in 1 graph layer,100 chunks in 1 graph layer
Data type,int64 numpy.ndarray,int64 numpy.ndarray


In [20]:
colour_ids = xr.DataArray(np.arange(colours_max + 1), dims=('colour'))

In [21]:
colour_sst_sum = xr.where(colours == colour_ids, sst, 0.0).sum(dim=('lat', 'lon'))
colour_counts = xr.where(colours == colour_ids, True, False).sum(dim=('lat', 'lon'))
colour_counts

Unnamed: 0,Array,Chunk
Bytes,100.78 kiB,1.01 kiB
Shape,"(100, 129)","(1, 129)"
Dask graph,100 chunks in 8 graph layers,100 chunks in 8 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 100.78 kiB 1.01 kiB Shape (100, 129) (1, 129) Dask graph 100 chunks in 8 graph layers Data type int64 numpy.ndarray",129  100,

Unnamed: 0,Array,Chunk
Bytes,100.78 kiB,1.01 kiB
Shape,"(100, 129)","(1, 129)"
Dask graph,100 chunks in 8 graph layers,100 chunks in 8 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray


In [23]:
%%time
# "Mean" (ignoring, for sake of brevity, grid-area variations...):
colour_sst_mean = (colour_sst_sum / colour_counts).compute()

CPU times: user 3.01 s, sys: 3.38 s, total: 6.39 s
Wall time: 9.8 s


See below for an alternative (better) solution...

---
---

## Advanced Features

### `xr.apply_ufunc`

When your desired operation can't be directly/easily/efficiently translated to `xarray` functions, then `xr.apply_ufunc` comes to the rescue. \
The "applied" function will receive as inputs a _`numpy` array_ as dictated by the `input_core_dims`, and any chunks in the remaining dimensions will be subsequently added to the `dask` task graph and broadcasted to workers.

Let's see how this works for `Example 2` above:

In [24]:
## This is the function that will be applied to each chunk
#    _sst_np will be a `numpy` array with shape (lat, lon)
#    (equiv. _colours_np & _colour_ids)

def chunked_function(_sst_np, _colours_np, _colour_ids):  
    # Put on your best `numpy`, etc vectorised behaviour in here...
    
    colour_count = (_colours_np == _colour_ids[:, None, None]).sum(axis=(1, 2))
    colour_sst_sum = np.where(_colours_np[None,:,:] == _colour_ids[:,None,None], 
                                 _sst_np[None,:,:], 0.0).sum(axis=(1, 2))
    colour_sst_mean_np = colour_sst_sum / colour_count
    return colour_sst_mean_np

Importantly, we _won't_ mention `time` in any of the "core" dimensions specified in `xr.apply_ufunc` ! \
Our `chunked_function()` doesn't care what `time` even is...  This allows `xr.apply_ufunc` to turn `time` chunks into tasks 👌

In [25]:
%%time
colour_sst_mean = xr.apply_ufunc(
        chunked_function, 
        sst,                    # The first  argument in `chunked_function`.  (Must be a DataArray!)
        colours,                # The second argument.  (Must _also_ be a DataArray!)
        colour_ids,             # The third  argument. (Still a DataArray.)
        input_core_dims=[['lat', 'lon'],['lat', 'lon'],['colour']],
        output_core_dims=[['colour']],
        vectorize=True,
        dask='parallelized'
)
colour_sst_mean = colour_sst_mean.compute()

CPU times: user 2.03 s, sys: 3.19 s, total: 5.21 s
Wall time: 7.53 s


N.B.: Set `vectorize=True` is a convenience that will simply loop over the chunked dimension on each worker, allowing you to operate on the precise lower-dimensional sub-array specified by `input_core_dims`. If your `numpy` algorithm is inherently vectorised (in the above case, in the 3rd `time` dimension), then performance gains can be had by setting `vectorize=False`.

---
### Advanced Chunking with `flox`

If rechunking is absolutely necessary, and your required chunking structure is a bit obscure, then `flox` provides advanced but very useful chunking capabilities. _Still remember: Always define chunks when reading the data !_ \
cf. [`flox` Documentation](https://flox.readthedocs.io/en/latest/)

Here is a quick simple example for climatological operations: \
We use `flox` to (effectively for free) group the data by `dayofyear`. 

The consequence is that future (climatological) operations across `dayofyear` will:
- Have more ideal stride access patterns, and
- Pre-distribute the data to minimise inter-worker communication 🤙

In [26]:
import flox.xarray

In [27]:
sst = xr.open_zarr(file_zarr, chunks={}).sst.isel(time=slice(0,1097))
sst_dayofyear_chunk = flox.xarray.rechunk_for_cohorts(sst, 'time', labels=sst.time.dt.dayofyear, force_new_chunk_at=1, chunksize=8, ignore_old_chunks=True)
sst_dayofyear_chunk

Unnamed: 0,Array,Chunk
Bytes,26.48 GiB,197.75 MiB
Shape,"(1097, 1800, 3600)","(8, 1800, 3600)"
Dask graph,139 chunks in 4 graph layers,139 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 26.48 GiB 197.75 MiB Shape (1097, 1800, 3600) (8, 1800, 3600) Dask graph 139 chunks in 4 graph layers Data type float32 numpy.ndarray",3600  1800  1097,

Unnamed: 0,Array,Chunk
Bytes,26.48 GiB,197.75 MiB
Shape,"(1097, 1800, 3600)","(8, 1800, 3600)"
Dask graph,139 chunks in 4 graph layers,139 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


N.B.: If we do this rechunking _before_ any of the data has been 'eagerly' loaded from disk, then this is very efficient ! \
See also, `flox.rechunk_for_blockwise()`, e.g. for chunking by months (annoyingly having different chunk-sizes...)

In [28]:
# Group by dayofyear and compute the mean
sst_climatology = sst_dayofyear_chunk.groupby('time.dayofyear').mean()

---
---

## Scaling Up with `Dask` SLURM Clusters

When your algorithm is optimised but the computation still takes too long, it's time to scale up with a `Dask` SLURM cluster.

Similar to how we previously made a `LocalCluster`, we just define our Distributed SLURM Cluster and then connect to the client.

In [None]:
from dask_jobqueue import SLURMCluster

clusterDistributed = SLURMCluster(
                        cores=32, 
                        processes=64,
                        memory="256GB",
                        walltime="01:00:00",
                        interface="ib0",
)

clusterDistributed.scale(128)
clientDistributed = Client(clusterDistributed)
remote_node = subprocess.run(['hostname'], capture_output=True, text=True
                            ).stdout.strip().split('.')[0]
port = re.search(r':(\d+)/', clientDistributed.dashboard_link).group(1)
print(f"Forward Port = {remote_node}:{port}")
print(f"Dashboard Link: {clientDistributed.dashboard_link}")

This sets up a `Dask` cluster that spans multiple SLURM jobs, allowing you to harness the power of distributed memory parallelism on multiple nodes for your `dask` computations.

In [None]:
# Now just do your same computation exactly as before !
...

---
---

## Conclusion

Congratulations on making it to the end of this `dask` workshop ! \
Happy computing! 😃