Here is a toy xarray dataset. It has a few 3D and 2D variables.

In [1]:
import numpy as np
import xarray as xr
from itertools import product

from torch.utils.data import Dataset, DataLoader

class XRTimeSeries(Dataset):
    """A pytorch Dataset class for time series data in xarray format

    This function assumes the data has dimensions ['time', 'z', 'y', 'x'], and
    that the axes of the data arrays are all stored in that order.

    An individual "sample" is the full time time series from a single
    horizontal location. The time-varying variables in this sample will have
    shape (time, z, 1, 1).

    Examples
    --------
    >>> ds = xr.open_dataset("in.nc")
    >>> dataset = XRTimeSeries(ds)
    >>> dataset[0]

    """
    dims = ['time', 'z', 'x', 'y']

    def __init__(self, data, time_length=None):
        """
        Parameters
        ----------
        data : xr.DataArray
            An input dataset. This dataset must contain at least some variables
            with all of the dimensions ['time' , 'z', 'x', 'y'].
        time_length : int, optional
            The length of the time sequences to use, must evenly divide the
            total number of time points.
        """
        self.time_length = time_length or len(data.time)
        self.data = data
        self.numpy_data = {key: data[key].values for key in data.data_vars}
        self.data_vars = set(data.data_vars)
        self.dims = {key: data[key].dims for key in data.data_vars}
        self.constants = {
            key
            for key in data.data_vars
            if len({'x', 'y', 'time'} & set(data[key].dims)) == 0
        }
        self.setup_indices()

    def setup_indices(self):
        len_x = len(self.data['x'].values)
        len_y = len(self.data['y'].values)
        len_t = len(self.data['time'].values)

        x_iter = range(0, len_x, 1)
        y_iter = range(0, len_y, 1)
        t_iter = range(0, len_t, self.time_length)
        assert len_t % self.time_length == 0
        self.indices = list(product(t_iter, y_iter, x_iter))

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        t, y, x = self.indices[i]
        output_tensors = {}
        for key in self.data_vars:
            if key in self.constants:
                continue

            data_array = self.numpy_data[key]
            if 'z' in self.dims[key]:
                this_array_index = (slice(t, t + self.time_length),
                                    slice(None), y, x)
            else:
                this_array_index = (slice(t, t + self.time_length), None, y, x)

            sample = data_array[this_array_index][:, :, np.newaxis, np.newaxis]
            output_tensors[key] = sample.astype(np.float32)
        return output_tensors

    @property
    def time_dim(self):
        return self.dims[0][0]

    def torch_constants(self):
        return {
            key: torch.tensor(self.data[key].values, requires_grad=False)
            .float()
            for key in self.constants
        }

    @property
    def scale(self):
        std = self.std
        return valmap(lambda x: x.max(), std)
    

def get_xarray_dataset():

    dims_3d = ['time', 'z', 'y', 'x']
    dims_2d = ['time', 'y', 'x']

    data_3d = np.ones((4, 4, 5, 2))
    data_2d = np.ones((4, 5, 2))

    return xr.Dataset({
        'a': (dims_3d, data_3d),
        'b': (dims_2d, data_2d)
    })

ds = get_xarray_dataset()
ds

<xarray.Dataset>
Dimensions:  (time: 4, x: 2, y: 5, z: 4)
Dimensions without coordinates: time, x, y, z
Data variables:
    a        (time, z, y, x) float64 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0 1.0
    b        (time, y, x) float64 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0 1.0

In [2]:
torch_dataset = XRTimeSeries(ds, time_length=4)

In [3]:
len(torch_dataset)

10

The length of the dataset is $x\dot y$

In [4]:
sample = torch_dataset[0]
sample

{'a': array([[[[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]]],
 
 
        [[[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]]],
 
 
        [[[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]]],
 
 
        [[[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]]]], dtype=float32), 'b': array([[[[1.]]],
 
 
        [[[1.]]],
 
 
        [[[1.]]],
 
 
        [[[1.]]]], dtype=float32)}

Two dimensional and three dimensional variables have broadcastable shapes

In [5]:
sample['b'].shape

(4, 1, 1, 1)

In [6]:
sample['a'].shape

(4, 4, 1, 1)

Now that we have made the torch dataset object, we can pass it to pytorch's DataLoader class.

In [7]:
train_loader = DataLoader(torch_dataset, batch_size=4)

In [8]:
for batch in train_loader:
    print("shape of b", batch['b'].shape)

shape of b torch.Size([4, 4, 1, 1, 1])
shape of b torch.Size([4, 4, 1, 1, 1])
shape of b torch.Size([2, 4, 1, 1, 1])


The first dimension becomes the "batch" dimension. The other dimensions are the physical dimensions (time, z, y, x). My [model classes](https://github.com/nbren12/uwnet/blob/047a63b70985b12e17013355ecd25c908681ab76/uwnet/modules.py) accept data in this format.