This notebook presents two approaches to load data from Xarray Datasets as batches of JAX custom Pytrees.

Best approache seems to be the use of PyTorch dataloading mechanism but enforcing manipulation of Numpy arrays (instead of PyTorch tensors).

Limitations that might degrade loading timings are mostly related to the fact that creating multiple processes for loading the data (as done by PyTorch) does not interact nicely with JAX and creates deadlocks.
What we might like to modify in an ideal world is to use JAX arrays (rather Numpy arrays) straight after extracting relevant data from Xarray datasets.

It is appealing to use [Grain](https://google-grain.readthedocs.io/en/latest/index.html) as it should allow to manipulate JAX arrays from the start and it is said to provide sharding (so multiple GPUs) capabilities, but unfortunately I was not able to even import it as a module after pip-installing it. 

In [1]:
%load_ext autoreload
%autoreload 2

## Xarray data

In [3]:
data_root = "/summer/meom"

In [4]:
import clouddrift as cd
import numpy as np
import xarray as xr

### Drifters

In [5]:
drifter_ds = xr.open_zarr(f"{data_root}/workdir/bertrava/noaa-oar-hourly-gdp-pds.zarr")
drifter_ds

Unnamed: 0,Array,Chunk
Bytes,10.80 kiB,10.80 kiB
Shape,"(1382,)","(1382,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 10.80 kiB 10.80 kiB Shape (1382,) (1382,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1382  1,

Unnamed: 0,Array,Chunk
Bytes,10.80 kiB,10.80 kiB
Shape,"(1382,)","(1382,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,24.33 MiB,2.94 MiB
Shape,"(3189334,)","(385186,)"
Dask graph,9 chunks in 2 graph layers,9 chunks in 2 graph layers
Data type,datetime64[ns] numpy.ndarray,datetime64[ns] numpy.ndarray
"Array Chunk Bytes 24.33 MiB 2.94 MiB Shape (3189334,) (385186,) Dask graph 9 chunks in 2 graph layers Data type datetime64[ns] numpy.ndarray",3189334  1,

Unnamed: 0,Array,Chunk
Bytes,24.33 MiB,2.94 MiB
Shape,"(3189334,)","(385186,)"
Dask graph,9 chunks in 2 graph layers,9 chunks in 2 graph layers
Data type,datetime64[ns] numpy.ndarray,datetime64[ns] numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,24.33 MiB,2.94 MiB
Shape,"(3189334,)","(385186,)"
Dask graph,9 chunks in 2 graph layers,9 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 24.33 MiB 2.94 MiB Shape (3189334,) (385186,) Dask graph 9 chunks in 2 graph layers Data type float64 numpy.ndarray",3189334  1,

Unnamed: 0,Array,Chunk
Bytes,24.33 MiB,2.94 MiB
Shape,"(3189334,)","(385186,)"
Dask graph,9 chunks in 2 graph layers,9 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,24.33 MiB,2.94 MiB
Shape,"(3189334,)","(385186,)"
Dask graph,9 chunks in 2 graph layers,9 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 24.33 MiB 2.94 MiB Shape (3189334,) (385186,) Dask graph 9 chunks in 2 graph layers Data type float64 numpy.ndarray",3189334  1,

Unnamed: 0,Array,Chunk
Bytes,24.33 MiB,2.94 MiB
Shape,"(3189334,)","(385186,)"
Dask graph,9 chunks in 2 graph layers,9 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,5.40 kiB,5.40 kiB
Shape,"(1382,)","(1382,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int32 numpy.ndarray,int32 numpy.ndarray
"Array Chunk Bytes 5.40 kiB 5.40 kiB Shape (1382,) (1382,) Dask graph 1 chunks in 2 graph layers Data type int32 numpy.ndarray",1382  1,

Unnamed: 0,Array,Chunk
Bytes,5.40 kiB,5.40 kiB
Shape,"(1382,)","(1382,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int32 numpy.ndarray,int32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,12.17 MiB,1.47 MiB
Shape,"(3189334,)","(385186,)"
Dask graph,9 chunks in 2 graph layers,9 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 12.17 MiB 1.47 MiB Shape (3189334,) (385186,) Dask graph 9 chunks in 2 graph layers Data type float32 numpy.ndarray",3189334  1,

Unnamed: 0,Array,Chunk
Bytes,12.17 MiB,1.47 MiB
Shape,"(3189334,)","(385186,)"
Dask graph,9 chunks in 2 graph layers,9 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,12.17 MiB,1.47 MiB
Shape,"(3189334,)","(385186,)"
Dask graph,9 chunks in 2 graph layers,9 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 12.17 MiB 1.47 MiB Shape (3189334,) (385186,) Dask graph 9 chunks in 2 graph layers Data type float32 numpy.ndarray",3189334  1,

Unnamed: 0,Array,Chunk
Bytes,12.17 MiB,1.47 MiB
Shape,"(3189334,)","(385186,)"
Dask graph,9 chunks in 2 graph layers,9 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [6]:
def chunk_trajectories(
    ds: xr.Dataset, n_days: int = 5, dt: np.timedelta64 = np.timedelta64(1, "h"), to_ragged: bool = False
) -> xr.Dataset:
    def ragged_chunk(arr: xr.DataArray | np.ndarray, is_metadata: bool = False) -> np.ndarray:
        arr = cd.ragged.apply_ragged(cd.ragged.chunk, arr, row_size, chunk_size)  # noqa
        if is_metadata:
            arr = arr[:, 0]
        if to_ragged:
            arr = arr.ravel()
        return arr

    if dt is None:
        dt = (ds.isel(traj=0).time[1] - ds.isel(traj=0).time[0])

    row_size = cd.ragged.segment(ds.time, dt, ds.rowsize)  # if holes, divide into segments
    chunk_size = int(n_days / (dt / np.timedelta64(1, "D"))) + 1

    # chunk along `obs` dimension (data)
    data = dict(
        [(d, ragged_chunk(ds[d])) for d in ["time", "lat", "lon"]]
    )

    # chunk along `traj` dimension (metadata)
    metadata = {"id": ragged_chunk(np.repeat(ds["id"], ds.rowsize), is_metadata=True)}
    metadata["rowsize"] = np.full(metadata["id"].size, chunk_size)  # noqa - after chunking the rowsize is constant

    # create xr.Dataset
    attrs_global = ds.attrs

    coord_dims = {}
    attrs_variables = {}
    for var in ds.coords.keys():
        var = str(var)
        coord_dims[var] = str(ds[var].dims[-1])
        attrs_variables[var] = ds[var].attrs

    for var in data.keys():
        attrs_variables[var] = ds[var].attrs

    for var in metadata.keys():
        attrs_variables[var] = ds[var].attrs

    metadata["drifter_id"] = metadata["id"]  # noqa
    del metadata["id"]
    attrs_variables["drifter_id"] = attrs_variables["id"]
    attrs_variables["id"] = {}

    if to_ragged:
        coords = {"id": np.arange(metadata["drifter_id"].size), "time": data.pop("time")}
        ragged_array = cd.RaggedArray(
            coords, metadata, data, attrs_global, attrs_variables, {"traj": "rows", "obs": "obs"}, coord_dims
        )
        ds = ragged_array.to_xarray()
    else:
        coords = {"id": np.arange(metadata["drifter_id"].size)}
        
        xr_coords = {}
        for var in coords.keys():
            xr_coords[var] = (
                [coord_dims[var]],
                coords[var],
                attrs_variables[var],
            )

        xr_data = {}
        for var in metadata.keys():
            xr_data[var] = (
                ["traj"],
                metadata[var],
                attrs_variables[var],
            )

        for var in data.keys():
            xr_data[var] = (
                ["traj", "obs"],
                data[var],
                attrs_variables[var],
            )

        ds = xr.Dataset(coords=xr_coords, data_vars=xr_data, attrs=attrs_global)

    return ds

In [7]:
traj_ds = chunk_trajectories(drifter_ds)
traj_ds

### Surface currents

In [8]:
ssc_ds = xr.open_zarr(f"{data_root}/workdir/bertrava/cmems_obs-sl_glo_phy-ssh_my_allsat-l4-duacs-0.125deg_P1D.zarr")
ssc_ds

Unnamed: 0,Array,Chunk
Bytes,11.28 GiB,4.00 MiB
Shape,"(365, 1440, 2880)","(1, 512, 1024)"
Dask graph,3285 chunks in 2 graph layers,3285 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 11.28 GiB 4.00 MiB Shape (365, 1440, 2880) (1, 512, 1024) Dask graph 3285 chunks in 2 graph layers Data type float64 numpy.ndarray",2880  1440  365,

Unnamed: 0,Array,Chunk
Bytes,11.28 GiB,4.00 MiB
Shape,"(365, 1440, 2880)","(1, 512, 1024)"
Dask graph,3285 chunks in 2 graph layers,3285 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,11.28 GiB,4.00 MiB
Shape,"(365, 1440, 2880)","(1, 512, 1024)"
Dask graph,3285 chunks in 2 graph layers,3285 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 11.28 GiB 4.00 MiB Shape (365, 1440, 2880) (1, 512, 1024) Dask graph 3285 chunks in 2 graph layers Data type float64 numpy.ndarray",2880  1440  365,

Unnamed: 0,Array,Chunk
Bytes,11.28 GiB,4.00 MiB
Shape,"(365, 1440, 2880)","(1, 512, 1024)"
Dask graph,3285 chunks in 2 graph layers,3285 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


## Datasets

In [12]:
import torch
from torch.utils.data import Dataset as TorchDataset

### Xarray to Torch

In [None]:
class XarrayTorchDataset(TorchDataset):
    def __init__(self, traj_ds: xr.Dataset, ssc_ds: xr.Dataset):
        self.traj_ds = traj_ds
        self.ssc_ds = ssc_ds

    def __len__(self):
        return self.traj_ds.traj.size

    def __getitem__(self, idx: int):
        traj_arrays = self.__get_traj_arrays(idx)
        ssc_arrays = self.__get_ssc_arrays(*traj_arrays[:3])

        traj_arrays = [torch.from_numpy(arr) for arr in traj_arrays]
        ssc_arrays = [torch.from_numpy(arr) for arr in ssc_arrays]
        
        return traj_arrays, ssc_arrays
    
    def __get_traj_arrays(self, idx: int):
        traj_subset = self.traj_ds.isel(traj=idx)
        
        traj_lat = traj_subset.lat.values.ravel()
        traj_lon = traj_subset.lon.values.ravel()
        traj_time = traj_subset.time.values.ravel().astype("datetime64[s]").astype(int)  # in seconds
        traj_id = traj_subset.id.values.ravel()
        
        return traj_lat, traj_lon, traj_time, traj_id
    
    def __get_ssc_arrays(self, traj_lat, traj_lon, traj_time):
        one_day_in_seconds = 60 * 60 * 24
        min_time = traj_time[0] - one_day_in_seconds
        max_time = traj_time[-1] + one_day_in_seconds
        n_days = (max_time - min_time) / one_day_in_seconds

        max_travel_distance = .5  # in °/day ; inferred from data
        max_travel_distance *= n_days  # in °

        min_latitude = traj_lat[0] - max_travel_distance
        max_latitude = traj_lat[0] + max_travel_distance
        min_longitude = traj_lon[0] - max_travel_distance
        max_longitude = traj_lon[0] + max_travel_distance

        min_latitude = max(-90, min_latitude)
        max_latitude = min(90, max_latitude)
        min_longitude = max(-180, min_longitude)
        max_longitude = min(180, max_longitude)
        
        ssc_patch = self.ssc_ds.sel(
            longitude=slice(min_longitude, max_longitude),
            latitude=slice(min_latitude, max_latitude),
            time=slice(np.datetime64(min_time.item(), "s"), np.datetime64(max_time.item() - 1, "s"))
        )

        ssc_u = ssc_patch.ugos.values
        ssc_v = ssc_patch.vgos.values
        ssc_time = ssc_patch.time.values.astype("datetime64[s]").astype(int)  # in seconds
        ssc_lat = ssc_patch.latitude.values
        ssc_lon = ssc_patch.longitude.values
        
        return ssc_u, ssc_v, ssc_time, ssc_lat, ssc_lon

In [None]:
xr_torch_dataset = XarrayTorchDataset(traj_ds, ssc_ds)

In [None]:
%time _ = xr_torch_dataset[10]

CPU times: user 82.6 ms, sys: 43.2 ms, total: 126 ms
Wall time: 221 ms


In [None]:
xr_torch_dataset[1500]

([tensor([30.2386, 30.2465, 30.2498, 30.2541, 30.2573, 30.2614, 30.2654, 30.2692,
          30.2743, 30.2797, 30.2855, 30.2905, 30.2954, 30.2999, 30.3038, 30.3074,
          30.3111, 30.3153, 30.3197, 30.3240, 30.3310, 30.3360, 30.3401, 30.3453,
          30.3486, 30.3513, 30.3550, 30.3556, 30.3570, 30.3583, 30.3595, 30.3606,
          30.3617, 30.3631, 30.3645, 30.3672, 30.3696, 30.3708, 30.3738, 30.3760,
          30.3780, 30.3806, 30.3829, 30.3861, 30.3903, 30.3954, 30.3996, 30.4042,
          30.4101, 30.4146, 30.4204, 30.4191, 30.4208, 30.4217, 30.4217, 30.4224,
          30.4232, 30.4247, 30.4249, 30.4253, 30.4259, 30.4261, 30.4269, 30.4277,
          30.4287, 30.4294, 30.4307, 30.4324, 30.4340, 30.4353, 30.4382, 30.4420,
          30.4474, 30.4542, 30.4612, 30.4681, 30.4738, 30.4796, 30.4856, 30.4919,
          30.4988, 30.4972, 30.4956, 30.4940, 30.4909, 30.4889, 30.4881, 30.4866,
          30.4865, 30.4879, 30.4935, 30.4983, 30.5036, 30.5096, 30.5143, 30.5198,
          30.524

### Xarray to JAX

In [13]:
class XarrayJAXDataset(TorchDataset):
    def __init__(self, traj_ds: xr.Dataset, ssc_ds: xr.Dataset):
        self.traj_ds = traj_ds
        self.ssc_ds = ssc_ds

    def __len__(self):
        return self.traj_ds.traj.size

    def __getitem__(self, idx: int):
        traj_arrays = self.__get_traj_arrays(idx)
        ssc_arrays = self.__get_ssc_arrays(*traj_arrays[:3])

        traj_arrays = [np.asarray(arr) for arr in traj_arrays]
        ssc_arrays = [np.asarray(arr) for arr in ssc_arrays]
        
        return traj_arrays, ssc_arrays
    
    def __get_traj_arrays(self, idx: int):
        traj_subset = self.traj_ds.isel(traj=idx)
        
        traj_lat = traj_subset.lat.values.ravel()
        traj_lon = traj_subset.lon.values.ravel()
        traj_time = traj_subset.time.values.ravel().astype("datetime64[s]").astype(int)  # in seconds
        traj_id = traj_subset.id.values.ravel()
        
        return traj_lat, traj_lon, traj_time, traj_id
    
    def __get_ssc_arrays(self, traj_lat, traj_lon, traj_time):
        one_day_in_seconds = 60 * 60 * 24
        min_time = traj_time[0] - one_day_in_seconds
        max_time = traj_time[-1] + one_day_in_seconds
        n_days = (max_time - min_time) / one_day_in_seconds

        max_travel_distance = .5  # in °/day ; inferred from data
        max_travel_distance *= n_days  # in °

        x0 = np.asarray((traj_lat[0], traj_lon[0]))
        min_corner = x0 - max_travel_distance
        max_corner = x0 + max_travel_distance

        min_latitude = max(-90, min_corner[0])
        max_latitude = min(90, max_corner[0])
        min_longitude = max(-180, min_corner[1])
        max_longitude = min(180, max_corner[1])
        
        ssc_patch = self.ssc_ds.sel(
            longitude=slice(min_longitude, max_longitude),
            latitude=slice(min_latitude, max_latitude),
            time=slice(np.datetime64(min_time.item(), "s"), np.datetime64(max_time.item() - 1, "s"))
        )

        ssc_u = ssc_patch.ugos.values
        ssc_v = ssc_patch.vgos.values
        ssc_time = ssc_patch.time.values.astype("datetime64[s]").astype(int)  # in seconds
        ssc_lat = ssc_patch.latitude.values
        ssc_lon = ssc_patch.longitude.values
        
        return ssc_u, ssc_v, ssc_time, ssc_lat, ssc_lon

In [14]:
xr_jax_dataset = XarrayJAXDataset(traj_ds, ssc_ds)

In [15]:
%time _ = xr_jax_dataset[10]

CPU times: user 101 ms, sys: 36 ms, total: 137 ms
Wall time: 72.4 ms


In [16]:
xr_jax_dataset[1500]

([array([30.23863029, 30.24650955, 30.24976921, 30.25411987, 30.25728035,
         30.26136017, 30.26536942, 30.26919937, 30.27428055, 30.27972984,
         30.28549957, 30.29051971, 30.29538918, 30.29993057, 30.30378914,
         30.3073597 , 30.31108093, 30.3152504 , 30.31967926, 30.32403946,
         30.3309803 , 30.33595085, 30.34007072, 30.34528923, 30.34860992,
         30.35132027, 30.35497093, 30.35563087, 30.35704041, 30.35832977,
         30.35950089, 30.36063957, 30.36174965, 30.36311913, 30.36449051,
         30.36724091, 30.36961937, 30.37080002, 30.3737793 , 30.37602043,
         30.37795067, 30.38056946, 30.3828907 , 30.38606071, 30.39034081,
         30.39536095, 30.39963913, 30.40423965, 30.41012001, 30.41460991,
         30.42036057, 30.41913033, 30.42078972, 30.42168045, 30.42173958,
         30.42238998, 30.42318916, 30.42465973, 30.42486954, 30.42531967,
         30.42593956, 30.42613983, 30.42688942, 30.42773056, 30.42868042,
         30.42938995, 30.43069077, 30.

## Dataloader

In [17]:
import time

import dask
import equinox as eqx
import jax.numpy as jnp
from pastax.gridded import Gridded
from pastax.trajectory import Trajectory
from torch.utils.data import DataLoader

In [18]:
DASK_N_WORKERS = 4
DL_N_WORKERS = 8
BATCH_SIZE = 64
PREFECT_FACTOR = 2

In [19]:
dask.config.set(scheduler="threads", num_workers=DASK_N_WORKERS)

<dask.config.set at 0x7f19576939d0>

In [20]:
@eqx.filter_jit
def to_trajectories(traj_arrays, stack_fn):
    traj_lat, traj_lon, traj_time, traj_id = traj_arrays

    traj_latlon = jnp.asarray(stack_fn((traj_lat, traj_lon)))
    trajectories = eqx.filter_vmap(
        lambda _latlon, _time, _id: Trajectory.from_array(values=_latlon, times=_time, id=_id)
    )(
        traj_latlon, jnp.asarray(traj_time), jnp.asarray(traj_id)
    )

    return trajectories


@eqx.filter_jit
def to_gridded(ssc_arrays):
    ssc_u, ssc_v, ssc_time, ssc_lat, ssc_lon = ssc_arrays
    
    gridded = eqx.filter_vmap(Gridded.from_array)(
        {"u": jnp.asarray(ssc_u), "v": jnp.asarray(ssc_v)}, 
        jnp.asarray(ssc_time), jnp.asarray(ssc_lat), jnp.asarray(ssc_lon)
    )

    return gridded

### Xarray to Torch

In [None]:
def xr_torch_collate_fn(batch: ((torch.Tensor, ...), (torch.Tensor, ...))):
    traj_arrays = [[] for _ in range(4)]
    ssc_arrays = [[] for _ in range(5)]

    for elem in batch:
        _traj_arrays, _ssc_arrays = elem

        if _ssc_arrays[0].shape != (7, 56, 56):  # dirty trick. prefer to extend the domain in the previous step
                continue
        
        [traj_arrays[i].append(_traj_arrays[i]) for i in range(len(_traj_arrays))]
        [ssc_arrays[i].append(_ssc_arrays[i]) for i in range(len(_ssc_arrays))]

    traj_arrays = [torch.stack(traj_array) for traj_array in traj_arrays]
    ssc_arrays = [torch.stack(ssc_array) for ssc_array in ssc_arrays]

    return traj_arrays, ssc_arrays

In [None]:
xr_torch_dataloader = DataLoader(
    xr_torch_dataset,
    batch_size=BATCH_SIZE, shuffle=True,
    collate_fn=xr_torch_collate_fn,
    num_workers=DL_N_WORKERS, prefetch_factor=PREFECT_FACTOR,
    persistent_workers=True, multiprocessing_context="fork"
)

t0 = time.time()
t1 = time.time()
for i, (traj_arrays, ssc_arrays) in enumerate(xr_torch_dataloader):
    print(i)

    trajectories = to_trajectories(traj_arrays, lambda arrs: torch.stack(arrs, dim=-1))
    gridded = to_gridded(ssc_arrays)
    
    print(time.time() - t1)
    t1 = time.time()

    if i == 30:
        break

print(time.time() - t0)

0
6.760746479034424
1
2.5802338123321533
2
2.2076003551483154
3
2.1688451766967773
4
2.337625741958618
5
2.247772693634033
6
2.284332275390625
7
2.339719295501709
8
2.264909029006958
9
2.3202197551727295
10
2.428603172302246
11
2.2939839363098145
12
2.12227463722229
13
2.185750961303711
14
2.2217981815338135
15
2.198852062225342
16
2.2385616302490234
17
2.2551610469818115
18
2.271639823913574
19
2.1832783222198486
20
2.312917709350586
21
2.286311626434326
22
2.7407989501953125
23
2.2531964778900146
24
2.288090705871582
25
2.317044496536255
26
2.3815598487854004
27
2.332911729812622
28
2.277844190597534
29
2.2347397804260254
30
2.4589788913726807
75.80062460899353


### Xarray to JAX

In [None]:
# @jax.jit
def xr_jax_collate_fn(batch: [((np.ndarray, ...), (np.ndarray, ...)), ...]):
    traj_arrays = [[] for _ in range(4)]
    ssc_arrays = [[] for _ in range(5)]

    for elem in batch:
        _traj_arrays, _ssc_arrays = elem

        if _ssc_arrays[0].shape != (7, 56, 56):  # dirty trick. prefer to extend the domain in the previous step
                continue
        
        [traj_arrays[i].append(_traj_arrays[i]) for i in range(len(_traj_arrays))]
        [ssc_arrays[i].append(_ssc_arrays[i]) for i in range(len(_ssc_arrays))]

    traj_arrays = [np.stack(traj_array) for traj_array in traj_arrays]
    ssc_arrays = [np.stack(ssc_array) for ssc_array in ssc_arrays]

    return traj_arrays, ssc_arrays

In [24]:
xr_jax_dataloader = DataLoader(
    xr_jax_dataset,
    batch_size=BATCH_SIZE, shuffle=True,
    collate_fn=xr_jax_collate_fn,
    num_workers=DL_N_WORKERS, prefetch_factor=PREFECT_FACTOR,
    persistent_workers=True, multiprocessing_context="fork"
)

t0 = time.time()
t1 = time.time()
for i, (traj_arrays, ssc_arrays) in enumerate(xr_jax_dataloader):
    print(i)

    trajectories = to_trajectories(traj_arrays, lambda arrs: jnp.stack(arrs, axis=-1))
    gridded = to_gridded(ssc_arrays)
    
    print(time.time() - t1)
    t1 = time.time()

    if i == 50:
        break

print(time.time() - t0)

0
5.054365873336792
1
0.597553014755249
2
0.10978198051452637
3
0.753359317779541
4
0.07685422897338867
5
0.07655978202819824
6
0.07599329948425293
7
0.8620905876159668
8
0.8352811336517334
9
0.07669925689697266
10
0.07505631446838379
11
0.31084775924682617
12
0.07366681098937988
13
0.0746762752532959
14
0.07396078109741211
15
0.5756821632385254
16
3.6916134357452393
17
0.10847330093383789
18
0.07839369773864746
19
0.10602045059204102
20
0.07659792900085449
21
0.0758817195892334
22
0.0768580436706543
23
0.07637929916381836
24
2.2968297004699707
25
0.26642537117004395
26
0.07398104667663574
27
0.10855364799499512
28
0.07687950134277344
29
0.0765070915222168
30
0.07763481140136719
31
0.08120512962341309
32
2.6927623748779297
33
0.17136812210083008
34
0.07501721382141113
35
0.07961416244506836
36
0.07714700698852539
37
0.07816934585571289
38
0.08017921447753906
39
0.07987332344055176
40
2.8314993381500244
41
0.1198577880859375
42
0.07844376564025879
43
0.07611775398254395
44
0.07745385169