# Interpolate a gridded field onto the drifter locations
- using the ERA5 datasets ("northward_wind_at_10_metres", "eastward_wind_at_10_metres") on AWS
- using the GDP datasets on AWS

In [1]:
import numpy as np
import dask
import s3fs
import xarray as xr
from tqdm import tqdm

In [2]:
s3_fs = s3fs.S3FileSystem(anon=True, use_listings_cache=False)

# *Lazy*-loading ERA5 Zarr archive
- based on this [example](https://nbviewer.org/github/awslabs/amazon-asdi/blob/main/examples/dask/notebooks/era5_zarr.ipynb)

In [3]:
def fix_accum_var_dims(ds, var):
    # Some variables like precip have extra time bounds variables, we drop them here to allow merging with other variables

    # Select variable of interest (drops dims that are not linked to current variable)
    ds = ds[[var]]

    if var in [
        "air_temperature_at_2_metres",
        "dew_point_temperature_at_2_metres",
        "air_pressure_at_mean_sea_level",
        "northward_wind_at_10_metres",
        "eastward_wind_at_10_metres",
    ]:
        ds = ds.rename({"time0": "valid_time_end_utc"})

    elif var in [
        "precipitation_amount_1hour_Accumulation",
        "integral_wrt_time_of_surface_direct_downwelling_shortwave_flux_in_air_1hour_Accumulation",
    ]:
        ds = ds.rename({"time1": "valid_time_end_utc"})

    else:
        print(
            "Warning, Haven't seen {var} varible yet! Time renaming might not work.".format(
                var=var
            )
        )

    return ds


@dask.delayed
def s3open(path):
    fs = s3fs.S3FileSystem(
        anon=True, default_fill_cache=False, config_kwargs={"max_pool_connections": 20}
    )
    return s3fs.S3Map(path, s3=fs)


def open_era5_range(start_year, end_year, variables):
    """Opens ERA5 monthly Zarr files in S3, given a start and end year (all months loaded) and a list of variables"""

    file_pattern = "era5-pds/zarr/{year}/{month}/data/{var}.zarr/"

    years = list(np.arange(start_year, end_year + 1, 1))
    months = ["01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12"]

    l = []
    for var in variables:
        print(var)

        # Get files
        files_mapper = [
            s3open(file_pattern.format(year=year, month=month, var=var))
            for year in years
            for month in months
        ]

        # Look up correct time dimension by variable name
        if var in ["precipitation_amount_1hour_Accumulation"]:
            concat_dim = "time1"
        else:
            concat_dim = "time0"

        # Lazy load
        ds = xr.open_mfdataset(
            files_mapper,
            engine="zarr",
            concat_dim=concat_dim,
            combine="nested",
            coords="minimal",
            compat="override",
            parallel=True,
        )

        # Fix dimension names
        ds = fix_accum_var_dims(ds, var)
        l.append(ds)

    ds_out = xr.merge(l)

    return ds_out

NameError: name 'dask' is not defined

In [4]:
%%time

ds_era = open_era5_range(
    1979,
    2020,
    [
        "northward_wind_at_10_metres",
        "eastward_wind_at_10_metres",
    ],  # only keep the required variables
)

northward_wind_at_10_metres
eastward_wind_at_10_metres
CPU times: user 24.9 s, sys: 3.18 s, total: 28 s
Wall time: 1min 29s


In [5]:
ds_era

Unnamed: 0,Array,Chunk
Bytes,1.39 TiB,31.93 MiB
Shape,"(368184, 721, 1440)","(372, 150, 150)"
Dask graph,50400 chunks in 1009 graph layers,50400 chunks in 1009 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.39 TiB 31.93 MiB Shape (368184, 721, 1440) (372, 150, 150) Dask graph 50400 chunks in 1009 graph layers Data type float32 numpy.ndarray",1440  721  368184,

Unnamed: 0,Array,Chunk
Bytes,1.39 TiB,31.93 MiB
Shape,"(368184, 721, 1440)","(372, 150, 150)"
Dask graph,50400 chunks in 1009 graph layers,50400 chunks in 1009 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.39 TiB,31.93 MiB
Shape,"(368184, 721, 1440)","(372, 150, 150)"
Dask graph,50400 chunks in 1009 graph layers,50400 chunks in 1009 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.39 TiB 31.93 MiB Shape (368184, 721, 1440) (372, 150, 150) Dask graph 50400 chunks in 1009 graph layers Data type float32 numpy.ndarray",1440  721  368184,

Unnamed: 0,Array,Chunk
Bytes,1.39 TiB,31.93 MiB
Shape,"(368184, 721, 1440)","(372, 150, 150)"
Dask graph,50400 chunks in 1009 graph layers,50400 chunks in 1009 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


The total dataset is pretty large, but not currently loaded in memory. Once we start the interpolation process, the required "chunks" (or parts of the Datasets) will be downloaded from AWS and loaded automatically into memory .

In [6]:
print("ds size in TB {:0.2f}\n".format(ds_era.nbytes / 1e12))

ds size in TB 3.06



# The GDP dataset is also available on AWS ([link](https://registry.opendata.aws/noaa-oar-hourly-gdp/))

In [7]:
url_path = "https://noaa-oar-hourly-gdp-pds.s3.us-east-1.amazonaws.com/latest/"
file = "gdp_v2.00.nc"
ds = xr.open_dataset(url_path + file + "#mode=bytes")

In [8]:
ds

We create a dataset that will receive the interpolated data. This requires downloading the coordinates of the GDP datasets (`ids`, `time`, `lat`, `lon`) so it can take some times (~3min with 200 Mbits connection).

In [9]:
ds_int = xr.Dataset(
    data_vars=dict(u10=np.zeros(ds.dims["obs"]), v10=np.zeros(ds.dims["obs"])),
    coords=dict(
        ids=ds.ids.values,
        longitude=ds.lon.values,
        latitude=ds.lat.values,
        time=ds.time.values,
    ),
)

## Interpolating the 10m eastward and northward winds
- we use Xarray [Advanced Interpolation](https://docs.xarray.dev/en/stable/user-guide/interpolation.html#advanced-interpolation) methodology. Otherwise, interpolating using `n` longitudes and `n` latitudes will results in a `n x n` interpolated DataArray. The idea is simple, we define a new coordinates `s` along the trajectory described by the `lon` and `lat` coordinates, then the interpolation will be performed along this new coordinate.

### a single trajectory

In [10]:
traj_idx = np.insert(np.cumsum(ds["rowsize"].values), 0, 0)

In [11]:
%%time
i = 0
s_i = slice(traj_idx[i], traj_idx[i + 1])
lon_i = xr.DataArray(ds.lon[s_i].values % 360, dims="z")  # ERA5 is 0-360
lat_i = xr.DataArray(ds.lat[s_i].values, dims="z")
time_i = xr.DataArray(ds.time[s_i].values, dims="z")

era_int = ds_era.interp(lon=lon_i, lat=lat_i, valid_time_end_utc=time_i).compute()
ds_int["u10"].values[s_i] = era_int["eastward_wind_at_10_metres"]
ds_int["v10"].values[s_i] = era_int["northward_wind_at_10_metres"]

CPU times: user 2.79 s, sys: 811 ms, total: 3.6 s
Wall time: 12 s


In [12]:
era_int

In [13]:
# validate the interpolated data coordinates matches the trajectory
np.testing.assert_allclose(ds.lon[s_i] % 360, era_int.lon)
np.testing.assert_allclose(ds.lat[s_i] % 360, era_int.lat)
assert len(ds.lon[s_i]) == era_int.dims["z"]

### looping a few trajectories (10s to 3min (!) per trajectory)

Ideas to speed up the process:
- run this directly on AWS
- if running on a cluster or locally, it might be better to perform the loop in parallel so more calculations can be performed while downloading

In [15]:
%%time

for i in tqdm(range(100)):
    s_i = slice(traj_idx[i], traj_idx[i + 1])
    lon_i = xr.DataArray(ds.lon[s_i].values % 360, dims="z")  # ERA5 is 0-360
    lat_i = xr.DataArray(ds.lat[s_i].values, dims="z")
    time_i = xr.DataArray(ds.time[s_i].values, dims="z")
    era_int = ds_era.interp(lon=lon_i, lat=lat_i, valid_time_end_utc=time_i)
    ds_int["u10"].values[s_i] = era_int["eastward_wind_at_10_metres"]
    ds_int["v10"].values[s_i] = era_int["northward_wind_at_10_metres"]

 13%|█▎        | 13/100 [08:34<57:26, 39.61s/it]  


KeyboardInterrupt: 