In [1]:
from typing import Union, Optional, List, Tuple, Dict, Any, Iterable, TypeVar, Type, NamedTuple, Sequence, Generic, _GenericAlias

from beartype import beartype
from edf_interface.data.base import DataAbstractBase

import torch

In [2]:
@beartype
class SmallData(DataAbstractBase):
    data_args_hint: Dict[str, type] = {
        'x': torch.Tensor,
    }

    metadata_args_hint: Dict[str, type] = {
        'name': str,
    }

    @property
    def device(self) -> torch.device:
        return self.x.device

    def __init__(self, x: torch.Tensor, name: str):
        super().__init__()
        self.x: torch.Tensor = x
        self.name: str = name

@beartype
class TestData(DataAbstractBase):
    data_args_hint: Dict[str, type] = {
        'x': torch.Tensor,
        'y': int,
        'z': dict,
        'k': SmallData
    }

    metadata_args_hint: Dict[str, type] = {
        'a': list,
        'b': int,
    }

    @property
    def device(self) -> torch.device:
        return self.x.device

    def __init__(self, x: torch.Tensor, y: int, z: Dict, k: SmallData, a: List, b: int):
        super().__init__()
        self.x: torch.Tensor = x
        self.y: int = y
        self.z: Dict = z
        self.a: List = a
        self.b: int = b
        self.k: SmallData = k

In [3]:
data = TestData(x=torch.randn(5,3), y=3, z={'a': 5}, k=SmallData(x=torch.randn(3,4), name='small'), a=[1., 2.], b=3)

In [4]:
data = data.to('cuda')

In [5]:
data

<TestData>  (device: cuda:0)
  Metadata: 
    - a: [1.0, 2.0]
    - b: 3
  Data: 
    - x: <Tensor> (Shape: torch.Size([5, 3]))
          tensor([[-2.0896e-01,  2.0279e+00,  1.9534e+00],
                  [-2.0516e-01, -1.6438e+00,  1.4453e+00],
                  [ 5.6487e-01,  1.5794e-03, -1.3912e+00],
                  [ 3.6207e-01,  3.1641e-01,  4.6279e-01],
                  [ 1.1365e+00, -1.5696e+00,  4.5848e-01]], device='cuda:0')
    - y: <int>
          3
    - z: <dict>
          {'a': 5}
    - k: <SmallData>
          - name: 'small'
          - x: <Tensor> (Shape: torch.Size([3, 4]))
            
          

In [6]:
state_dict = data.get_data_dict(device='cpu')

In [7]:
state_dict

{'x': tensor([[-2.0896e-01,  2.0279e+00,  1.9534e+00],
         [-2.0516e-01, -1.6438e+00,  1.4453e+00],
         [ 5.6487e-01,  1.5794e-03, -1.3912e+00],
         [ 3.6207e-01,  3.1641e-01,  4.6279e-01],
         [ 1.1365e+00, -1.5696e+00,  4.5848e-01]]),
 'y': 3,
 'z': {'a': 5},
 'k': {'x': tensor([[ 1.3250,  1.2905,  1.6491, -0.5659],
          [-1.9802, -0.5141,  0.7187,  0.2536],
          [ 1.0510, -0.5383,  0.8729,  0.5419]]),
  'metadata': {'name': 'small'}},
 'metadata': {'a': [1.0, 2.0], 'b': 3}}