In [1]:
from typing import Union, Optional, List, Tuple, Dict, Any, Iterable, TypeVar, Type, NamedTuple, Sequence, Generic, _GenericAlias

from beartype import beartype
from edf_interface.data.base import DataAbstractBase

import torch

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
@beartype
class SmallData(DataAbstractBase):
    data_args_type: Dict[str, type] = {
        'j': torch.Tensor,
    }

    metadata_args: List[str] = ['name']

    @property
    def device(self) -> torch.device:
        return self.j.device

    def __init__(self, j: torch.Tensor, name: str):
        self.j: torch.Tensor = j
        self.name: str = name

@beartype
class TestData(DataAbstractBase):
    data_args_type: Dict[str, type] = {
        'x': torch.Tensor,
        'y': int,
        'z': dict,
        'k': SmallData
    }

    metadata_args: List[str] = ['a', 'b']

    @property
    def device(self) -> torch.device:
        return self.x.device

    def __init__(self, x: torch.Tensor, y: int, z: Dict, k: SmallData, a: List, b: int):
        self.x: torch.Tensor = x
        self.y: int = y
        self.z: Dict = z
        self.a: List = a
        self.b: int = b
        self.k: SmallData = k

In [3]:
data = TestData(x=torch.randn(5,3), y=3, z={'a': 5}, k=SmallData(j=torch.randn(3,4), name='small'), a=[1., 2.], b=3)
data

<TestData>  (device: cpu)
  Metadata: 
    - a: [1.0, 2.0]
    - b: 3
  Data: 
    - x: <Tensor> (Shape: torch.Size([5, 3]))
          tensor([[ 2.1426,  0.8252,  0.3087],
                  [ 0.0946, -0.8636,  0.9128],
                  [-0.1423, -0.2763,  1.2188],
                  [ 0.3210, -0.8009, -0.4178],
                  [ 1.1104,  0.2976, -0.1872]])
    - y: <int>
          3
    - z: <dict>
          {'a': 5}
    - k: <SmallData>
          - name: 'small'
          - j: <Tensor> (Shape: torch.Size([3, 4]))
            
          

In [4]:
data = data.to('cuda')
data

<TestData>  (device: cuda:0)
  Metadata: 
    - a: [1.0, 2.0]
    - b: 3
  Data: 
    - x: <Tensor> (Shape: torch.Size([5, 3]))
          tensor([[ 2.1426,  0.8252,  0.3087],
                  [ 0.0946, -0.8636,  0.9128],
                  [-0.1423, -0.2763,  1.2188],
                  [ 0.3210, -0.8009, -0.4178],
                  [ 1.1104,  0.2976, -0.1872]], device='cuda:0')
    - y: <int>
          3
    - z: <dict>
          {'a': 5}
    - k: <SmallData>
          - name: 'small'
          - j: <Tensor> (Shape: torch.Size([3, 4]))
            
          

In [6]:
state_dict = data.get_data_dict(device='cpu')
state_dict

{'x': tensor([[ 2.1426,  0.8252,  0.3087],
         [ 0.0946, -0.8636,  0.9128],
         [-0.1423, -0.2763,  1.2188],
         [ 0.3210, -0.8009, -0.4178],
         [ 1.1104,  0.2976, -0.1872]]),
 'y': 3,
 'z': {'a': 5},
 'k': {'j': tensor([[ 1.8516,  0.7965,  0.2105, -0.5913],
          [-0.6595,  0.7075, -1.9977, -0.1598],
          [-0.3256,  0.5113, -0.9532,  0.2929]]),
  'metadata': {'__type__': 'SmallData', 'name': 'small'}},
 'metadata': {'__type__': 'TestData', 'a': [1.0, 2.0], 'b': 3}}

In [7]:
TestData.from_data_dict(state_dict, device='cuda')

<TestData>  (device: cuda:0)
  Metadata: 
    - a: [1.0, 2.0]
    - b: 3
  Data: 
    - x: <Tensor> (Shape: torch.Size([5, 3]))
          tensor([[ 2.1426,  0.8252,  0.3087],
                  [ 0.0946, -0.8636,  0.9128],
                  [-0.1423, -0.2763,  1.2188],
                  [ 0.3210, -0.8009, -0.4178],
                  [ 1.1104,  0.2976, -0.1872]], device='cuda:0')
    - y: <int>
          3
    - z: <dict>
          {'a': 5}
    - k: <SmallData>
          - name: 'small'
          - j: <Tensor> (Shape: torch.Size([3, 4]))
            
          