# Strategies for climatology calculations

This notebook is motivated by
[this post](https://discourse.pangeo.io/t/understanding-optimal-zarr-chunking-scheme-for-a-climatology/2335)
on the Pangeo discourse forum.


In [None]:
import dask.array
import pandas as pd
import xarray as xr

import flox
import flox.xarray

Let's first create an example Xarray Dataset representing the OISST dataset,
with chunk sizes matching that in the post.


In [None]:
oisst = xr.DataArray(
    dask.array.ones((14532, 720, 1440), chunks=(20, -1, -1)),
    dims=("time", "lat", "lon"),
    coords={"time": pd.date_range("1981-09-01 12:00", "2021-06-14 12:00", freq="D")},
    name="sst",
)
oisst

To account for Feb-29 being present in some years, we'll construct a time vector to group by as "mmm-dd" string.

For more options, see https://strftime.org/

In [None]:
day = oisst.time.dt.strftime("%h-%d").rename("day")
day

## map-reduce

The default
[method="map-reduce"](https://flox.readthedocs.io/en/latest/implementation.html#method-map-reduce)
doesn't work so well. We aggregate all days in a single ~3GB chunk.

For this to work well, we'd want smaller chunks in space and bigger chunks in
time.


In [None]:
flox.xarray.xarray_reduce(
    oisst,
    day,
    func="mean",
    method="map-reduce",
)

## Rechunking for map-reduce

We can split each chunk along the `lat`, `lon` dimensions to make sure the
output chunk sizes are more reasonable


In [None]:
flox.xarray.xarray_reduce(
    oisst.chunk({"lat": -1, "lon": 120}),
    day,
    func="mean",
    method="map-reduce",
)

But what if we didn't want to rechunk the dataset so drastically (note the 10x
increase in tasks). For that let's try `method="cohorts"`

## method=cohorts

We can take advantage of patterns in the groups here "day of year".
Specifically:

1. The groups at an approximately periodic interval, 365 or 366 days
2. The chunk size 20 is smaller than the period of 365 or 366. This means, that
   to construct the mean for days 1-20, we just need to use the chunks that
   contain days 1-20.

This strategy is implemented as
[method="cohorts"](https://flox.readthedocs.io/en/latest/implementation.html#method-cohorts)


In [None]:
flox.xarray.xarray_reduce(
    oisst,
    day,
    func="mean",
    method="cohorts",
)

By default cohorts doesn't work so well for this problem because the period
isn't regular (365 vs 366) and the period isn't divisible by the chunk size. So
the groups end up being "out of phase" (for a visual illustration
[click here](https://flox.readthedocs.io/en/latest/implementation.html#method-cohorts)).
Now we have the opposite problem: the chunk sizes on the output are too small.

Let us inspect the cohorts

In [None]:
# integer codes for each "day"
codes, _ = pd.factorize(day.data)
preferred_method, cohorts = flox.core.find_group_cohorts(
    labels=codes,
    chunks=(oisst.chunksizes["time"],),
)
print(len(cohorts))

Looking more closely, we can see many cohorts with a single entry. 

In [None]:
cohorts.values()

## Rechunking data for cohorts

Can we fix the "out of phase" problem by rechunking along time?

First lets see where the current chunk boundaries are

In [None]:
oisst.chunksizes["time"][:10]

We'll choose to rechunk such that a single month in is a chunk. This is not too different from the current chunking but will help your periodicity problem

In [None]:
newchunks = xr.ones_like(day).astype(int).resample(time="M").count()

In [None]:
rechunked = oisst.chunk(time=tuple(newchunks.data))

And now our cohorts contain more than one group


In [None]:
preferrd_method, new_cohorts = flox.core.find_group_cohorts(
    labels=codes,
    chunks=(rechunked.chunksizes["time"],),
)
# one cohort per month!
len(new_cohorts)

In [None]:
new_cohorts.values()

Now the groupby reduction **looks OK** in terms of number of tasks but remember
that rechunking to get to this point involves some communication overhead.


In [None]:
flox.xarray.xarray_reduce(rechunked, day, func="mean", method="cohorts")

## How about other climatologies?

Let's try monthly


In [None]:
flox.xarray.xarray_reduce(oisst, oisst.time.dt.month, func="mean")

This looks great. Why?

It's because each chunk (size 20) is smaller than number of days in a typical
month. `flox` initially applies the groupby-reduction blockwise. For the chunk
size of 20, we will have at most 2 groups in each chunk, so the initial
blockwise reduction is quite effective - at least a 10x reduction in size from
20 elements in time to at most 2 elements in time.

For this kind of problem, `"map-reduce"` works quite well.
