<img src="http://xarray.pydata.org/en/stable/_static/dataset-diagram-logo.png" align="right" width="30%">

# Grouped Computations with Xarray

In this lesson, we discuss how to do scientific computations with defined "groups" of data
within our xarray objects. Our learning goals are as follows:

- Perform "split / apply / combine" workflows in Xarray using `groupby`,
  including
  - reductions within groups
  - transformations on groups
- Use the `resample`, `rolling` and `coarsen` functions to manipulate data.


In [None]:
import numpy as np
import xarray as xr
from matplotlib import pyplot as plt

## Example Dataset

First we load a dataset. We will use the
[NOAA Extended Reconstructed Sea Surface Temperature (ERSST) v5](https://www.ncdc.noaa.gov/data-access/marineocean-data/extended-reconstructed-sea-surface-temperature-ersst-v5)
product, a widely used and trusted gridded compilation of of historical data
going back to 1854.

Since the data is provided via an
[OPeNDAP](https://en.wikipedia.org/wiki/OPeNDAP) server, we can load it directly
without downloading anything:


In [None]:
### NOTE: If hundreds of people connect to this server at once and download the same dataset,
###       things might not go so well! Recommended to use the Google Cloud copy instead.

# url = "http://www.esrl.noaa.gov/psd/thredds/dodsC/Datasets/noaa.ersst.v5/sst.mnmean.nc"
# # drop an unnecessary variable which complicates some operations
# ds = xr.open_dataset(url, drop_variables=["time_bnds"])
# # will take a minute or two to complete
# ds = ds.sel(time=slice("1960", "2018")).load()
# ds

In [None]:
import gcsfs

fs = gcsfs.GCSFileSystem(token="anon")
ds = xr.open_zarr(
    fs.get_mapper("gs://pangeo-noaa-ncei/noaa.ersst.v5.zarr"), consolidated=True
).load()
ds

Let's do some basic visualizations of the data, just to make sure it looks
reasonable.


In [None]:
ds.sst[0].plot(vmin=-2, vmax=30)

## Groupby

Xarray copies Pandas' very useful groupby functionality, enabling the "split /
apply / combine" workflow on xarray DataArrays and Datasets.

To provide a physically motivated example, let's examine a timeseries of SST at
a single point.


In [None]:
ds.sst.sel(lon=300, lat=50).plot()

As we can see from the plot, the timeseries at any one point is totally
dominated by the seasonal cycle. We would like to remove this seasonal cycle
(called the "climatology") in order to better see the long-term variaitions in
temperature. We can accomplish this using **groupby**.

Before moving forward, we note that xarray correctly parsed the time index,
resulting in a Pandas datetime index on the time dimension.


In [None]:
ds.time

The syntax of Xarray's groupby is almost identical to Pandas.


In [None]:
?ds.groupby

### Split Step

The most important argument is `group`: this defines the unique values we will
us to "split" the data for grouped analysis. We can pass either a DataArray or a
name of a variable in the dataset. Let's first use a DataArray. Just like with
Pandas, we can use the time index to extract specific components of dates and
times. Xarray uses a special syntax for this `.dt`, called the
`DatetimeAccessor`.


In [None]:
ds.time.dt

In [None]:
ds.time.dt.month

In [None]:
ds.time.dt.year

We can use these arrays in a groupby operation:


In [None]:
gb = ds.groupby(ds.time.dt.month)
gb

Xarray also offers a more concise syntax when the variable you're grouping on is
already present in the dataset. This is identical to the previous line:


In [None]:
gb = ds.groupby("time.month")
gb

Now that the data are split, we can manually iterate over the group. The
iterator returns the key (group name) and the value (the actual dataset
corresponding to that group) for each group.


In [None]:
for group_name, group_ds in gb:
    # stop iterating after the first loop
    break
print(group_name)
group_ds

### Apply & Combine

Now that we have groups defined, it's time to "apply" a calculation to the
group. Like in Pandas, these calculations can either be:

- _aggregation_: reduces the size of the group
- _transformation_: preserves the group's full size

At then end of the apply step, xarray will automatically combine the aggregated
/ transformed groups back into a single object.

The most fundamental way to apply is with the `.map` method.


In [None]:
?gb.map

#### Aggregations

`.map` accepts as its argument a function that expects and return xarray
objects. We define a custom function. This function takes a single argument--the
group dataset--and returns a new dataset to be combined:


In [None]:
def time_mean(a):
    return a.mean(dim="time")


gb.map(time_mean)

Like Pandas, xarray's groupby object has many built-in aggregation operations
(e.g. `mean`, `min`, `max`, `std`, etc):


In [None]:
# this does the same thing as the previous cell
ds_mm = gb.mean(dim="time")
ds_mm

So we did what we wanted to do: calculate the climatology at every point in the
dataset. Let's look at the data a bit.

_Climatology at a specific point in the North Atlantic_


In [None]:
ds_mm.sst.sel(lon=300, lat=50).plot()

_Zonal Mean Climatology_


In [None]:
ds_mm.sst.mean(dim="lon").plot.contourf(x="month", levels=12, vmin=-2, vmax=30)

_Difference between January and July Climatology_


In [None]:
(ds_mm.sst.sel(month=1) - ds_mm.sst.sel(month=7)).plot(vmax=10)

#### Transformations

Now we want to _remove_ this climatology from the dataset, to examine the
residual, called the _anomaly_, which is the interesting part from a climate
perspective. Removing the seasonal climatology is a perfect example of a
transformation: it operates over a group, but doesn't change the size of the
dataset. Here is one way to code it


In [None]:
def remove_time_mean(x):
    return x - x.mean(dim="time")


ds_anom = ds.groupby("time.month").map(remove_time_mean)
ds_anom

Xarray makes these sorts of transformations easy by supporting _groupby
arithmetic_. This concept is easiest explained with an example:


In [None]:
gb = ds.groupby("time.month")
ds_anom = gb - gb.mean(dim="time")
ds_anom

Now we can view the climate signal without the overwhelming influence of the
seasonal cycle.

_Timeseries at a single point in the North Atlantic_


In [None]:
ds_anom.sst.sel(lon=300, lat=50).plot()

_Difference between Jan. 1 2018 and Jan. 1 1960_


In [None]:
(ds_anom.sel(time="2018-01-01") - ds_anom.sel(time="1960-01-01")).sst.plot()

## Grouby-Related: Resample, Rolling, Coarsen

Resample in xarray is nearly identical to Pandas. It is effectively a group-by
operation, and uses the same basic syntax. It can be applied only to time-index
dimensions. Here we compute the five-year mean.


In [None]:
resample_obj = ds_anom.resample(time="5Y")
resample_obj

In [None]:
ds_anom_resample = resample_obj.mean(dim="time")
ds_anom_resample

In [None]:
ds_anom.sst.sel(lon=300, lat=50).plot()
ds_anom_resample.sst.sel(lon=300, lat=50).plot(marker="o")

<div class="alert alert-info">
    <strong>Note:</strong> <code>resample</code> only works with proper datetime indexes.
</div>

Rolling is also similar to pandas, but can be applied along any dimension. It
works with logical coordinates.


In [None]:
ds_anom_rolling = ds_anom.rolling(time=12, center=True).mean()
ds_anom_rolling

In [None]:
ds_anom.sst.sel(lon=300, lat=50).plot(label="monthly anom")
ds_anom_resample.sst.sel(lon=300, lat=50).plot(marker="o", label="5 year resample")
ds_anom_rolling.sst.sel(lon=300, lat=50).plot(label="12 month rolling mean")
plt.legend()

`coarsen` does something similar to `resample`, but without being aware of time.
It operates on logical coordinates only but can work on multiple dimensions at a
time.


In [None]:
ds_anom_coarsen_time = ds_anom.coarsen(time=12).mean()

ds_anom_rolling.sst.sel(lon=300, lat=50).plot(label="12 month rolling mean")
ds_anom_coarsen_time.sst.sel(lon=300, lat=50).plot(marker="^", label="12 item coarsen")
plt.legend()

In [None]:
# We expect an error here
ds_anom_coarsen_space = ds_anom.coarsen(lon=4, lat=4).mean()

In [None]:
ds_anom_coarsen_space = ds_anom.isel(lat=slice(0, -1)).coarsen(lon=4, lat=4).mean()
ds_anom_coarsen_space

In [None]:
ds_anom_coarsen_space.sst.isel(time=0).plot()

## Exercise

Load the following "basin mask" dataset, and use it to take a weighted average
of SST in each ocean basin. Figure out which ocean basins are the warmest and
coldest.

**Hint:** you will first need to align this dataset with the SST dataset. Use
what you learned in the "indexing and alignment" lesson.


In [None]:
basin = xr.open_dataset(
    "http://iridl.ldeo.columbia.edu/SOURCES/.NOAA/.NODC/.WOA09/.Masks/.basin/dods"
)
basin