This notebook interpolates variables from gridded satellite datasets to drifters space/time locations.

Satellite and drifters data for the period 2010â€“2020 are:

- Geostrophic currents from DUACS (https://doi.org/10.48670/moi-00148),
- Stokes drift from WAVERYS/MFWAM (https://doi.org/10.48670/moi-00022),
- Wind stress and 10 m velocity from ERA5 (https://doi.org/10.48670/moi-00185),
- Drogued-SVP drifters from the GPD (https://doi.org/10.25921/x46c-3620).

In [None]:
import time

import dask
dask.config.set({"array.slicing.split_large_chunks": True})
import numpy as np
from tqdm import tqdm
import xarray as xr

# Interpolation

In [None]:
# for finding the data files
start_datetime_str = "1994-06-01"
end_datetime_str = "2025-08-01"
output_directory = "/summer/meom/workdir/bertrava/data"

In [None]:
# actually used start and end datetimes
start_datetime = None
end_datetime = None

In [None]:
gdp_ds = xr.open_zarr(f"gdp-v2.01.1_{start_datetime_str}_{end_datetime_str}.zarr")
if start_datetime is not None and end_datetime is not None:
    gdp_ds = gdp_ds.where(
        ((gdp_ds.time >= start_datetime) & (gdp_ds.time < end_datetime)).compute(), drop=True
    )

In [None]:
start_datetime = gdp_ds.time.min().values
end_datetime = gdp_ds.time.max().values + np.timedelta64(1, "D")

In [None]:
duacs_ds = xr.open_zarr(
    f"cmems_obs-sl_glo_phy-ssh_my_allsat-l4-duacs-0.125deg_P1D_{start_datetime_str}_{end_datetime_str}.zarr"
).sel(time=slice(start_datetime, end_datetime))

waverys_ds = xr.open_zarr(
    f"cmems_mod_glo_wav_my_0.2deg_PT3H-i_{start_datetime_str}_{end_datetime_str}.zarr"
).sel(time=slice(start_datetime, end_datetime))

if end_datetime > np.datetime64("2008-01-01"):
    if start_datetime > np.datetime64("2008-01-01"):
        era5_ds = xr.open_zarr(
            f"cmems_obs-wind_glo_phy_my_l4_0.125deg_PT1H_{start_datetime_str}_{end_datetime_str}.zarr"
        ).sel(time=slice(start_datetime, end_datetime))

        era5_dss = [era5_ds,]
    else:
        era5_ds1 = xr.open_zarr(
            f"cmems_obs-wind_glo_phy_my_l4_0.125deg_PT1H_2008-01-01_{end_datetime_str}.zarr"
        ).sel(time=slice(start_datetime, end_datetime))

        era5_ds2 = xr.open_zarr(
            f"cmems_obs-wind_glo_phy_my_l4_0.25deg_PT1H_{start_datetime_str}_2008-01-01.zarr"
        ).sel(time=slice(start_datetime, end_datetime))

        era5_dss = [era5_ds1, era5_ds2,]
else:
    era5_ds = xr.open_zarr(
        f"cmems_obs-wind_glo_phy_my_l4_0.25deg_PT1H_{start_datetime_str}_{end_datetime_str}.zarr"
    ).sel(time=slice(start_datetime, end_datetime))
    
    era5_dss = [era5_ds,]

In [None]:
# ease sequencial interpolation by sorting time dimension

gdp_ds = gdp_ds.isel(points=np.argsort(gdp_ds.time.values))

N = int(gdp_ds.points.size)
gdp_ds["points"] = np.arange(N, dtype=np.int32)

In [None]:
out_store = "gdp_interp_tmp.zarr"
chunk_size = 500_000

In [None]:
gdp_ds = gdp_ds.chunk({"points": chunk_size}).drop_encoding()
# gdp_ds.to_zarr(out_store, consolidated=False)

In [None]:
vars_to_write = [
    "ugos",
    "vgos",
    "VSDX",
    "VSDY",
    "eastward_stress",
    "northward_stress",
    "eastward_wind",
    "northward_wind",
]
vars_to_ds = {
    "ugos": duacs_ds,
    "vgos": duacs_ds,
    "VSDX": waverys_ds,
    "VSDY": waverys_ds,
    "eastward_stress": era5_dss[0],
    "northward_stress": era5_dss[0],
    "eastward_wind": era5_dss[0],
    "northward_wind": era5_dss[0],
}

template = xr.Dataset(
    {
        v: (
            ("points",), 
            np.full(N, np.nan, dtype=np.float32), 
            vars_to_ds[v][v].attrs,
        )
        for v in vars_to_write
    },
    coords={"points": gdp_ds.points},
).chunk({"points": chunk_size})

# template.to_zarr(out_store, mode="a", consolidated=False)

In [None]:
buf_lat = 0.25
buf_lon = 0.25

batch_size = 50_000  # should fit in memory and aligned with `chunk_size`
start_i = 68350000  # for resuming interrupted runs
for i in tqdm(range(start_i, N, batch_size)):
    batch = gdp_ds.isel(points=slice(i, i + batch_size))

    batch_time = batch.time
    batch_lat = batch.lat
    batch_lon = batch.lon

    # matters for the last batch
    bsz = batch_time.shape[0]
    region = {"points": slice(i, i + bsz)}

    print("Processing batch:", i, "to", i + bsz, "with", batch_time.shape[0], "points")

    time_min = batch_time.min().values
    time_max = batch_time.max().values
    lat_min = float(np.nanmin(batch_lat)) - buf_lat
    lat_max = float(np.nanmax(batch_lat)) + buf_lat
    lon_min = float(np.nanmin(batch_lon)) - buf_lon
    lon_max = float(np.nanmax(batch_lon)) + buf_lon

    # ------------------ DUACS (daily) ------------------
    print("DUACS interpolation...")
    ds = duacs_ds.sel(time=slice(time_min - np.timedelta64(1, "D"), time_max + np.timedelta64(1, "D")))
    ds = ds.sel(longitude=slice(lon_min, lon_max))
    ds = ds.sel(latitude=slice(lat_min, lat_max))

    not_done = True
    while not_done:
        try:
            duacs_obs = ds[["ugos", "vgos"]].interp(time=batch_time, latitude=batch_lat, longitude=batch_lon).compute()
        except OSError:
            print("OSError during DUACS interpolation")
            time.sleep(1)
        else:
            not_done = False
    
    duacs_obs= duacs_obs.drop_vars(["time", "latitude", "longitude"])
    not_done = True
    while not_done:
        try:
            duacs_obs.to_zarr(out_store, region=region, consolidated=False)
        except OSError as e:
            print("OSError while writing DUACS to_zarr")
            time.sleep(1)
        else:
            not_done = False

    del ds, duacs_obs

    # ------------------ WAVERYS (3-hourly) ------------------
    print("WAVERYS interpolation...")
    ds = waverys_ds.sel(time=slice(time_min - np.timedelta64(3, "h"), time_max + np.timedelta64(3, "h")))
    ds = ds.sel(longitude=slice(lon_min, lon_max))
    ds = ds.sel(latitude=slice(lat_min, lat_max))

    # remove duplicate times
    ds = ds.isel(time=~ds.time.to_index().duplicated())

    not_done = True
    while not_done:
        try:
            waverys_obs = ds[["VSDX", "VSDY"]].interp(
                time=batch_time, latitude=batch_lat, longitude=batch_lon
            ).compute()
        except OSError:
            print("OSError during WAVERYS interpolation")
            time.sleep(1)
        else:
            not_done = False

    waverys_obs= waverys_obs.drop_vars(["time", "latitude", "longitude"])
    not_done = True
    while not_done:
        try:
            waverys_obs.to_zarr(out_store, region=region, consolidated=False)
        except OSError as e:
            print("OSError while writing WAVERYS to_zarr")
            time.sleep(1)
        else:
            not_done = False

    del ds, waverys_obs

    # ------------------ ERA5 (hourly) ------------------
    print("ERA5 interpolation...")

    is_after_2008 = batch_time.values >= np.datetime64("2008-01-01")
    is_before_2008 = ~is_after_2008
    
    if is_before_2008.all():
        continue

    if is_after_2008.all():
        pairs = [(is_after_2008, era5_dss[0])]
    elif is_before_2008.all():
        pairs = [(is_before_2008, era5_dss[1])]
        pairs = []
    else:
        pairs = [(is_before_2008, era5_dss[1]), (is_after_2008, era5_dss[0])]

    era5_obs_all = []
    for mask, era5_ds in pairs:
        ds = era5_ds.sel(time=slice(time_min - np.timedelta64(1, "h"), time_max + np.timedelta64(1, "h")))
        ds = ds.sel(longitude=slice(lon_min, lon_max))
        ds = ds.sel(latitude=slice(lat_min, lat_max))

        not_done = True
        while not_done:
            try:
                era5_obs = ds.interp(
                    time=batch_time[mask], latitude=batch_lat[mask], longitude=batch_lon[mask]
                ).compute()
            except OSError:
                print("OSError during ERA5 interpolation")
                time.sleep(1)
            else:
                not_done = False
        
        era5_obs_all.append(era5_obs)
    
    if len(era5_obs_all) == 0:
        continue
    
    era5_obs_all = xr.concat(era5_obs_all, dim="points").sortby("points")
    era5_obs_all = era5_obs_all.drop_vars(["time", "latitude", "longitude"])
    
    not_done = True
    while not_done:
        try:
            era5_obs_all.to_zarr(out_store, region=region, consolidated=False)
        except OSError as e:
            print("OSError while writing ERA5 to_zarr")
            time.sleep(1)
        else:
            not_done = False

    del ds, era5_obs_all

In [None]:
gdp_ds_interp = xr.open_zarr(out_store)
gdp_ds_interp.dropna(
    dim="points", how="any"
).chunk(
    {"points": chunk_size}
).to_zarr(f"gdp_interp_{start_datetime_str}_{end_datetime_str}.zarr", mode="w", consolidated=True)