In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib widget

## Load data

### Drifters

In [2]:
import numpy as np
import xarray as xr

from sealagrangiax.sealagrangiax.input import Preprocessing
from sealagrangiax.sealagrangiax.input.drifter_preprocessors import gdp as gdp_preproc

In [3]:
ds = xr.open_dataset("../../data/gdp1h.zarr", engine="zarr")
ds

In [4]:
bbox = {"min_lat": 25, "max_lat": 50, "min_lon": 100, "max_lon": 169.9}

gdp_preproc = Preprocessing([
    gdp_preproc.DeployBbox(**bbox),
    gdp_preproc.GPSLocationType(),
    gdp_preproc.SVPBuoyTypes(),
    gdp_preproc.DeployDate(),
    gdp_preproc.Time(),
    gdp_preproc.LocationBbox(**bbox),
    gdp_preproc.Drogued(),
    gdp_preproc.FiniteValue(),
    gdp_preproc.Outlier(),
#    gdp_preproc.OpenOcean(coastline_file="../../data/coastline.json"),
    gdp_preproc.Chunk(n_days=9, dt=np.timedelta64(1, "h"), to_ragged=False)
])

gdp1h = gdp_preproc(ds)
gdp1h

Subsetting on drifter deploy location...
(min_lat, max_lat, min_lon, max_lon: 25, 50, 100, 169.9)
# traj: 788 ; # obs: 8883375
Subsetting to GPS location type...
# traj: 227 ; # obs: 3404389
Subsetting to SVP buoy types...
# traj: 227 ; # obs: 3404389
Subsetting to post 2000 deployments...
Subsetting to pre 2023-06-07 deployments...
# traj: 226 ; # obs: 3385010
Subsetting to post 2000 observations...
Subsetting to pre 2023-06-07 observations...
# traj: 226 ; # obs: 3385010
Subsetting on observation location...
(min_lat, max_lat, min_lon, max_lon: 25, 50, 100, 169.9)
# traj: 225 ; # obs: 1300550
Subsetting to drogued observations...
# traj: 225 ; # obs: 862085
Subsetting to finite value observations...
# traj: 225 ; # obs: 862085
Subsetting to plausible value observations...
(velocity_cutoff: 10m/s, latlon_err_cutoff: 0.5°)
# traj: 225 ; # obs: 858147
Chunking in equally sampled trajectories...
(n_days: 9, dt: 1 hours)
# traj: 3610 ; # obs: 217


### Sea surface currents

In [5]:
ssc = xr.open_zarr("/Users/bertrava/DUACS/cmems_obs-sl_glo_phy-ssh_my_allsat-l4-duacs-0.25deg_P1D.zarr")

ssc = ssc.rename({
    "adt": "ssh",
    "ugos": "u",
    "vgos": "v",
    "longitude": "lon_t",
    "latitude": "lat_t"
})

ssc

Unnamed: 0,Array,Chunk
Bytes,85.86 GiB,2.81 MiB
Shape,"(11115, 720, 1440)","(1, 720, 512)"
Dask graph,33345 chunks in 2 graph layers,33345 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 85.86 GiB 2.81 MiB Shape (11115, 720, 1440) (1, 720, 512) Dask graph 33345 chunks in 2 graph layers Data type float64 numpy.ndarray",1440  720  11115,

Unnamed: 0,Array,Chunk
Bytes,85.86 GiB,2.81 MiB
Shape,"(11115, 720, 1440)","(1, 720, 512)"
Dask graph,33345 chunks in 2 graph layers,33345 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,85.86 GiB,2.81 MiB
Shape,"(11115, 720, 1440)","(1, 720, 512)"
Dask graph,33345 chunks in 2 graph layers,33345 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 85.86 GiB 2.81 MiB Shape (11115, 720, 1440) (1, 720, 512) Dask graph 33345 chunks in 2 graph layers Data type float64 numpy.ndarray",1440  720  11115,

