# flox: Faster GroupBy aggregations with Xarray

**Authors** Deepak Cherian, Kevin Sampson, Matthew Rocklin

Significantly faster groupby calculations are now possible through a new-ish package in the Xarray/Dask/Pangeo ecosystem called [flox](https://flox.readthedocs.io/en/latest/). Practically, this means faster climatologies, faster resampling, faster histogramming, and faster compositing of array datasets.


## The National Water Model Dataset


To show off, we demonstrate county-wise aggregation of output from the National Water Model (NWM) available on the [AWS Public Data Registry](https://registry.opendata.aws/nwm-archive/).


Quoting [the NOAA page](https://water.noaa.gov/about/nwm) for more.

> The National Water Model (NWM) is a hydrologic modelling framework that simulates observed and forecast streamflow over the entire continental United States (CONUS). The NWM simulates the water cycle with mathematical representations of the different processes and how they fit together. This complex representation of physical processes such as snowmelt and infiltration and movement of water through the soil layers varies significantly with changing elevations, soils, vegetation types and a host of other variables. Additionally, extreme variability in precipitation over short distances and times can cause the response on rivers and streams to change very quickly. Overall, the process is so complex that to simulate it with a mathematical model means that it needs a very high powered computer or supercomputer in order to run in the time frame needed to support decision makers when flooding is threatened.  

> All CONUS model configurations provide streamflow for 2.7 million river reaches and other hydrologic information on 1km and 250m grids.



## Problem description

We want to calculate county-level means for 3 hourly time series data on the 250m grid. This is a *Groupby* problem. 

GroupBy is a term used for a very common analysis pattern commonly called "split-apply-combine" ([Wickham, 2011](https://www.jstatsoft.org/article/view/v040i01)) wherein an analyst 
- *Splits* a dataset into groups (e.g. counties),
- *Applies* a transformation to each group of data (here a reduction like `.mean`)
- *Combines* the results of `apply` to form a new dataset


For this problem we will split the dataset into counties, apply the `mean`, and then combine the results back.

With [Xarray](https://docs.xarray.dev/en/stable/user-guide/groupby.html), this would look like
```python
dataset.groupby(counties).mean()
```

However Xarray's default algorithm is a simple for-loop over groups and doesn't work very well for large distributed problems.

## Enter `flox`.

`flox` solves a long-standing problem in the Pangeo array computing ecosytem of computing GroupBy reductions. It implements a parallel groupby algorithm (using a tree reduction) to substantially improve performance of groupby reductions with dask. 
  - Specifically, `flox` speeds up [reduction methods](https://flox.readthedocs.io/en/latest/aggregations.html) like `groupby(...).mean()`, `groupby(...).max()`, etc, but not `groupby.map`.
  - `flox` also significantly speeds up groupby reductions with pure numpy arrays using optimized implementations in the [`numpy-groupies` package](https://github.com/ml31415/numpy-groupies).
  - `flox` allows more complicated groupby operations such as lazy grouping by a dask array, and grouping by multiple variables. Use `flox.xarray.xarray_reduce` for [these operations](https://flox.readthedocs.io/en/latest/xarray.html). Xarray currently only supports grouping by a single numpy variable.

See [here](https://flox.readthedocs.io/en/latest/intro.html) for short examples.

### How do I use it?

Run `mamba install flox` and `xarray>=2022.06.0` will use it by default for `.groupby`, `.groupby_bins`, and `.resample`!

A lot of effort was spent in ensuring backwards compatibility, so your workloads should only work better. Let us know if it [does not](https://github.com/pydata/xarray/issues)


## Setup cluster with Coiled

To demonstrate calculation on a cloud-available dataset, we will use [Coiled](https://coiled.io) to set up a dask cluster in AWS `us-east-1`.

In [None]:
import coiled

cluster = coiled.Cluster(
    region="us-east-1",
    n_workers=40,
)

client = cluster.get_client()

## Setup

In [None]:
%load_ext watermark

import flox  # make sure its available
import fsspec
import numpy as np
import rioxarray
import xarray as xr

xr.set_options(
    display_expand_attrs=False,
    display_expand_coords=False,
    display_expand_data=True,
)

%watermark -iv

## Load NWM data

In [None]:
ds = xr.open_zarr(
    fsspec.get_mapper("s3://noaa-nwm-retrospective-2-1-zarr-pds/rtout.zarr", anon=True),
    consolidated=True,
)
ds

Each field in this dataset is big!

In [None]:
ds.zwattablrt

We'll subset to a single year subset for demo purposes

In [None]:
subset = ds.zwattablrt.sel(time=slice("2001-01-01", "2020-12-31"))
subset

## Load county raster for grouping

A raster TIFF file identifying counties by a unique integer was created separately and saved. 

We load that using [rioxarray](https://corteva.github.io/rioxarray/html/rioxarray.html)

In [None]:
import fsspec
import rioxarray

fs = fsspec.filesystem("s3", requester_pays=True)

counties = rioxarray.open_rasterio(
    fs.open("s3://nwm-250m-us-counties/Counties_on_250m_grid.tif"), chunks="auto"
).squeeze()
counties

# remove any small floating point error in coordinate locations
_, counties_aligned = xr.align(ds, counties, join="override")

counties_aligned

We'll need the unique county IDs later, calculate that now.

In [None]:
county_id = np.unique(counties_aligned.data).compute()
county_id = county_id[county_id != 0]
print(f"There are {len(county_id)} counties!")

## GroupBy with flox

We could run the computation as

```python
subset.groupby(counties_aligned).mean()
```

This would use flox in the background.

However it would also load `counties_aligned` in to memory (an unfortunate Xarray implementation detail) which is not so bad (only a gig). To avoid egress charges, we'll instead go through `flox.xarray` which allows you to lazily groupby a dask array (here `counties_aligned`) as long as you pass in the expected group labels in `expected_groups`.

See [here](https://flox.readthedocs.io/en/latest/intro.html#with-dask) for more.

In [None]:
import flox.xarray

county_mean = flox.xarray.xarray_reduce(
    subset,
    counties_aligned.rename("county"),
    func="mean",
    expected_groups=(county_id,),
)

county_mean

In [None]:
county_mean.load()

The computation proceeds very nicely. 

We don't anticipate trouble scaling this computation up to the full dataset.

## Cleanup

In [None]:
cluster.shutdown()

## Visualize yearly mean

Read county shapefile, combo of state FIPS code and county FIPS code as multi-index

In [None]:
import geopandas as gpd
import hvplot.pandas

counties = gpd.read_file(
    "https://www2.census.gov/geo/tiger/GENZ2022/shp/cb_2022_us_county_20m.zip"
).to_crs("EPSG:3395")
counties["STATEFP"] = counties.STATEFP.astype(int)
counties["COUNTYFP"] = counties.COUNTYFP.astype(int)
continental = counties.loc[~counties["STATEFP"].isin([2, 15, 72])]
continental = continental.set_index(["STATEFP", "COUNTYFP"])

Interpret `county` as combo of state FIPS code and county FIPS code. Set multi-index

In [None]:
yearly_mean = county_mean.mean("time")
yearly_mean.coords["STATEFP"] = (yearly_mean.county // 1000).astype(int)
yearly_mean.coords["COUNTYFP"] = np.mod(yearly_mean.county, 1000).astype(int)
yearly_mean = yearly_mean.drop_vars("county").set_index(county=["STATEFP", "COUNTYFP"])
yearly_mean

Join

In [None]:
continental["zwattablrt"] = yearly_mean.to_dataframe()["zwattablrt"]

Plot

In [None]:
continental.hvplot(c="zwattablrt")

## Summary

[flox](https://flox.readthedocs.io) makes many large Groupby problems tractable! Use it.


[flox](https://flox.readthedocs.io) also makes many small but more complicated (e.g. multiple variables) Groupby problems tractable! Use it.

We [anticipate](https://github.com/pydata/xarray/issues/6610) upgrading Xarray's interface to enable more complicated GroupBy computations. In the mean time, use flox!