Idea is to test out the following work flow

https://github.com/openclimatefix/nowcasting_dataset/issues/213#issuecomment-942080996

In [39]:
# imports
from __future__ import annotations
from typing import Optional
from typing import Union, List

import time
import numpy as np
import pandas as pd
import torch
import xarray as xr
from pydantic import BaseModel, Field

In [40]:
# pydantic dataset

class PydanticXArrayDataSet(xr.Dataset):
    # Adapted from https://pydantic-docs.helpmanual.io/usage/types/#classes-with-__get_validators__

    __slots__ = []
    
    # TODO add validation

    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v):
        return v

In [41]:
# xr array and xr dataset --> to torch functions
if not hasattr(xr.DataArray, "torch"):
    @xr.register_dataarray_accessor("torch")
    class TorchAccessor:
        def __init__(self, xarray_obj):
            self._obj = xarray_obj

        def to_tensor(self):
            """Convert this DataArray to a torch.Tensor"""
            import torch

            return torch.tensor(self._obj.data)

        def to_named_tensor(self):
            """Convert this DataArray to a torch.Tensor with named dimensions"""
            import torch

            return torch.tensor(self._obj.data, names=self._obj.dims)


if not hasattr(xr.Dataset, "torch"):
    @xr.register_dataset_accessor("torch")
    class TorchAccessor:
        def __init__(self, xdataset_obj: xr.Dataset):
            self._obj = xdataset_obj

        def to_tensor(self, dims: List[str]) -> dict:
            """Convert this Dataset to dictionary of torch tensors"""

            torch_dict = {}

            for dim in dims:
                v = getattr(self._obj, dim)
                if 'time' == dim:
                    v = v.astype(np.int32)

                torch_dict[dim] = v.torch.to_tensor()

            return torch_dict

In [42]:
# useful functions

def from_list_data_array_to_batch_dataset(image_data_arrays: List[xr.DataArray]) -> xr.Dataset:
    # join a list of data arrays to a dataset byt expanding dims 

    image_data_arrays = [
        convert_data_array_to_dataset(image_data_arrays[i]) for i in range(len(image_data_arrays))
    ]

    image_data_arrays = [
        image_data_arrays[i].expand_dims(dim="example").assign_coords(example=("example", [i]))
        for i in range(len(image_data_arrays))
    ]

    return xr.concat(image_data_arrays, dim="example")


def convert_data_array_to_dataset(data_xarray):
    # convert data array to dataset, and re index dims

    dims = data_xarray.dims
    data = xr.Dataset({"data": data_xarray})

    for dim in dims:
        coord = data[dim]
        data[dim] = np.arange(len(coord))

        data = data.rename({dim: f"{dim}_index"})

        data[dim] = xr.DataArray(coord, coords=data[f"{dim}_index"].coords, dims=[f"{dim}_index"])

    return data


In [43]:
# fake image xr array function
def create_image_array(dims=("time", "x", "y", "channels")):
    ALL_COORDS = {
        "time": pd.date_range("2021-01-01", freq="5T", periods=4),
        "x": np.random.randint(low=0, high=1000, size=8),
        "y": np.random.randint(low=0, high=1000, size=8),
        "channels": np.arange(5),
    }
    coords = [(dim, ALL_COORDS[dim]) for dim in dims]
    image_data_array = xr.DataArray(0, coords=coords)  # Fake data for testing!
    return image_data_array

In [44]:
# define satellite models. 
# 1. xr.Dataset
# 2. Pydantic model

class Satellite(PydanticXArrayDataSet):
    # Use to store xr.Dataset data
    __slots__ = []
    
    # todo add validation here
    
class SatelliteML(BaseModel):
    # Use to store data ready for ml
    data: torch.Tensor
    time: torch.Tensor
    x: torch.Tensor
    y: torch.Tensor

    class Config:
        arbitrary_types_allowed = True


In [45]:
# set up batch models

class Batch(BaseModel):
    """A batch of xr.Datasets."""

    satellite: Optional[Satellite]
    # nwp
    # metadata

    def to_tensor(self):
        # loop through data_sources, and change to tensors
        pass

    def save_netcdf(self):
        # save to netcdf
        self.satellite.to_netcdf('test.nc')

    @staticmethod
    def load_netcdf():
        # load to netcdf
        return Batch(satellite=xr.load_dataset('test.nc'))
        

    
class BatchML(BaseModel):
    """A batch machine learning training examples."""

    satellite: Optional[SatelliteML]
    # nwp
    # metadata



In [46]:
# Workflow 1., create the data and save to netcdf file
sat_1 = create_image_array()
sat_2 = create_image_array()

satellite_batch = Satellite(from_list_data_array_to_batch_dataset([sat_1, sat_2]))

batch = Batch(satellite=satellite_batch)

batch.save_netcdf()

In [47]:
# Workflow 2., load data and 
batch = Batch.load_netcdf()

# change to torch
satellite_batch_ml = batch.satellite.torch.to_tensor(['data','time','x','y'])
satellite_batch_ml = SatelliteML(**satellite_batch_ml)


batch_ml = BatchML(satellite=satellite_batch_ml)

In [50]:
# create dataset, 

class FakeDataset(torch.utils.data.Dataset):
    
    def __init__(self, length: int = 10):
        self.length = length
    
    def __len__(self):
        """Number of pieces of data"""
        return self.length

    def per_worker_init(self, worker_id: int):
        """Not needed"""
        pass

    def __getitem__(self, idx):
        
        batch = Batch.load_netcdf()

        # change to torch
        satellite_batch_ml = batch.satellite.torch.to_tensor(['data','time','x','y'])
        satellite_batch_ml = SatelliteML(**satellite_batch_ml)

        batch_ml = BatchML(satellite=satellite_batch_ml)
        
        return batch_ml.dict()

In [58]:
# run dataloader

train = torch.utils.data.DataLoader(FakeDataset())
i = iter(train)

for _ in range(10):
    t =time.time()
    x = next(i)
    x = BatchML(**x)

    print(time.time() - t)

# IT WORKS
assert type(x.satellite.data) == torch.Tensor

0.01802206039428711
0.014360666275024414
0.012000083923339844
0.012310028076171875
0.011492729187011719
0.01065683364868164
0.012231826782226562
0.011026144027709961
0.010531187057495117
0.011240243911743164
