In [18]:
import xarray as xr
import numpy as np
import pandas as pd
from typing import List
from pydantic import BaseModel, Field

In [19]:
# Make Pydantic Data Array and Data set. No validation is done here, but we could add some
class PydanticXArrayDataArray(xr.DataArray):
    # Adapted from https://pydantic-docs.helpmanual.io/usage/types/#classes-with-__get_validators__

    __slots__ = []

    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v):
        return v


class PydanticXArrayDataSet(xr.Dataset):
    # Adapted from https://pydantic-docs.helpmanual.io/usage/types/#classes-with-__get_validators__

    __slots__ = []

    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v):
        return v

The idea with the notebook is to explore the pydantic models of the data strcuture. Some cells are there just for an example. 

The general aim is to remove any to_numpy functions

In [20]:
# lets define a general data source that stores batch dataset information

class BatchDataSource(BaseModel):
    """Superclass for image data (satellite imagery, NWPs, etc.)"""

    data: PydanticXArrayDataSet

    def to_netcdf(self):
        pass

    def from_netcdf(self):
        pass


In [21]:
# some useful functions
def convert_data_array_to_dataset(data):

    dims = data.dims
    data = xr.Dataset({'data': data})

    for dim in dims:
        coord = data[dim]
        data[dim] = np.arange(len(coord))

        data[f"{dim}_coords"] = xr.DataArray(coord, coords=[data[dim]], dims=[dim])

    return data


def from_list_data_array_to_dataset(image_data_arrays: List[xr.DataArray]) -> xr.Dataset:
    # might need to example dims here

    image_data_arrays = [convert_data_array_to_dataset(image_data_arrays[i])
                         for i in range(len(image_data_arrays))]

    image_data_arrays = [image_data_arrays[i].expand_dims(dim='example').assign_coords(example=("example", [i]))
                         for i in range(len(image_data_arrays))]

    return xr.concat(image_data_arrays, dim="example")


def create_image_array(dims=("time", "x", "y", "channels")):
    ALL_COORDS = {
        "time": pd.date_range("2021-01-01", freq="5T", periods=4),
        "x": np.random.randint(low=0, high=1000, size=8),
        "y": np.random.randint(low=0, high=1000, size=8),
        "channels": np.arange(5),
    }
    coords = [(dim, ALL_COORDS[dim]) for dim in dims]
    image_data_array = xr.DataArray(0, coords=coords)  # Fake data for testing!
    return image_data_array

def create_image_dataset(dims=("time", "x", "y", "channels")):
    data = create_image_array(dims=dims)

    return convert_data_array_to_dataset(data=data)

In [22]:
# lets define the satellite modelts

class Satellite(BaseModel):
    data: PydanticXArrayDataArray
    # can validate here satellite data
    
    def fake():
        # this could be in testing folder
        pass


class BatchSatellite(BatchDataSource):
    data: PydanticXArrayDataSet


In [23]:
# set up Batch class

class Batch(BaseModel):

    batch_size: int = Field(
        ...,
        g=0,
        description="The size of this batch. If the batch size is 0, "
        "then this item stores one data item",
    )

    satellite: BatchSatellite
#     nwp ....
#     pv ...
#     gsp ...
#     metadata
        
    def from_netcdf(self, folder):
    # loop through data_sources, and load netcdf
        pass


    def to_tensor(self):
        # loop through data_sources, and change to tensors
        pass


In [24]:
# lets get some test satellite data

sat_1 = data=create_image_array()
sat_2 = data=create_image_array()

satellite = from_list_data_array_to_dataset([sat_1,sat_2])

satellite_batch = BatchSatellite(data=satellite)

# 'satellite_batch' can be then saved to a netcdf file
# satellite_batch.to_netcdf()


In [36]:
# we can then load the batch from all the the differetn data sources
# batch = Batch.from_netcdf(path)

# in the data laoder torch.utils.data.Dataset can then return
batch = Batch(batch_size=2, satellite=satellite_batch)
batch.to_tensor()

print(batch.satellite.data.to_dict())
# return
# return batch.dict()

{'coords': {'time': {'dims': ('time',), 'attrs': {}, 'data': [0, 1, 2, 3]}, 'x': {'dims': ('x',), 'attrs': {}, 'data': [0, 1, 2, 3, 4, 5, 6, 7]}, 'y': {'dims': ('y',), 'attrs': {}, 'data': [0, 1, 2, 3, 4, 5, 6, 7]}, 'channels': {'dims': ('channels',), 'attrs': {}, 'data': [0, 1, 2, 3, 4]}, 'example': {'dims': ('example',), 'attrs': {}, 'data': [0, 1]}}, 'attrs': {}, 'dims': {'time': 4, 'x': 8, 'y': 8, 'channels': 5, 'example': 2}, 'data_vars': {'data': {'dims': ('example', 'time', 'x', 'y', 'channels'), 'attrs': {}, 'data': [[[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [

In [None]:
# we might want a function to split the Batch into a List[Example] and maybe a function that joins Examples to a Batch
class Example(BaseModel):

    satellite: Satellite
    nwp ....
    pv ...
    gsp ...
    metadata