# _Notebook 1: Where Things Can Go Wrong with `Dask`_
Part 1/3 of the [ETH Zurich UP](https://up.ethz.ch) Dask Workshop \
_Aaron Wienkers, 2024_

**Objectives of this Notebook**:
1. Understand how to choose the appropriate chunk size for your _particular_ computation and machine
2. Recognise the importance of matching chunks to the data on disk
3. Avoid common pitfalls like unnecessary rechunking or persisting large intermediate results
4. Appreciate that sometimes `dask` is just as much an art as it is a scientific tool, in particular when chunking


---

## Pre-requisites
Import libraries and datasets following Notebook 0.

In [1]:
import xarray as xr 
import numpy as np
import dask
from dask.distributed import Client, LocalCluster

import sys
import subprocess
import warnings
warnings.filterwarnings('ignore')

remote_node = subprocess.run(['hostname'], capture_output=True, text=True).stdout.strip().split('.')[0]
print('Hostname is', remote_node)

data_dir = '/scratch/b/b382615/dask_example_scratch/'
file_zarr = data_dir + 'example_data_chunks1.zarr'
file_zarr_big = data_dir + 'example_data_big_chunks1.zarr' 
file_zarr2 = data_dir + 'example_data_chunks2.zarr'
file_zarr3 = data_dir + 'example_data_chunks3.zarr'
file_netcdf = data_dir + 'example_mfdataset/'


---
---

## Problem #1: Exhausting Worker Memory

One of the most common issues when working with `dask` is exhausting the memory of the workers by using chunk sizes that are too large.

We will first set up our Local Dask Cluster to use **128 workers**. On DKRZ `Levante`, this is the number of physical cores on the compute node. \
Note that even though I have **512 GB** of memory on this node, each worker will only have access to **4 GB**.

In [2]:
client = Client(LocalCluster(n_workers=128))

Vibes suggested that a chunk size of 1 year in the `time` dimension was a good idea. 🤙 \
Notice that each chunk is nearly 9 Gb now... \
However, we won't have problems yet, because all of this is still _symbolic_.

In [3]:
sst_big_chunk = xr.open_zarr(file_zarr, chunks={'time': 365, 'lat': -1, 'lon': -1}).sst
sst_big_chunk

Unnamed: 0,Array,Chunk
Bytes,264.50 GiB,8.81 GiB
Shape,"(10957, 1800, 3600)","(365, 1800, 3600)"
Dask graph,31 chunks in 2 graph layers,31 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 264.50 GiB 8.81 GiB Shape (10957, 1800, 3600) (365, 1800, 3600) Dask graph 31 chunks in 2 graph layers Data type float32 numpy.ndarray",3600  1800  10957,

Unnamed: 0,Array,Chunk
Bytes,264.50 GiB,8.81 GiB
Shape,"(10957, 1800, 3600)","(365, 1800, 3600)"
Dask graph,31 chunks in 2 graph layers,31 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


If we try to compute the evolution of the global SST, we will run into memory issues... eventually... \
For now, `xarray` is still optimistic, and even tells us the metadata of the result, once we compute it. It may even appear to be small enough to fit into memory --- after all it's only 43 kiB 🤔

In [4]:
sst_global = sst_big_chunk.mean(dim={'lat','lon'})
sst_global

Unnamed: 0,Array,Chunk
Bytes,42.80 kiB,1.43 kiB
Shape,"(10957,)","(365,)"
Dask graph,31 chunks in 4 graph layers,31 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 42.80 kiB 1.43 kiB Shape (10957,) (365,) Dask graph 31 chunks in 4 graph layers Data type float32 numpy.ndarray",10957  1,

Unnamed: 0,Array,Chunk
Bytes,42.80 kiB,1.43 kiB
Shape,"(10957,)","(365,)"
Dask graph,31 chunks in 4 graph layers,31 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


Now that we call `.load()` to prompt computation, we will get a `MemoryError`:

In [None]:
sst_global.load()

If this chunksize is (for some reason) necessitated by your particular calculation, then an alternative solution is to: \
Reduce the number of workers --> Increases the memory per worker. 

In this example, if we show some restraint and only use 32 workers, then we would have ~16 Gb per worker 👌 \
So let's try this...

In [None]:
client = Client(LocalCluster(n_workers=32))
sst_big_chunk = xr.open_zarr(file_zarr, chunks={'time': 365, 'lat': -1, 'lon': -1}).sst
sst_big_chunk.mean(dim={'lat','lon'}).load()

Still we get a `MemoryError` 🤔...

---
### $\implies$ Rule of Thumb #1 for Choosing Chunk Size:
A single chunk should take up at most **~10%** of the memory available to each worker. \
Workers don't just need to fit a single chunk in memory... They often also need to store intermediate results that are dependencies within the task graph.

---
---

## Problem #2: Unrealised Parallelism

Even if our chunks fit in memory, it may still not be optimal. \
_Remember, the number of data chunks are closely related to the number of tasks that can be executed in parallel._

For example, here our `DataArray` is divided into only 4 chunks. \
So even though our `dask` cluster has 16 workers, 12 of them will remain idle...

In [2]:
client = Client(LocalCluster(n_workers=16, threads_per_worker=1))
sst_4_chunks = xr.open_zarr(file_zarr, chunks={'time': 92, 'lat': -1, 'lon': -1}).sst.isel(time=slice(0,365))
sst_4_chunks

Unnamed: 0,Array,Chunk
Bytes,8.81 GiB,2.22 GiB
Shape,"(365, 1800, 3600)","(92, 1800, 3600)"
Dask graph,4 chunks in 3 graph layers,4 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 8.81 GiB 2.22 GiB Shape (365, 1800, 3600) (92, 1800, 3600) Dask graph 4 chunks in 3 graph layers Data type float32 numpy.ndarray",3600  1800  365,

Unnamed: 0,Array,Chunk
Bytes,8.81 GiB,2.22 GiB
Shape,"(365, 1800, 3600)","(92, 1800, 3600)"
Dask graph,4 chunks in 3 graph layers,4 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


Look at `top` after submitting `.compute()` to confirm there's only 4 processes at ~100% CPU...

In [3]:
%%time
sst_4_chunks_var = sst_4_chunks.var(dim={'lat','lon'}).compute()

CPU times: user 1.65 s, sys: 391 ms, total: 2.04 s
Wall time: 19.2 s


---
### $\implies$ Rule of Thumb #2 for Choosing Chunk Size:
$n_\mathrm{chunks} \gtrsim n_\mathrm{workers}$ \
Ideally the number of chunks is some small multiple of the number of workers, to allow the task-based parallelism to be fully exploited. \
For example, `dask` may often transfer data between workers/disk for one task while finishing the FLOPs of another task.

---
---

## Problem #3: Overzealous Chunking

Just as we can have too few chunks, having too many is also not great...

Now let's chunk the same `DataArray` from `Problem #2` into ~12,000 chunks. \
We can even use the entire 128 cores on our node, so _surely_ we'll get parallelism to the max now ? 😎

In [4]:
client = Client(LocalCluster(n_workers=128))
sst_many_chunks = xr.open_zarr(file_zarr, chunks={'time': 1, 'lat': 500, 'lon': 500}).sst.isel(time=slice(0,365))
sst_many_chunks

Unnamed: 0,Array,Chunk
Bytes,8.81 GiB,0.95 MiB
Shape,"(365, 1800, 3600)","(1, 500, 500)"
Dask graph,11680 chunks in 3 graph layers,11680 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 8.81 GiB 0.95 MiB Shape (365, 1800, 3600) (1, 500, 500) Dask graph 11680 chunks in 3 graph layers Data type float32 numpy.ndarray",3600  1800  365,

Unnamed: 0,Array,Chunk
Bytes,8.81 GiB,0.95 MiB
Shape,"(365, 1800, 3600)","(1, 500, 500)"
Dask graph,11680 chunks in 3 graph layers,11680 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


If we look at `top` after submitting `.compute()`, it appears to be running in serial... 🤔 \
This is because the "nanny" processor first needs to schedule and optimise the task graph for _60,000_ chunks !

In [5]:
%%time
sst_many_chunks = sst_many_chunks.var(dim={'lat','lon'}).compute()

CPU times: user 53.3 s, sys: 13.5 s, total: 1min 6s
Wall time: 1min 3s


Yet, even with 8x as many workers compared to `Problem #2`, this _same computation_ took nearly 2x longer ! 🙁

---
### $\implies$ Rule of Thumb #3 for Choosing Chunk Size:
$n_\mathrm{chunks} \lesssim 100\cdot n_\mathrm{workers}$, and the total $n_\mathrm{chunks} \lesssim 10,000$ \
Keeping the task graph to a "manageable" size ensures we spend time doing the _actual computation_ rather than managing the task graph. \
With too many chunks/tasks, the overhead of sending the tasks to workers can even outweigh the benefits of parallelism.

A good target chunksize in memory is often around 100 Mb.

---
---

## Problem #4: Chunk Shape Mismatched to the Computation

Size is not everything when it comes to chunking... \
The shape of the chunks can also have a big impact, particularly when they don't match the access pattern of the computation.

Let's compare two different chunking strategies when computing a zonally-averaged SST variance. \
For a fair comparison both sets of chunks will be the same size in memory.

In [6]:
client = Client(LocalCluster(n_workers=64))

---
#### Chunking Strategy #1: Splitting Latitude

Here, SST has a single chunk in `dim='time'`, but each chunk contains only a single data point in `dim='lat'`.

In [7]:
sst_lat_chunks = xr.open_zarr(file_zarr3, chunks={'time': 1800, 'lat': 1}).isel(time=slice(0,1800)).sst.persist()
sst_lat_chunks

Unnamed: 0,Array,Chunk
Bytes,43.45 GiB,24.72 MiB
Shape,"(1800, 1800, 3600)","(1800, 1, 3600)"
Dask graph,1800 chunks in 1 graph layer,1800 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 43.45 GiB 24.72 MiB Shape (1800, 1800, 3600) (1800, 1, 3600) Dask graph 1800 chunks in 1 graph layer Data type float32 numpy.ndarray",3600  1800  1800,

Unnamed: 0,Array,Chunk
Bytes,43.45 GiB,24.72 MiB
Shape,"(1800, 1800, 3600)","(1800, 1, 3600)"
Dask graph,1800 chunks in 1 graph layer,1800 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [8]:
%%time
sst_lat_chunks_mean_var = sst_lat_chunks.var(dim='time').mean(dim='lon').compute()

CPU times: user 21.1 s, sys: 12.2 s, total: 33.3 s
Wall time: 24.1 s


---
#### Chunking Strategy #2: Splitting Time

Now flip the chunking dimensions so that each chunk contains all data along `dim='lat'`, but only a single point in time.

In [9]:
sst_time_chunks = xr.open_zarr(file_zarr, chunks={'time': 1, 'lat': -1}).isel(time=slice(0,1800)).sst.persist()
sst_time_chunks

Unnamed: 0,Array,Chunk
Bytes,43.45 GiB,24.72 MiB
Shape,"(1800, 1800, 3600)","(1, 1800, 3600)"
Dask graph,1800 chunks in 1 graph layer,1800 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 43.45 GiB 24.72 MiB Shape (1800, 1800, 3600) (1, 1800, 3600) Dask graph 1800 chunks in 1 graph layer Data type float32 numpy.ndarray",3600  1800  1800,

Unnamed: 0,Array,Chunk
Bytes,43.45 GiB,24.72 MiB
Shape,"(1800, 1800, 3600)","(1, 1800, 3600)"
Dask graph,1800 chunks in 1 graph layer,1800 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [None]:
%%time
sst_time_chunks_mean_var = sst_time_chunks.var(dim='time').mean(dim='lon').compute()

Why did we run out of memory ?! 🤔 \
The solution isn't to just request ever-larger memory resources... We can be smarter.

Here, we tried to compute the time-variance, which is a _global_ operation along the `time` dimension. \
This means that when we chunked along `time`, _every single chunk_ was needed to compute the time-variance at every point in latitude.


---
### $\implies$ Rule of Thumb #4a for Choosing Chunk Shape:
Consider memory locality for _your particular_ computation...
- In which dimensions is your computation _global_ vs _local_ ? \
(Or, more realistically, it's somewhere in-between --- cf. Notebook 2 on _Advanced `dask`.) 
- Do you only consider a single time-slice ?  Without chunking in `time`, your workers will still need to read _all_ of the other time slices into memory before accessing your desired slice.

---
#### Chunking Strategy #3: Splitting Time, again...

Considering now that our workers will eventually need much of the data along the `time` dimension, we trade a few chunks in `time` for more chunks in `lat`. In this way, at least the worker computing `xr.var()` in the North Atlantic doesn't need to also touch the data for the Southern Ocean... 

In [2]:
client = Client(LocalCluster(n_workers=64))
sst_both_chunks = xr.open_zarr(file_zarr2, chunks={'time': 5, 'lat': 360}).isel(time=slice(0,1800)).sst.persist()
sst_both_chunks

Unnamed: 0,Array,Chunk
Bytes,43.45 GiB,24.72 MiB
Shape,"(1800, 1800, 3600)","(5, 360, 3600)"
Dask graph,1800 chunks in 1 graph layer,1800 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 43.45 GiB 24.72 MiB Shape (1800, 1800, 3600) (5, 360, 3600) Dask graph 1800 chunks in 1 graph layer Data type float32 numpy.ndarray",3600  1800  1800,

Unnamed: 0,Array,Chunk
Bytes,43.45 GiB,24.72 MiB
Shape,"(1800, 1800, 3600)","(5, 360, 3600)"
Dask graph,1800 chunks in 1 graph layer,1800 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [3]:
%%time
sst_both_chunks_mean_var = sst_both_chunks.var(dim='time').mean(dim='lon').compute()

CPU times: user 14.2 s, sys: 3.46 s, total: 17.6 s
Wall time: 19.5 s


Well at least it runs this time 😅, but still it takes 3x longer compared to when we divided our data only in latitude...

---
### $\implies$ Rule of Thumb #4b for Choosing Chunk Shape:
Consider data access patterns for _your particular_ computation...
- Work to allow vectorisation by the hardware & accelerated/threaded numerical libraries --- this inherent "parallelism" is often many times faster than task-based parallelism. 
- Minimise unnecessary worker-worker communication.

---
---

## Problem #5: Disregarding the Chunksize of Data on Disk

All data --- whether read from a `zarr` store, `netCDF` file, or `intake` catalogue --- has an inherent on-disk data encoding implying a natural chunking. Let's explore the native chunking from 2 different DataSets on disk:

In [4]:
client = Client(LocalCluster(n_workers=64))
sst_native_chunk1 = xr.open_zarr(file_zarr2, chunks={}).sst  # Specifying chunks={} will load with the native (disk) chunking
sst_native_chunk2 = xr.open_zarr(file_zarr, chunks={}).sst

In [5]:
print(f"sst_native_chunk1 chunksize on disk: (time, lat, lon) = {sst_native_chunk1.data.chunksize}")
print(f"sst_native_chunk2 chunksize on disk: (time, lat, lon) = {sst_native_chunk2.data.chunksize}")

sst_native_chunk1 chunksize on disk: (time, lat, lon) = (20, 36, 3600)
sst_native_chunk2 chunksize on disk: (time, lat, lon) = (1, 1800, 3600)


After running into Problem #4, we've learned that it's probably best to make more chunks in `lat` when computing the time-variance. \
Keeping the equivalent `dask` chunk size, now let's compare how the chunk size on disk affects our computation:

In [6]:
chunk_size = {'time': 20, 'lat': 36}

Here, we match the native chunksize of the data on disk, despite maybe not being the optimal chunking (cf. Problem #4):

In [7]:
%%time
sst_chunk1 = xr.open_zarr(file_zarr2, chunks=chunk_size).isel(time=slice(0,900)).sst
sst_chunk1_mean_var = sst_chunk1.var(dim='time').mean(dim='lon').compute()

CPU times: user 11.3 s, sys: 3.84 s, total: 15.1 s
Wall time: 12.7 s


Compare this now to the equivalent computation & chunksize, but retrieving the data from a `zarr` store with a different native chunksize:

In [8]:
%%time
sst_chunk2 = xr.open_zarr(file_zarr, chunks=chunk_size).isel(time=slice(0,900)).sst
sst_chunk2_mean_var = sst_chunk2.var(dim='time').mean(dim='lon').compute()

CPU times: user 26.7 s, sys: 10.8 s, total: 37.5 s
Wall time: 30.5 s


---
### $\implies$ Rule of Thumb #5 for Choosing Chunk Size:
Work with, not against, the natural chunking of the data on disk. \
(This should also be a consideration when saving simulation or intermediate analysis data to disk.)

If the data format is beyond your control, then often (if possible) it's better to work with the inherent chunking of the data on disk, rather than trying to optimise for your particular computation.

Adjusting the chunksize _at data read-time_ is more efficient, **given the chunksizes are _integer-multiples_ of the native chunksize**. \
i.e. prefer `xr.open_zarr(file_zarr, chunks={'time':10, 'lat': -1})` \
rather than `xr.open_zarr(file_zarr, chunks={}).chunk({'time':10, 'lat': -1})`


---
---

## Problem #6: Unnecessary Rechunking

Rechunking is an expensive operation that should be avoided whenever possible. \
Not only is it time-consuming, it is also memory-intensive (and can lead to memory issues). \
Here is a quick example demonstrating how unnecessary rechunking can slow down our computation.

In [2]:
client = Client(LocalCluster(n_workers=128))
sst = xr.open_zarr(file_zarr2, chunks={}).isel(time=slice(0,4000)).sst
sst

Unnamed: 0,Array,Chunk
Bytes,96.56 GiB,9.89 MiB
Shape,"(4000, 1800, 3600)","(20, 36, 3600)"
Dask graph,10000 chunks in 3 graph layers,10000 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 96.56 GiB 9.89 MiB Shape (4000, 1800, 3600) (20, 36, 3600) Dask graph 10000 chunks in 3 graph layers Data type float32 numpy.ndarray",3600  1800  4000,

Unnamed: 0,Array,Chunk
Bytes,96.56 GiB,9.89 MiB
Shape,"(4000, 1800, 3600)","(20, 36, 3600)"
Dask graph,10000 chunks in 3 graph layers,10000 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [3]:
%%time
sst_mean = sst.mean().compute()

CPU times: user 35.8 s, sys: 7.63 s, total: 43.4 s
Wall time: 44.5 s


These original chunks look a bit small... so maybe rechunking will help reduce the task-graph overhead ? 🤷‍♂️

In [4]:
%%time
sst_rechunked = sst.chunk({'time': 50, 'lat': 200}).persist()
sst_rechunked_mean = sst_rechunked.mean().compute()

CPU times: user 32.3 s, sys: 6.45 s, total: 38.7 s
Wall time: 39.3 s


The thought was good, but the execution was poor... It didn't really help, but it also could've been much worse (memory-wise).\
Let's try again, except we will remember to decide on our chunking strategy _before_ loading the data...

In [5]:
%%time
sst_chunk_initial = xr.open_zarr(file_zarr2, chunks={'time':40, 'lat': 216}).isel(time=slice(0,4000)).sst
sst_chunk_initial_mean = sst_chunk_initial.mean().compute()

CPU times: user 8.43 s, sys: 2.76 s, total: 11.2 s
Wall time: 11.8 s


---
---

## Problem #7: Constraining the Task Graph by Persisting Intermediate Results

Persisting intermediate results, with `.persist()` can be useful for avoiding repeating computations. \
Similarly to `.load()` and `.compute()`, `.persist()` prompts the 'lazy' computation to get going (thereby collapsing the task graph layers). \
However, `.persist()` does not destroy your `dask`-backed DataArray !  Instead, it keeps the _chunked results_ in each workers' memory. 

Misusing `.persist()` can result in:
- Exhausting worker memory (if many or large intermediate results are persisted)
- Slowing down the computation by constraining the task graph, and any realisable parallelism

Let's consider a maybe convoluted example:
1. Calculate the temporal high-pass of $u$ and $\tau_x$
2. Correlate the high-pass of $\tilde{\tau_x} \tilde{u}$ and $\tilde{u}^2$ --- a sort of high-pass wind-work and zonal EKE...
3. See how `.persist()`ing each intermediate result affects the computation time

**NOTE**:  This example takes ~15 minutes to compute... 

In [2]:
client = Client(LocalCluster(n_workers=128))

In [3]:
ds = xr.open_zarr(file_zarr2, chunks={}).isel(time=slice(0,2000))

In [4]:
u_roll = ds.u.rolling(time=30).mean()
tau_x_roll = ds.tau_x.rolling(time=30).mean()

In [5]:
u_high_pass = ds.u - u_roll
tau_x_high_pass = ds.tau_x - tau_x_roll

In [6]:
tau_x_u_high_pass = tau_x_high_pass * u_high_pass
u_high_pass_squared = u_high_pass**2

In [7]:
%%time
correlation = xr.corr(u_high_pass_squared, tau_x_u_high_pass, dim='time').compute()

CPU times: user 6min 3s, sys: 1min, total: 7min 3s
Wall time: 7min 8s


In [8]:
%%time
u_roll_persist = ds.u.rolling(time=30).mean().persist()
tau_x_roll_persist = ds.tau_x.rolling(time=30).mean().persist()

u_high_pass_persist = (ds.u - u_roll_persist).persist()
tau_x_high_pass_persist = (ds.tau_x - tau_x_roll_persist).persist()
tau_x_u_high_pass_persist = (tau_x_high_pass_persist * u_high_pass_persist).persist()
u_high_pass_squared_persist = (u_high_pass_persist**2).persist()
correlation = xr.corr(u_high_pass_squared_persist, tau_x_u_high_pass_persist, dim='time').compute()

CPU times: user 7min 27s, sys: 1min 7s, total: 8min 35s
Wall time: 8min 40s


---
---

## Conclusion

If you've made it this far and navigated these numerous potential pitfalls, then you're well on your way to mastering `dask` ! 🌟

Remember that effectively leveraging `dask` is as much an art as it is a science... \
It requires understanding your data, algorithm, and computational resources, and often takes a bit of artistic experimentation.

Now, for some more advanced topics, continue on to Notebook 2:   `2_advanced_dask.ipynb`