# Large Raster Zonal Statistics

"Zonal statistics" spans a large range of problems. 

This one is inspired by [this issue](https://github.com/xarray-contrib/flox/issues/428), where a cell areas raster is aggregated over 6 different groupers and summed. Each array involved has a global extent on a 30m grid with shape 560_000 x 1440_000 and chunk size 10_000 x 10_000. Three of the groupers `tcl_year`, `drivers`, and `tcd_thresholds` have a small number of group labels (23, 5, and 7). 

The last 3 groupers are [GADM](https://gadm.org/) level 0, 1, 2 administrative area polygons rasterized to this grid; with 248, 86, and 854 unique labels respectively (arrays `adm0`, `adm1`, and `adm2`). These correspond to country-level, state-level, and county-level administrative boundaries. 

## Example dataset

Here is a representative version of the dataset (in terms of size and chunk sizes).

In [None]:
import dask.array
import numpy as np
import xarray as xr

from flox.xarray import xarray_reduce

sizes = {"y": 560_000, "x": 1440_000}
chunksizes = {"y": 10_000, "x": 10_000}
dims = ("y", "x")
shape = tuple(sizes[d] for d in dims)
chunks = tuple(chunksizes[d] for d in dims)

ds = xr.Dataset(
    {
        "areas": (dims, dask.array.ones(shape, chunks=chunks, dtype=np.float32)),
        "tcl_year": (
            dims,
            1 + dask.array.zeros(shape, chunks=chunks, dtype=np.float32),
        ),
        "drivers": (dims, 2 + dask.array.zeros(shape, chunks=chunks, dtype=np.float32)),
        "tcd_thresholds": (
            dims,
            3 + dask.array.zeros(shape, chunks=chunks, dtype=np.float32),
        ),
        "adm0": (dims, 4 + dask.array.ones(shape, chunks=chunks, dtype=np.float32)),
        "adm1": (dims, 5 + dask.array.zeros(shape, chunks=chunks, dtype=np.float32)),
        "adm2": (dims, 6 + dask.array.zeros(shape, chunks=chunks, dtype=np.float32)),
    }
)
ds

## Zonal Statistics

Next define the grouper arrays and expected group labels

In [None]:
by = (ds.tcl_year, ds.drivers, ds.tcd_thresholds, ds.adm0, ds.adm1, ds.adm2)
expected_groups = (
    np.arange(23),
    np.arange(1, 6),
    np.arange(1, 8),
    np.arange(248),
    np.arange(86),
    np.arange(854),
)

In [None]:
result = xarray_reduce(
    ds.areas,
    *by,
    expected_groups=expected_groups,
    func="sum",
)
result

Formulating the three admin levels as orthogonal dimensions is quite wasteful --- not all countries have 86 states or 854 counties per state. The total number of GADM geometries for levels 0, 1, and 2 is ~48,000 which is much smaller than 23 x 5 x 7 x 248 x 86 x 854 = 14_662_360_160.

We end up with one humoungous 56GB chunk, that is mostly empty (sparsity ~ 48,000/14_662_360_160 ~ 0.2%).

## We can do better using a sparse array

Since the results are very sparse, we can instruct flox to construct dense arrays of intermediate results on the full 23 x 5 x 7 x 248 x 86 x 854 output grid.

```python
ReindexStrategy(
    # do not reindex to the full output grid at the blockwise aggregation stage
    blockwise=False,
    # when combining intermediate results after blockwise aggregation, reindex to the
    # common grid using a sparse.COO array type
    array_type=ReindexArrayType.SPARSE_COO
)
```

In [None]:
from flox import ReindexArrayType, ReindexStrategy

result = xarray_reduce(
    ds.areas,
    *by,
    expected_groups=expected_groups,
    func="sum",
    reindex=ReindexStrategy(
        blockwise=False,
        array_type=ReindexArrayType.SPARSE_COO,
    ),
    fill_value=0,
)
result

The output is a sparse array (see the **Data type** section)! Note that the size of this array cannot be estimated without computing it.

The computation runs smoothly with low memory.

## Why

To understand why you might do this, here is how flox runs reductions. In the images below, the `areas` array on the left has 5 2D chunks. Each color represents a group, each square represents a value of the array; clearly there are different groups in each chunk. 


### reindex = True

<img src="../_images/new-map-reduce-reindex-True-annotated.svg" width=100%>

First, the grouped-reduction is run on each chunk independently, and the results are constructed as _dense_ arrays on the full 23 x 5 x 7 x 248 x 86 x 854 output grid. This means that every chunk balloons to ~50GB. This method cannot work well.

### reindex = False with sparse intermediates

<img src="../_images/new-map-reduce-reindex-False-annotated.svg" width=100%>

First, the grouped-reduction is run on each chunk independently. Conceptually the result after this step is an array with differently sized chunks. 

Next results from neighbouring blocks are concatenated and a reduction is run again. These results are first aligned or reindexed to a common grid of group labels, termed "reindexing". At this stage, we instruct flox to construct a _sparse array_ during reindexing, otherwise we will eventually end up constructing _dense_ reindexed arrays of shape 23 x 5 x 7 x 248 x 86 x 854.


## Can we do better?

Yes. 

1. Using the reindexing machinery to convert intermediates to sparse is a little bit hacky. A better option would be to aggregate directly to sparse arrays, potentially using a new `engine="sparse"` ([issue](https://github.com/xarray-contrib/flox/issues/346)).
2. The total number of GADM geometries for levels 0, 1, and 2 is ~48,000. A much more sensible solution would be to allow grouping by these _geometries_ directly. This would allow us to be smart about the reduction, by exploiting the ideas underlying the [`method="cohorts"` strategy](../implementation.md#method-cohorts).

Regardless, the ability to do such reindexing allows flox to scale to much larger grouper arrays than previously possible.

