# Strategies for climatology calculations

This notebook is motivated by
[this post](https://discourse.pangeo.io/t/understanding-optimal-zarr-chunking-scheme-for-a-climatology/2335)
on the Pangeo discourse forum.


In [None]:
import dask.array
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

import flox
import flox.xarray

Let's first create an example Xarray Dataset representing the OISST dataset,
with chunk sizes matching that in the post.


In [None]:
oisst = xr.DataArray(
    dask.array.ones((14532, 720, 1440), chunks=(20, -1, -1)),
    dims=("time", "lat", "lon"),
    coords={"time": pd.date_range("1981-09-01 12:00", "2021-06-14 12:00", freq="D")},
    name="sst",
)
oisst

## map-reduce

The default
[method="map-reduce"](https://flox.readthedocs.io/en/latest/implementation.html#method-map-reduce)
doesn't work so well. We aggregate all days in a single chunk.

For this to work well, we'd want smaller chunks in space and bigger chunks in
time.


In [None]:
flox.xarray.xarray_reduce(
    oisst,
    oisst.time.dt.dayofyear,
    func="mean",
    method="map-reduce",
)

## Rechunking for map-reduce

We can split each chunk along the `lat`, `lon` dimensions to make sure the
output chunk sizes are more reasonable


In [None]:
flox.xarray.xarray_reduce(
    oisst.chunk({"lat": -1, "lon": 120}),
    oisst.time.dt.dayofyear,
    func="mean",
    method="map-reduce",
)

But what if we didn't want to rechunk the dataset so drastically (note the 10x
increase in tasks). For that let's try `method="cohorts"`

## method=cohorts

We can take advantage of patterns in the groups here "day of year".
Specifically:

1. The groups at an approximately periodic interval, 365 or 366 days
2. The chunk size 20 is smaller than the period of 365 or 366. This means, that
   to construct the mean for days 1-20, we just need to use the chunks that
   contain days 1-20.

This strategy is implemented as
[method="cohorts"](https://flox.readthedocs.io/en/latest/implementation.html#method-cohorts)


In [None]:
flox.xarray.xarray_reduce(
    oisst,
    oisst.time.dt.dayofyear,
    func="mean",
    method="cohorts",
)

By default cohorts doesn't work so well for this problem because the period
isn't regular (365 vs 366) and the period isn't divisible by the chunk size. So
the groups end up being "out of phase" (for a visual illustration
[click here](https://flox.readthedocs.io/en/latest/implementation.html#method-cohorts)).
Now we have the opposite problem: the chunk sizes on the output are too small.

Looking more closely, We can see the cohorts that `flox` has detected are not
really cohorts, each cohort is a single group label. We've replicated Xarray's
current strategy; what flox calls
["split-reduce"](https://flox.readthedocs.io/en/latest/implementation.html#method-split-reduce-xarray-s-current-groupby-strategy)


In [None]:
flox.core.find_group_cohorts(
    labels=oisst.time.dt.dayofyear.data,
    chunks=(oisst.chunksizes["time"],),
).values()

## Rechunking data for cohorts

Can we fix the "out of phase" problem by rechunking along time?

First lets see where the current chunk boundaries are


In [None]:
array = oisst.data
labels = oisst.time.dt.dayofyear.data
axis = oisst.get_axis_num("time")
oldchunks = array.chunks[axis]
oldbreaks = np.insert(np.cumsum(oldchunks), 0, 0)
labels_at_breaks = labels[oldbreaks[:-1]]
labels_at_breaks

Now we'll use a convenient function `rechunk_for_cohorts` to rechunk the `oisst`
dataset along time. We'll ask it to rechunk so that a new chunk starts at each
of the elements

```
[244, 264, 284, 304, 324, 344, 364,  19,  39,  59,  79,  99, 119,
 139, 159, 179, 199, 219, 239]
```

These are labels at the chunk boundaries in the first year of data. We are
forcing that chunking pattern to repeat as much as possible. We also tell the
function to ignore any existing chunk boundaries.


In [None]:
rechunked = flox.xarray.rechunk_for_cohorts(
    oisst,
    dim="time",
    labels=oisst.time.dt.dayofyear,
    force_new_chunk_at=[
        244,
        264,
        284,
        304,
        324,
        344,
        364,
        19,
        39,
        59,
        79,
        99,
        119,
        139,
        159,
        179,
        199,
        219,
        239,
    ],
    ignore_old_chunks=True,
)
rechunked

We see that chunks are mostly 20 elements long in time with some differences


In [None]:
plt.plot(rechunked.chunksizes["time"], marker="x", ls="none")

And now our cohorts contain more than one group


In [None]:
flox.core.find_group_cohorts(
    labels=rechunked.time.dt.dayofyear.data,
    chunks=(rechunked.chunksizes["time"],),
).values()

Now the groupby reduction **looks OK** in terms of number of tasks but remember
that rechunking to get to this point involves some communication overhead.


In [None]:
flox.xarray.xarray_reduce(rechunked, rechunked.time.dt.dayofyear, func="mean", method="cohorts")

## How about other climatologies?

Let's try monthly


In [None]:
flox.xarray.xarray_reduce(oisst, oisst.time.dt.month, func="mean")

This looks great. Why?

It's because each chunk (size 20) is smaller than number of days in a typical
month. `flox` initially applies the groupby-reduction blockwise. For the chunk
size of 20, we will have at most 2 groups in each chunk, so the initial
blockwise reduction is quite effective - at least a 10x reduction in size from
20 elements in time to at most 2 elements in time.

For this kind of problem, `"map-reduce"` works quite well.
