# Analyzing the National Water Model with Xarray, Dask, and Coiled


This example was originally adapted from [this notebook](https://github.com/dcherian/dask-demo/blob/main/nwm-aws.ipynb), authored by Deepak Cherian, Kevin Sampson, and Matthew Rocklin.

Datasets with high spatio-temporal resolution can get very large very quickly, vastly exceeding the resources you may have on a laptop. Dask integrates with Xarray to support parallel computing and you can use Coiled to scale this to the cloud.

## The National Water Model Dataset

In this example, we'll perform a county-wise aggregation of output from the National Water Model (NWM) available on the [AWS Open Data Registry](https://registry.opendata.aws/nwm-archive/).


The National Water Model (NWM) is a highly complex hydrological modeling framework that simulates observed and forecasted streamflow across the entire continental US at a very fine spatial and temporal scale. It’s used by private and public organizations across the US to inform decision making around water management and to predict when and where flooding will occur. You can [read more from the Office of Water Prediction](https://water.noaa.gov/about/nwm).

## Problem description

We’ll calculate the mean depth to soil saturation for each US county: 

- Years: 1979-2020
- Temporal resolution: 3-hourly land surface output
- Spatial resolution: 250 m grid (~820 feet, over half a lap on a track,  2.5 soccer fields)
- 277 terabytes!

We'll use a few tools to help make this happen:
- `dask` + `coiled` process the dataset in parallel in the cloud
- `xarray` + `flox` to work with the multi-dimensional Zarr datset and aggregate to county-level means from the 250m grid.

## Start a Coiled cluster

To demonstrate calculation on a cloud-available dataset, we will use [Coiled](https://coiled.io) to set up a dask cluster in AWS `us-east-1`.

In [14]:
import coiled

cluster = coiled.Cluster(
    region="us-east-1",
    n_workers=10,
    tags={"project": "nwm"},
    scheduler_vm_types="r6in.xlarge",
    worker_vm_types="r6in.2xlarge"
)

client = cluster.get_client()

cluster.adapt(minimum=10, maximum=50)

Output()

Output()

2023-08-03 12:22:27,831 - distributed.deploy.adaptive - INFO - Adaptive scaling started: minimum=10 maximum=50


<coiled.cluster.CoiledAdaptive at 0x15bcac350>

## Setup

In [15]:
%load_ext watermark

import flox  # make sure its available
import fsspec
import numpy as np
import rioxarray
import xarray as xr

xr.set_options(
    display_expand_attrs=False,
    display_expand_coords=False,
    display_expand_data=True,
)

%watermark -iv

The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark
rioxarray: 0.14.1
xarray   : 2023.7.0
numpy    : 1.24.4
coiled   : 0.9.0
flox     : 0.7.2
fsspec   : 2023.6.0



## Load NWM data

In [16]:
ds = xr.open_zarr(
    fsspec.get_mapper("s3://noaa-nwm-retrospective-2-1-zarr-pds/rtout.zarr", anon=True),
    consolidated=True,
    chunks={"time": 896, "x": 350, "y": 350}
)
ds

Unnamed: 0,Array,Chunk
Bytes,252.30 TiB,837.40 MiB
Shape,"(122479, 15360, 18432)","(896, 350, 350)"
Dask graph,319484 chunks in 2 graph layers,319484 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 252.30 TiB 837.40 MiB Shape (122479, 15360, 18432) (896, 350, 350) Dask graph 319484 chunks in 2 graph layers Data type float64 numpy.ndarray",18432  15360  122479,

Unnamed: 0,Array,Chunk
Bytes,252.30 TiB,837.40 MiB
Shape,"(122479, 15360, 18432)","(896, 350, 350)"
Dask graph,319484 chunks in 2 graph layers,319484 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,252.30 TiB,837.40 MiB
Shape,"(122479, 15360, 18432)","(896, 350, 350)"
Dask graph,319484 chunks in 2 graph layers,319484 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 252.30 TiB 837.40 MiB Shape (122479, 15360, 18432) (896, 350, 350) Dask graph 319484 chunks in 2 graph layers Data type float64 numpy.ndarray",18432  15360  122479,

Unnamed: 0,Array,Chunk
Bytes,252.30 TiB,837.40 MiB
Shape,"(122479, 15360, 18432)","(896, 350, 350)"
Dask graph,319484 chunks in 2 graph layers,319484 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


Each field in this dataset is big!

In [17]:
ds.zwattablrt

Unnamed: 0,Array,Chunk
Bytes,252.30 TiB,837.40 MiB
Shape,"(122479, 15360, 18432)","(896, 350, 350)"
Dask graph,319484 chunks in 2 graph layers,319484 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 252.30 TiB 837.40 MiB Shape (122479, 15360, 18432) (896, 350, 350) Dask graph 319484 chunks in 2 graph layers Data type float64 numpy.ndarray",18432  15360  122479,

Unnamed: 0,Array,Chunk
Bytes,252.30 TiB,837.40 MiB
Shape,"(122479, 15360, 18432)","(896, 350, 350)"
Dask graph,319484 chunks in 2 graph layers,319484 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


We'll subset to a single year subset for demo purposes

In [18]:
subset = ds.zwattablrt.sel(time=slice("2020-01-01", "2020-12-31"))
subset

Unnamed: 0,Array,Chunk
Bytes,6.03 TiB,837.40 MiB
Shape,"(2928, 15360, 18432)","(896, 350, 350)"
Dask graph,9328 chunks in 3 graph layers,9328 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 6.03 TiB 837.40 MiB Shape (2928, 15360, 18432) (896, 350, 350) Dask graph 9328 chunks in 3 graph layers Data type float64 numpy.ndarray",18432  15360  2928,

Unnamed: 0,Array,Chunk
Bytes,6.03 TiB,837.40 MiB
Shape,"(2928, 15360, 18432)","(896, 350, 350)"
Dask graph,9328 chunks in 3 graph layers,9328 chunks in 3 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


## Load county raster for grouping

A raster TIFF file identifying counties by a unique integer was created separately and saved. 

We load that using [rioxarray](https://corteva.github.io/rioxarray/html/rioxarray.html)

In [19]:
import fsspec
import rioxarray

fs = fsspec.filesystem("s3", requester_pays=True)

counties = rioxarray.open_rasterio(
    fs.open("s3://nwm-250m-us-counties/Counties_on_250m_grid.tif"), chunks="auto"
).squeeze()

# remove any small floating point error in coordinate locations
_, counties_aligned = xr.align(subset, counties, join="override")

counties_aligned

Unnamed: 0,Array,Chunk
Bytes,1.05 GiB,127.97 MiB
Shape,"(15360, 18432)","(1820, 18432)"
Dask graph,9 chunks in 3 graph layers,9 chunks in 3 graph layers
Data type,int32 numpy.ndarray,int32 numpy.ndarray
"Array Chunk Bytes 1.05 GiB 127.97 MiB Shape (15360, 18432) (1820, 18432) Dask graph 9 chunks in 3 graph layers Data type int32 numpy.ndarray",18432  15360,

Unnamed: 0,Array,Chunk
Bytes,1.05 GiB,127.97 MiB
Shape,"(15360, 18432)","(1820, 18432)"
Dask graph,9 chunks in 3 graph layers,9 chunks in 3 graph layers
Data type,int32 numpy.ndarray,int32 numpy.ndarray


We'll need the unique county IDs later, calculate that now.

In [20]:
county_id = np.unique(counties_aligned.data).compute()
county_id = county_id[county_id != 0]
print(f"There are {len(county_id)} counties!")

There are 3108 counties!


## GroupBy with flox

We could run the computation as

```python
subset.groupby(counties_aligned).mean()
```

This would use flox in the background.

However it would also load `counties_aligned` in to memory (an unfortunate Xarray implementation detail) which is not so bad (only a gig). To avoid egress charges, we'll instead go through `flox.xarray` which allows you to lazily groupby a dask array (here `counties_aligned`) as long as you pass in the expected group labels in `expected_groups`.

See [here](https://flox.readthedocs.io/en/latest/intro.html#with-dask) for more.

In [21]:
import flox.xarray

county_mean = flox.xarray.xarray_reduce(
    subset,
    counties_aligned.rename("county"),
    func="mean",
    expected_groups=(county_id,),
)

county_mean

Unnamed: 0,Array,Chunk
Bytes,69.43 MiB,21.25 MiB
Shape,"(2928, 3108)","(896, 3108)"
Dask graph,4 chunks in 17 graph layers,4 chunks in 17 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 69.43 MiB 21.25 MiB Shape (2928, 3108) (896, 3108) Dask graph 4 chunks in 17 graph layers Data type float64 numpy.ndarray",3108  2928,

Unnamed: 0,Array,Chunk
Bytes,69.43 MiB,21.25 MiB
Shape,"(2928, 3108)","(896, 3108)"
Dask graph,4 chunks in 17 graph layers,4 chunks in 17 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [22]:
county_mean.load()

The computation proceeds very nicely. 

We don't anticipate trouble scaling this computation up to the full dataset.

## Cleanup

In [23]:
cluster.shutdown()

2023-08-03 12:36:51,023 - distributed.deploy.adaptive_core - INFO - Adaptive stop
2023-08-03 12:36:51,583 - distributed.deploy.adaptive_core - INFO - Adaptive stop


## Visualize yearly mean

Read county shapefile, combo of state FIPS code and county FIPS code as multi-index

In [35]:
import geopandas as gpd
import hvplot.pandas

counties = gpd.read_file(
    "https://www2.census.gov/geo/tiger/GENZ2022/shp/cb_2022_us_county_20m.zip"
).to_crs("EPSG:3395")
counties["STATEFP"] = counties.STATEFP.astype(int)
counties["COUNTYFP"] = counties.COUNTYFP.astype(int)
continental = counties[~counties["STATEFP"].isin([2, 15, 72])].set_index(["STATEFP", "COUNTYFP"])

Interpret `county` as combo of state FIPS code and county FIPS code. Set multi-index

In [28]:
yearly_mean = county_mean.mean("time")
yearly_mean.coords["STATEFP"] = (yearly_mean.county // 1000).astype(int)
yearly_mean.coords["COUNTYFP"] = np.mod(yearly_mean.county, 1000).astype(int)
yearly_mean = yearly_mean.drop_vars("county").set_index(county=["STATEFP", "COUNTYFP"])
yearly_mean

Join

In [31]:
continental["zwattablrt"] = yearly_mean.to_dataframe()["zwattablrt"]

Plot

In [34]:
continental.hvplot(
    c="zwattablrt",
    cmap='turbo_r',
    title="Mean Depth to Soil Saturation in 2020 by US County (meters)",
    xaxis=None,
    yaxis=None
)