In [1]:
from typing import Union, Optional, List, Tuple, Dict, Any, Iterable, TypeVar, Type, NamedTuple, Sequence, Generic, _GenericAlias

from beartype import beartype
from edf_interface.data.base import DataAbstractBase

import torch

In [2]:
@beartype
class SmallData(DataAbstractBase):
    def __init__(self, x: torch.Tensor, name: str):
        super().__init__()
        self.x: torch.Tensor = x
        self.name: str = name

    @property
    def data_args_hint(self) -> Dict[str, type]:
        hint = {
            'x': torch.Tensor,
        }
        return hint

    @property
    def metadata_args_hint(self) -> Dict[str, type]:
        hint = {
            'name': str
        }
        return hint

@beartype
class TestData(DataAbstractBase):
    def __init__(self, x: torch.Tensor, y: int, z: Dict, k: SmallData, a: List, b: int):
        super().__init__()
        self.x: torch.Tensor = x
        self.y: int = y
        self.z: Dict = z
        self.a: List = a
        self.b: int = b
        self.k: SmallData = k

    @property
    def data_args_hint(self) -> Dict[str, type]:
        hint = {
            'x': torch.Tensor,
            'y': int,
            'z': dict,
            'k': SmallData
        }
        return hint

    @property
    def metadata_args_hint(self) -> Dict[str, type]:
        hint = {
            'a': list,
            'b': int,
        }
        return hint

In [3]:
data = TestData(x=torch.randn(5,3), y=3, z={'a': 5}, k=SmallData(x=torch.randn(3,4), name='small'), a=[1., 2.], b=3)

In [4]:
data = data.to('cuda')

In [5]:
data

<TestData>
  Metadata: 
    - a: [1.0, 2.0]
    - b: 3
  Data: 
    - x: <Tensor> (Shape: torch.Size([5, 3]))
          tensor([[ 0.7419,  1.1894,  2.1646],
                  [-0.0760, -0.4688, -0.0328],
                  [ 0.7164,  1.4404,  0.6246],
                  [ 2.1232,  2.0259, -0.6372],
                  [ 0.2361, -1.6544,  0.1931]], device='cuda:0')
    - y: <int>
          3
    - z: <dict>
          {'a': 5}
    - k: <SmallData>
          - name: 'small'
          - x: <Tensor> (Shape: torch.Size([3, 4]))
            
          