In [1]:
import xarray as xr
import pydantic
from typing import Any, Optional
import numpy as np
import pandas as pd

OK, I think pandera _isn't_ the way forwards because it appears very tighly coupled to Pandas (so, for example, I don't think it's possible to use pandera with n-dimensional arrays).

And it's not possible to do `class MyDataset(xr.Dataset, pydantic.BaseModel)`.

But Pydantic looks promising. Below is a (very rough) attempt at combining xarray with pydantic. Comments more than welcome! 

The notebook below automatically validates a few things; but it's not super-useful as a human-readable specification for what's going on inside a DataArray or Dataset.

The code below is adapted from [Pydantic's documentation on custom data types](https://pydantic-docs.helpmanual.io/usage/types/#classes-with-__get_validators__).


## Define subclasses of `xr.DataArray` which enable Pydantic validation.

In [2]:
class PydanticXArrayDataArray(xr.DataArray):
    # Adapted from https://pydantic-docs.helpmanual.io/usage/types/#classes-with-__get_validators__
    
    _expected_dimensions = ()  # Subclasses should set this.
    
    @classmethod
    def __get_validators__(cls):
        yield cls.validate_dims
    
    @classmethod
    def validate(cls, v: Any) -> Any:
        v = cls.validate_dims(v)
        v = cls.validate_coords(v)
        return v
        # TODO: How to call multiple validation functions?
        
    @classmethod
    def validate_dims(cls, v: Any) -> Any:
        assert v.dims == cls._expected_dimensions, f"{cls.__name__}.dims is wrong! {cls.__name__}.dims is {v.dims}. But we expected {cls._expected_dimensions}"
        return v
        
    @classmethod
    def validate_coords(cls, v: Any) -> Any:
        for dim in cls._expected_dimensions:
            coord = v.coords[dim]
            assert len(coord) > 0, f"{dim} is empty in {cls.__name__}!"
        return v

        
class ImageDataArray(PydanticXArrayDataArray):
    """Superclass for image data (satellite imagery, NWPs, etc.)"""
    _expected_dimensions = ('time', 'x', 'y')
    

class CoordsDataArray(PydanticXArrayDataArray):
    _expected_dimensions = ('index')

  class PydanticXArrayDataArray(xr.DataArray):
  class ImageDataArray(PydanticXArrayDataArray):
  class CoordsDataArray(PydanticXArrayDataArray):


## Define subclass of `xr.Dataset` to hold the `DataArray`s together

In [3]:
class ImageDataset(xr.Dataset):
    
    data: ImageDataArray
    x_coords: Optional[CoordsDataArray] = None
    y_coords: Optional[CoordsDataArray] = None

    @classmethod
    def __get_validators__(cls):
        yield cls.validate_data
        yield cls.validate_attrs
        yield cls.validate_coords
    
    @classmethod
    def validate_data(cls, v: Any) -> Any:
        v.data = ImageDataArray.validate(v.data)
        return v
        
    @classmethod
    def validate_attrs(cls, v: Any) -> Any:
        expected_attrs = ['data', 'x', 'y']
        for attr in expected_attrs:
            assert attr in v, f"{attr} is missing from {cls.__name__}!"
        return v
        
    @classmethod
    def validate_coords(cls, v: Any) -> Any:
        for dim in ['x_coords', 'y_coords']:
            coords = getattr(v, dim)
            if coords is not None:
                CoordsDataArray.validate(coords)
        return v

  cls = super().__new__(mcls, name, bases, namespace, **kwargs)


## Define a `pydantic.BaseModel`

In [4]:
class Example(pydantic.BaseModel):
    """A single machine learning training example."""
    satellite: Optional[ImageDataset]
    nwp: Optional[ImageDataset]

## Test with some dummy data

In [5]:
def create_image_dataset(dims=('time', 'x', 'y')):
    ALL_COORDS = {
        "time": pd.date_range("2021-01-01", freq="5T", periods=4),
        "x": np.arange(10, 18),
        "y": np.arange(20, 28)
    }
    coords = [(dim, ALL_COORDS[dim]) for dim in dims]
    image_data_array = ImageDataArray(
        0,  # Fake data for testing!
        coords=coords)
    return ImageDataset({'data': image_data_array})

In [6]:
example = Example(
    satellite=create_image_dataset(), 
    nwp=create_image_dataset()
)

  v.data = ImageDataArray.validate(v.data)


In [7]:
example

Example(satellite=<xarray.ImageDataset>
Dimensions:  (time: 4, x: 8, y: 8)
Coordinates:
  * time     (time) datetime64[ns] 2021-01-01 ... 2021-01-01T00:15:00
  * x        (x) int64 10 11 12 13 14 15 16 17
  * y        (y) int64 20 21 22 23 24 25 26 27
Data variables:
    data     (time, x, y) int64 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0, nwp=<xarray.ImageDataset>
Dimensions:  (time: 4, x: 8, y: 8)
Coordinates:
  * time     (time) datetime64[ns] 2021-01-01 ... 2021-01-01T00:15:00
  * x        (x) int64 10 11 12 13 14 15 16 17
  * y        (y) int64 20 21 22 23 24 25 26 27
Data variables:
    data     (time, x, y) int64 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0 0)

### Test with bad data

In [8]:
example = Example(
    satellite=create_image_dataset(dims=('time', 'x')), 
    nwp=create_image_dataset()
)

  v.data = ImageDataArray.validate(v.data)


ValidationError: 1 validation error for Example
satellite
  ImageDataArray.dims is wrong! ImageDataArray.dims is ('time', 'x'). But we expected ('time', 'x', 'y') (type=assertion_error)