# flox: Faster GroupBy aggregations with Xarray

**Authors** Deepak Cherian, Kevin Sampson, Matthew Rocklin

Significantly faster groupby calculations are now possible through a new-ish package in the Xarray/Dask/Pangeo ecosystem called [flox](https://flox.readthedocs.io/en/latest/). Practically, this means faster climatologies, faster resampling, faster histogramming, and faster compositing of array datasets.


## The National Water Model Dataset


To show off, we demonstrate county-wise aggregation of output from the National Water Model (NWM) available on the [AWS Public Data Registry](https://registry.opendata.aws/nwm-archive/).


Quoting [the NOAA page](https://water.noaa.gov/about/nwm) for more.

> The National Water Model (NWM) is a hydrologic modelling framework that simulates observed and forecast streamflow over the entire continental United States (CONUS). The NWM simulates the water cycle with mathematical representations of the different processes and how they fit together. This complex representation of physical processes such as snowmelt and infiltration and movement of water through the soil layers varies significantly with changing elevations, soils, vegetation types and a host of other variables. Additionally, extreme variability in precipitation over short distances and times can cause the response on rivers and streams to change very quickly. Overall, the process is so complex that to simulate it with a mathematical model means that it needs a very high powered computer or supercomputer in order to run in the time frame needed to support decision makers when flooding is threatened.  

> All CONUS model configurations provide streamflow for 2.7 million river reaches and other hydrologic information on 1km and 250m grids.



## Problem description

We want to calculate county-level means for 3 hourly time series data on the 250m grid. This is a *Groupby* problem. 

GroupBy is a term used for a very common analysis pattern commonly called "split-apply-combine" ([Wickham, 2011](https://www.jstatsoft.org/article/view/v040i01)) wherein an analyst 
- *Splits* a dataset into groups (e.g. counties),
- *Applies* a transformation to each group of data (here a reduction like `.mean`)
- *Combines* the results of `apply` to form a new dataset


For this problem we will split the dataset into counties, apply the `mean`, and then combine the results back.

With [Xarray](https://docs.xarray.dev/en/stable/user-guide/groupby.html), this would look like
```python
dataset.groupby(counties).mean()
```

However Xarray's default algorithm is a simple for-loop over groups and doesn't work very well for large distributed problems.

## Enter `flox`.

`flox` solves a long-standing problem in the Pangeo array computing ecosytem of computing GroupBy reductions. It implements a parallel groupby algorithm (using a tree reduction) to substantially improve performance of groupby reductions with dask. 
  - Specifically, `flox` speeds up [reduction methods](https://flox.readthedocs.io/en/latest/aggregations.html) like `groupby(...).mean()`, `groupby(...).max()`, etc, but not `groupby.map`.
  - `flox` also significantly speeds up groupby reductions with pure numpy arrays using optimized implementations in the [`numpy-groupies` package](https://github.com/ml31415/numpy-groupies).
  - `flox` allows more complicated groupby operations such as lazy grouping by a dask array, and grouping by multiple variables. Use `flox.xarray.xarray_reduce` for [these operations](https://flox.readthedocs.io/en/latest/xarray.html). Xarray currently only supports grouping by a single numpy variable.

See [here](https://flox.readthedocs.io/en/latest/intro.html) for short examples.

### How do I use it?

Run `mamba install flox` and `xarray>=2022.06.0` will use it by default for `.groupby`, `.groupby_bins`, and `.resample`!

A lot of effort was spent in ensuring backwards compatibility, so your workloads should only work better. Let us know if it [does not](https://github.com/pydata/xarray/issues)


## Setup cluster with Coiled

To demonstrate calculation on a cloud-available dataset, we will use [Coiled](https://coiled.io) to set up a dask cluster in AWS `us-east-1`.

In [60]:
import coiled

cluster = coiled.Cluster(
    region="us-east-1",
    n_workers=10,
    tags={"project": "nwm"},
    scheduler_vm_types="r7g.large",
    worker_vm_types="c7g.xlarge"
)

client = cluster.get_client()

cluster.adapt(minimum=10, maximum=40)

Output()

Output()

2023-07-24 16:53:35,865 - distributed.deploy.adaptive - INFO - Adaptive scaling started: minimum=10 maximum=40


<coiled.cluster.CoiledAdaptive at 0x3fa8e3310>

## Setup

In [61]:
%load_ext watermark

import flox  # make sure its available
import fsspec
import numpy as np
import rioxarray
import xarray as xr

xr.set_options(
    display_expand_attrs=False,
    display_expand_coords=False,
    display_expand_data=True,
)

%watermark -iv

The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark
rioxarray: 0.14.1
geopandas: 0.13.2
xarray   : 2023.7.0
flox     : 0.7.2
hvplot   : 0.8.4
fsspec   : 2023.6.0
numpy    : 1.24.4
coiled   : 0.8.14



## Load NWM data

In [62]:
ds = xr.open_zarr(
    fsspec.get_mapper("s3://noaa-nwm-retrospective-2-1-zarr-pds/rtout.zarr", anon=True),
    consolidated=True,
)
ds

Unnamed: 0,Array,Chunk
Bytes,252.30 TiB,209.35 MiB
Shape,"(122479, 15360, 18432)","(224, 350, 350)"
Dask graph,1275604 chunks in 2 graph layers,1275604 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 252.30 TiB 209.35 MiB Shape (122479, 15360, 18432) (224, 350, 350) Dask graph 1275604 chunks in 2 graph layers Data type float64 numpy.ndarray",18432  15360  122479,

Unnamed: 0,Array,Chunk
Bytes,252.30 TiB,209.35 MiB
Shape,"(122479, 15360, 18432)","(224, 350, 350)"
Dask graph,1275604 chunks in 2 graph layers,1275604 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,252.30 TiB,209.35 MiB
Shape,"(122479, 15360, 18432)","(224, 350, 350)"
Dask graph,1275604 chunks in 2 graph layers,1275604 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 252.30 TiB 209.35 MiB Shape (122479, 15360, 18432) (224, 350, 350) Dask graph 1275604 chunks in 2 graph layers Data type float64 numpy.ndarray",18432  15360  122479,

Unnamed: 0,Array,Chunk
Bytes,252.30 TiB,209.35 MiB
Shape,"(122479, 15360, 18432)","(224, 350, 350)"
Dask graph,1275604 chunks in 2 graph layers,1275604 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


Each field in this dataset is big!

In [63]:
ds.zwattablrt

Unnamed: 0,Array,Chunk
Bytes,252.30 TiB,209.35 MiB
Shape,"(122479, 15360, 18432)","(224, 350, 350)"
Dask graph,1275604 chunks in 2 graph layers,1275604 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 252.30 TiB 209.35 MiB Shape (122479, 15360, 18432) (224, 350, 350) Dask graph 1275604 chunks in 2 graph layers Data type float64 numpy.ndarray",18432  15360  122479,

Unnamed: 0,Array,Chunk
Bytes,252.30 TiB,209.35 MiB
Shape,"(122479, 15360, 18432)","(224, 350, 350)"
Dask graph,1275604 chunks in 2 graph layers,1275604 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


We'll subset to a single year subset for demo purposes

In [64]:
# subset = ds.zwattablrt.sel(time=slice("2001-01-01", "2020-12-31"))
subset = ds.zwattablrt.sel(time=slice("2001-01-01", "2002-12-31"))
subset

Unnamed: 0,Array,Chunk
Bytes,12.03 TiB,209.35 MiB
Shape,"(5840, 15360, 18432)","(224, 350, 350)"
Dask graph,62964 chunks in 3 graph layers,62964 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 12.03 TiB 209.35 MiB Shape (5840, 15360, 18432) (224, 350, 350) Dask graph 62964 chunks in 3 graph layers Data type float64 numpy.ndarray",18432  15360  5840,

Unnamed: 0,Array,Chunk
Bytes,12.03 TiB,209.35 MiB
Shape,"(5840, 15360, 18432)","(224, 350, 350)"
Dask graph,62964 chunks in 3 graph layers,62964 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


## Load county raster for grouping

A raster TIFF file identifying counties by a unique integer was created separately and saved. 

We load that using [rioxarray](https://corteva.github.io/rioxarray/html/rioxarray.html)

In [65]:
import fsspec
import rioxarray

fs = fsspec.filesystem("s3", requester_pays=True)

counties = rioxarray.open_rasterio(
    fs.open("s3://nwm-250m-us-counties/Counties_on_250m_grid.tif"), chunks="auto"
).squeeze()
counties

# remove any small floating point error in coordinate locations
_, counties_aligned = xr.align(ds, counties, join="override")

counties_aligned

Unnamed: 0,Array,Chunk
Bytes,1.05 GiB,127.97 MiB
Shape,"(15360, 18432)","(1820, 18432)"
Dask graph,9 chunks in 3 graph layers,9 chunks in 3 graph layers
Data type,int32 numpy.ndarray,int32 numpy.ndarray
"Array Chunk Bytes 1.05 GiB 127.97 MiB Shape (15360, 18432) (1820, 18432) Dask graph 9 chunks in 3 graph layers Data type int32 numpy.ndarray",18432  15360,

Unnamed: 0,Array,Chunk
Bytes,1.05 GiB,127.97 MiB
Shape,"(15360, 18432)","(1820, 18432)"
Dask graph,9 chunks in 3 graph layers,9 chunks in 3 graph layers
Data type,int32 numpy.ndarray,int32 numpy.ndarray


We'll need the unique county IDs later, calculate that now.

In [66]:
county_id = np.unique(counties_aligned.data).compute()
county_id = county_id[county_id != 0]
print(f"There are {len(county_id)} counties!")

There are 3108 counties!


## GroupBy with flox

We could run the computation as

```python
subset.groupby(counties_aligned).mean()
```

This would use flox in the background.

However it would also load `counties_aligned` in to memory (an unfortunate Xarray implementation detail) which is not so bad (only a gig). To avoid egress charges, we'll instead go through `flox.xarray` which allows you to lazily groupby a dask array (here `counties_aligned`) as long as you pass in the expected group labels in `expected_groups`.

See [here](https://flox.readthedocs.io/en/latest/intro.html#with-dask) for more.

In [67]:
import flox.xarray

county_mean = flox.xarray.xarray_reduce(
    subset,
    counties_aligned.rename("county"),
    func="mean",
    expected_groups=(county_id,),
)

county_mean

Unnamed: 0,Array,Chunk
Bytes,138.48 MiB,5.31 MiB
Shape,"(5840, 3108)","(224, 3108)"
Dask graph,27 chunks in 17 graph layers,27 chunks in 17 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 138.48 MiB 5.31 MiB Shape (5840, 3108) (224, 3108) Dask graph 27 chunks in 17 graph layers Data type float64 numpy.ndarray",3108  5840,

Unnamed: 0,Array,Chunk
Bytes,138.48 MiB,5.31 MiB
Shape,"(5840, 3108)","(224, 3108)"
Dask graph,27 chunks in 17 graph layers,27 chunks in 17 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [69]:
county_mean.load()

The computation proceeds very nicely. 

We don't anticipate trouble scaling this computation up to the full dataset.

## Cleanup

In [70]:
cluster.shutdown()

2023-07-24 17:07:14,719 - distributed.deploy.adaptive_core - INFO - Adaptive stop
2023-07-24 17:07:16,200 - distributed.deploy.adaptive_core - INFO - Adaptive stop


In [84]:
county_mean.to_netcdf("mean_zwattablrt_nwm.nc")

In [88]:
county_mean = xr.open_dataset("mean_zwattablrt_nwm.nc")

## Visualize yearly mean

Read county shapefile, combo of state FIPS code and county FIPS code as multi-index

In [89]:
import geopandas as gpd
import hvplot.pandas

counties = gpd.read_file(
    "https://www2.census.gov/geo/tiger/GENZ2022/shp/cb_2022_us_county_20m.zip"
).to_crs("EPSG:3395")
counties["STATEFP"] = counties.STATEFP.astype(int)
counties["COUNTYFP"] = counties.COUNTYFP.astype(int)
continental = counties.loc[~counties["STATEFP"].isin([2, 15, 72])]
continental = continental.set_index(["STATEFP", "COUNTYFP"])

Interpret `county` as combo of state FIPS code and county FIPS code. Set multi-index

In [90]:
yearly_mean = county_mean.mean("time")
yearly_mean.coords["STATEFP"] = (yearly_mean.county // 1000).astype(int)
yearly_mean.coords["COUNTYFP"] = np.mod(yearly_mean.county, 1000).astype(int)
yearly_mean = yearly_mean.drop_vars("county").set_index(county=["STATEFP", "COUNTYFP"])
yearly_mean

Join

In [91]:
continental["zwattablrt"] = yearly_mean.to_dataframe()["zwattablrt"]

In [92]:
continental

Unnamed: 0_level_0,Unnamed: 1_level_0,COUNTYNS,AFFGEOID,GEOID,NAME,NAMELSAD,STUSPS,STATE_NAME,LSAD,ALAND,AWATER,geometry,zwattablrt
STATEFP,COUNTYFP,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
17,127,01784730,0500000US17127,17127,Massac,Massac County,IL,Illinois,06,614218330,12784614,"POLYGON ((-9899503.889 4455506.800, -9896867.1...",1.327917
27,17,00659454,0500000US27017,27017,Carlton,Carlton County,MN,Minnesota,06,2230473967,36173451,"POLYGON ((-10359540.224 5872907.231, -10274975...",0.989152
37,181,01008591,0500000US37181,37181,Vance,Vance County,NC,North Carolina,06,653701481,42190675,"POLYGON ((-8738332.957 4346219.160, -8733823.9...",1.615570
47,79,01639755,0500000US47079,47079,Henry,Henry County,TN,Tennessee,06,1455320362,81582236,"POLYGON ((-9853595.905 4344424.296, -9853101.8...",2.000000
6,21,00277275,0500000US06021,06021,Glenn,Glenn County,CA,California,06,3403160299,33693344,"POLYGON ((-13682478.726 4809490.889, -13586151...",1.927344
...,...,...,...,...,...,...,...,...,...,...,...,...,...
51,167,01497573,0500000US51167,51167,Russell,Russell County,VA,Virginia,06,1226421378,7558481,"POLYGON ((-9173464.253 4396242.834, -9161437.6...",1.353288
40,89,01101835,0500000US40089,40089,McCurtain,McCurtain County,OK,Oklahoma,06,4793496603,133942698,"POLYGON ((-10592775.814 4025907.915, -10592197...",1.585679
1,105,00161579,0500000US01105,01105,Perry,Perry County,AL,Alabama,06,1863936201,10902207,"POLYGON ((-9743260.401 3826639.779, -9737551.6...",1.079257
54,43,01550028,0500000US54043,54043,Lincoln,Lincoln County,WV,West Virginia,06,1132065764,4053564,"POLYGON ((-9149606.558 4559633.726, -9159736.8...",1.181518


Plot

In [93]:
continental.hvplot(c="zwattablrt")

In [94]:
continental.hvplot(
    c="zwattablrt",
    cmap='turbo_r',
    title="Average Water Table Depth in 2001 by US County (meters)",
    xaxis=None,
    yaxis=None
)

## Summary

[flox](https://flox.readthedocs.io) makes many large Groupby problems tractable! Use it.


[flox](https://flox.readthedocs.io) also makes many small but more complicated (e.g. multiple variables) Groupby problems tractable! Use it.

We [anticipate](https://github.com/pydata/xarray/issues/6610) upgrading Xarray's interface to enable more complicated GroupBy computations. In the mean time, use flox!