Unnamed: 0,Array,Chunk
Bytes,85.86 GiB,2.81 MiB
Shape,"(11115, 720, 1440)","(1, 720, 512)"
Dask graph,33345 chunks in 2 graph layers,33345 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,85.86 GiB,2.81 MiB
Shape,"(11115, 720, 1440)","(1, 720, 512)"
Dask graph,33345 chunks in 2 graph layers,33345 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 85.86 GiB 2.81 MiB Shape (11115, 720, 1440) (1, 720, 512) Dask graph 33345 chunks in 2 graph layers Data type float64 numpy.ndarray",1440  720  11115,

Unnamed: 0,Array,Chunk
Bytes,85.86 GiB,2.81 MiB
Shape,"(11115, 720, 1440)","(1, 720, 512)"
Dask graph,33345 chunks in 2 graph layers,33345 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


## Batch drifter data

In [6]:
import clouddrift as cd
import jax.numpy as jnp
import torch
from torch.utils.data import Dataset as TorchDataset

from sealagrangiax.sealagrangiax.gridded import Dataset
from sealagrangiax.sealagrangiax.trajectory import Displacement, Location
from sealagrangiax.sealagrangiax.utils import UNIT
from sealagrangiax.sealagrangiax.utils.unit import time_in_seconds

**custom dataset**

In [14]:
class XArrayDataset(TorchDataset):
    def __init__(self, drifter_ds: xr.Dataset, ssc_ds: xr.Dataset, is_ragged: bool = False):
        self.drifter_ds = drifter_ds
        self.ssc_ds = ssc_ds
        self.is_ragged = is_ragged

    def __len__(self):
        return self.drifter_ds.traj.size

    def __getitem__(self, idx: int):
        drifter_lat, drifter_lon, drifter_time, drifter_id = self.__get_drifter_arrays(idx)
        ssc_variables, ssc_time, ssc_lat, ssc_lon = self.__get_ssc_arrays(drifter_lat, drifter_lon, drifter_time)
        
        drifter_lat, drifter_lon, drifter_time, drifter_id = (
            torch.from_numpy(drifter_lat), torch.from_numpy(drifter_lon), torch.from_numpy(drifter_time), torch.from_numpy(drifter_id)
        )
        
        for k in ssc_variables.keys():
            ssc_variables[k] = torch.from_numpy(ssc_variables[k])
        ssc_time, ssc_lat, ssc_lon = torch.from_numpy(ssc_time), torch.from_numpy(ssc_lat), torch.from_numpy(ssc_lon)
        
        return (drifter_lat, drifter_lon, drifter_time, drifter_id), (ssc_variables, ssc_time, ssc_lat, ssc_lon)
    
    def __get_drifter_arrays(self, idx: int):
        if self.is_ragged:
            drifter_subset = cd.ragged.subset(self.drifter_ds, {"traj": idx}, row_dim_name="traj")
        else:
            drifter_subset = self.drifter_ds.isel(traj=idx)
        
        drifter_lat = drifter_subset.lat.data.ravel()
        drifter_lon = drifter_subset.lon.data.ravel()
        drifter_time = time_in_seconds(drifter_subset.time.data.ravel())
        drifter_id = drifter_subset.id.data.ravel()
        
        return drifter_lat, drifter_lon, drifter_time, drifter_id
    
    def __get_ssc_arrays(self, drifter_lat, drifter_lon, drifter_time):
        one_day = 60 * 60 * 24
        min_time = drifter_time[0] - one_day
        max_time = drifter_time[-1] + one_day
        n_days = (max_time - min_time) / one_day

        max_travel_distance = .5  # in °/day ; inferred from data
        max_travel_distance *= (n_days - 2)  # in °
        max_travel_distance = Displacement(
            jnp.full(2, max_travel_distance, dtype=float), unit=UNIT.degrees
        )

        x0 = Location(jnp.asarray((drifter_lat[0], drifter_lon[0])))
        min_corner = Location(x0 - max_travel_distance)
        max_corner = Location(x0 + max_travel_distance)

        min_latitude = max(-89.875, min_corner.latitude.item())
        max_latitude = min(89.875, max_corner.latitude.item())
        min_longitude = max(-179.875, min_corner.longitude.item())
        max_longitude = min(179.875, max_corner.longitude.item())
        
        ssc_patch = self.ssc_ds.sel(
            lon_t=slice(min_longitude, max_longitude),
            lat_t=slice(min_latitude, max_latitude),
            time=slice(np.datetime64(min_time.item(), "s"), np.datetime64(max_time.item() - 1, "s"))
        )
        
        ssc_variables, ssc_time, ssc_lat, ssc_lon = Dataset.to_arrays(
            ssc_patch,
            variables={"u": "u", "v": "v"},
            coordinates={"time": "time", "latitude": "lat_t", "longitude": "lon_t"},
            to_jax=False
        )
        
        return ssc_variables, ssc_time, ssc_lat, ssc_lon
    
dataset = XArrayDataset(gdp1h, ssc)

In [17]:
%time _ = dataset[10]

CPU times: user 82 ms, sys: 15.9 ms, total: 97.9 ms
Wall time: 66.4 ms


In [18]:
dataset[1500]

((tensor([35.5267, 35.5232, 35.5180, 35.5101, 35.5014, 35.4929, 35.4854, 35.4785,
          35.4716, 35.4640, 35.4551, 35.4444, 35.4327, 35.4214, 35.4097, 35.3969,
          35.3831, 35.3695, 35.3568, 35.3463, 35.3373, 35.3295, 35.3216, 35.3144,
          35.3079, 35.3012, 35.2935, 35.2847, 35.2750, 35.2648, 35.2541, 35.2429,
          35.2313, 35.2189, 35.2051, 35.1913, 35.1779, 35.1649, 35.1529, 35.1423,
          35.1334, 35.1263, 35.1195, 35.1119, 35.1032, 35.0946, 35.0865, 35.0785,
          35.0694, 35.0588, 35.0475, 35.0357, 35.0233, 35.0103, 34.9967, 34.9830,
          34.9693, 34.9552, 34.9404, 34.9263, 34.9139, 34.9031, 34.8931, 34.8843,
          34.8769, 34.8712, 34.8663, 34.8623, 34.8595, 34.8570, 34.8541, 34.8503,
          34.8456, 34.8394, 34.8297, 34.8170, 34.8012, 34.7853, 34.7704, 34.7574,
          34.7461, 34.7361, 34.7272, 34.7200, 34.7147, 34.7098, 34.7050, 34.7002,
          34.6966, 34.6942, 34.6920, 34.6895, 34.6871, 34.6847, 34.6837, 34.6825,
          34.679

**custom dataloader**

In [19]:
import equinox as eqx
import dask
from torch.utils.data import DataLoader

from sealagrangiax.sealagrangiax.trajectory import Trajectory

In [25]:
dask.config.set(scheduler="threads", num_workers=4)    

dataloader = DataLoader(
    dataset, 
    shuffle=True, batch_size=64,
    num_workers=8, prefetch_factor=3,
    persistent_workers=True, multiprocessing_context="fork"
)

In [26]:
import time

In [27]:
t0 = time.time()
for i, (drifter_arrays, ssc_arrays) in enumerate(dataloader):
    print(i)
    print(time.time() - t0)
    
    drifter_lat, drifter_lon, drifter_time, drifter_id = drifter_arrays
    ssc_variables, ssc_time, ssc_lat, ssc_lon = ssc_arrays
    
    drifter_latlon = jnp.asarray(torch.stack((drifter_lat, drifter_lon), dim=-1))
    drifter_trajectories = eqx.filter_vmap(Trajectory)(
        drifter_latlon, jnp.asarray(drifter_time), jnp.asarray(drifter_id)
    )
    
    for k in ssc_variables.keys():
        ssc_variables[k] = jnp.asarray(ssc_variables[k])
        
    ssc_patches = eqx.filter_vmap(Dataset.from_arrays)(
        ssc_variables, jnp.asarray(ssc_time), jnp.asarray(ssc_lat), jnp.asarray(ssc_lon)
    )
    
    sample = (drifter_trajectories, ssc_patches)
    
    t0 = time.time()

  self.pid = os.fork()


0
12.419975996017456
1
0.018904924392700195
2
0.0010051727294921875
3
0.0006260871887207031
4
0.0021429061889648438
5
0.0005030632019042969
6
0.000591278076171875
7
0.0003960132598876953
8
4.9452879428863525
9
0.0027511119842529297
10
0.05753326416015625
11
0.00041294097900390625
12
0.00041794776916503906
13
0.007472991943359375
14
0.0004811286926269531
15
0.0006439685821533203
16
5.112571954727173
17
0.00728297233581543
18
0.0004248619079589844
19
0.003239870071411133
20
0.0004029273986816406
21
0.026800870895385742
22
0.0009701251983642578
23
0.0005240440368652344
24
5.564537048339844
25
0.0008490085601806641
26
0.016004085540771484
27
0.01860809326171875
28
0.0004820823669433594
29
0.0006990432739257812
30
0.004866123199462891
31
0.0006721019744873047
32
5.480940103530884
33
0.00040411949157714844
34
0.004733085632324219
35
0.0006108283996582031
36
0.0004968643188476562
37
0.005650997161865234
38
0.0009419918060302734
39
0.0003859996795654297
40
7.755730152130127
41
0.00124406814575

In [28]:
drifter_trajectories

Trajectory(
  _states=Location(
    value=f32[26,217,2],
    what=EnumerationItem(
      _value=i32[26],
      _enumeration=<class 'sealagrangiax.sealagrangiax.utils.what.WHAT'>
    ),
    unit=EnumerationItem(
      _value=i32[26],
      _enumeration=<class 'sealagrangiax.sealagrangiax.utils.unit.UNIT'>
    )
  ),
  _times=Time(
    value=i32[26,217],
    what=EnumerationItem(
      _value=i32[26],
      _enumeration=<class 'sealagrangiax.sealagrangiax.utils.what.WHAT'>
    ),
    unit=EnumerationItem(
      _value=i32[26],
      _enumeration=<class 'sealagrangiax.sealagrangiax.utils.unit.UNIT'>
    )
  ),
  length=217,
  id=i32[26,1]
)

In [29]:
ssc_patches

Dataset(
  variables={
    'u':
    Spatiotemporal(
      values=f32[26,11,36,36],
      temporal_field=Interpolator1D(
        x=f32[26,11],
        f=f32[26,11,36,36],
        derivs={'fx': f32[26,11,36,36]},
        method='linear',
        extrap=True,
        period=None,
        axis=0
      ),
      spatial_field=Interpolator2D(
        x=f32[26,36],
        y=f32[26,36],
        f=f32[26,36,36,11],
        derivs={
          'fx':
          f32[26,36,36,11],
          'fxy':
          f32[26,36,36,11],
          'fy':
          f32[26,36,36,11]
        },
        method='linear',
        extrap=True,
        period=(None, 360),
        axis=0
      ),
      spatiotemporal_field=Interpolator3D(
        x=f32[26,11],
        y=f32[26,36],
        z=f32[26,36],
        f=f32[26,11,36,36],
        derivs={
          'fx':
          f32[26,11,36,36],
          'fxy':
          f32[26,11,36,36],
          'fxyz':
          f32[26,11,36,36],
          'fxz':
          f32[26,11,36,36]

Do not create intermediate jax.Array in the Dataset object. No real benefit on the performances though...

During training, `num_workers=8, prefetch_factor=3` should be adjusted. See: https://earthmover.io/blog/cloud-native-dataloader/