In [1]:
from typing import Union, Optional, List, Tuple, Dict, Any, Iterable, TypeVar, Type, NamedTuple, Sequence, Generic, _GenericAlias, get_origin, get_args
from typing_extensions import Self
from abc import abstractmethod

from beartype import beartype
from edf_interface.data.base import DataAbstractBase, _device, _dtype, _torch_tensor_to
from edf_interface.data.se3 import SE3
from edf_interface.data.pointcloud import PointCloud

import torch

In [2]:
inputs = [SE3(poses=torch.tensor([[1., 0., 0., 0., 0., 0., 1.]])).to('cuda'),
          SE3(poses=torch.tensor([[1., 0., 0., 0., 0., 0., 1.]])).to('cuda'),
          SE3(poses=torch.tensor([[1., 0., 0., 0., 0., 0., 1.]])).to('cuda'),
          SE3(poses=torch.tensor([[1., 0., 0., 0., 0., 0., 1.]])).to('cuda'),
          ]

In [7]:
@beartype
class Test():
    data_seq: Sequence[int]
    def __init__(self, data_seq: Sequence[int]):
        self.data_seq = data_seq

    def __len__(self) -> int:
        return len(self.data_seq)

    def __getitem__(self, idx) -> Union[Self, Any]:
        assert type(idx) == slice or type(idx) == int, "Indexing must be an integer or a slice with single axis."
        return self.data_seq[idx]
    
    def __getattr__(self, name: str):
        if name.startswith('data_'):
            index = name[5:]
            try:
                index = int(name[5:])
                return self[index]
            except ValueError:
                pass

        if hasattr(super(), '__getattr__'):
            return super().__getattr__(name=name)
        else:
            raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
        



In [9]:
test = Test([1,2,3,4])

In [12]:
test.data_3 = 6

In [13]:
test.data_3

6

In [14]:
test.data_seq

[1, 2, 3, 4]

In [6]:
@beartype
class DataList(DataAbstractBase):
    dataseq: List[Any]
    metadata_args: List[str]
    _data_name_prefix: str = 'data'
    
    @property
    def data_args_type(self) -> Dict[str, type]:
        outputs = {}
        for i, data in enumerate(self.dataseq):
            outputs[f"{self._data_name_prefix}{i}": type(data)]

    def __len__(self) -> int:
        return len(self.dataseq)

    def __getitem__(self, idx) -> Union[Self, Any]:
        assert type(idx) == slice or type(idx) == int, "Indexing must be an integer or a slice with single axis."
        if type(idx) == int:
            return self.dataseq[idx]
        else:
            return self.new(dataseq=self.dataseq[idx])
    
    def __getattr__(self, name: str):
        if name.startswith(self._data_name_prefix):
            index = name.lstrip(self._data_name_prefix)
            try:
                index = int(name.lstrip(self._data_name_prefix))
                return self[index]
            except ValueError:
                pass

        if hasattr(super(), '__getattr__'):
            return super().__getattr__(name=name)
        else:
            raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
        
    def __setattr__(self, name: str, value: Any):
        if name.startswith(self._data_prefix):
            index = name.lstrip(self._data_prefix)
            try:
                index = int(name.lstrip(self._data_prefix))
            except ValueError:
                raise IndexError(f"Wrong data index: {name}")
            

            if index is None:
                pass
            elif index == self.__len__():
                self.dataseq.append()
            elif index < self.__len__():
                self.ddataseq[index] = 
        else:
            super().__setattr__(name, value)

    
    @property
    def is_empty(self) -> bool:
        if len(self) == 0:
            return True
        else:
            return False
        
    @classmethod
    def empty(cls, *args, **kwargs) -> Self:
        if 'sequence_type' in kwargs.keys():
            sequence_type = kwargs['sequence_type']
        else:
            sequence_type = cls.default_sequence_type
        return cls(sequence_type(), *args, **kwargs)

    @property
    def device(self) -> Optional[torch.device]:
        if self.is_empty:
            return None
        
        device = None
        for data in self.data_seq:
            if hasattr(data, 'device'):
                device = data.device
        
        return device

    def new(self, **kwargs) -> Self:
        """
        Returns a new object which is a shallow copy of original object, but with data and metadata that are specified as kwargs being replaced. 
        """
        for arg in (['data_seq'] + list(self.metadata_args)):
            if arg not in kwargs.keys():
                kwargs[arg] = getattr(self, arg)

        return self.__class__(**kwargs)
    
    def to(self, *args, **kwargs) -> Self:
        """
        similar to pytorch Tensor objects' .to() method
        """
        if self.is_empty:
            return self
        else:
            data_seq = self.sequence_type()
            for data in self.data_seq:
                if isinstance(data, DataAbstractBase):
                    data = data.to(*args, **kwargs)
                    data_seq.append(data)
                elif isinstance(data, torch.Tensor):
                    assert '__tensor' not in kwargs.keys(), f"Don't use __tensor as a keyward arguments. It is reserved."
                    data = _torch_tensor_to(data, *args, **kwargs)
                    data_seq.append(data)
                elif hasattr(data, 'to'):
                    raise NotImplementedError(f"'to()' is not implemented for data type {type(data)}")
                else:
                    data_seq.append(data)

            return self.new(data_seq=data_seq)
    
    def get_data_dict(self, *args, **kwargs) -> Dict[str, Any]:
        """
        Returns recursive data dictionary.
        Similar to torch.nn.Module's .state_dict() method.
        """
        data_seq = self.sequence_type()
        for data in self.data_seq:
            if isinstance(data, DataAbstractBase):
                data = data.get_data_dict(*args, **kwargs)
                data_seq.append(data)
            elif isinstance(data, torch.Tensor):
                assert '__tensor' not in kwargs.keys(), f"Don't use __tensor as a keyward arguments. It's reserved."
                data = _torch_tensor_to(data, *args, **kwargs)
                data_seq.append(data)
            else:
                data_seq.append(data)
        
        data_dict = {
            'metadata': self.metadata,
            'data_seq': data_seq
        }
        
        return data_dict
    
    @classmethod
    def from_data_dict(cls, data_dict: Dict[str, Any], *args, **kwargs) -> Self:
        """
        Reconstruct data object from dictionary.
        """
        inputs: Dict[str, Any] = {}
        for arg, val in data_dict.items():
            if arg == 'metadata':
                assert isinstance(val, Dict), f"data_dict['metadata'] must be a dictionary but {type(val)} is provided."
                assert cls.__name__ == val['__type__'], f"Class type {cls.__name__} does not match with type annotated in metadata ({val['__type__']})"
            elif arg == 'data_seq':
                assert isinstance(val, Sequence)
                data_seq = type(val)()

                if issubclass(type_, DataAbstractBase):
                    assert isinstance(val, Dict), f"For arg of type {type(arg)}, data_dict[arg] must be a dictionary"
                    assert 'metadata' in val.keys(), f"For arg of type {type(arg)}, data_dict[arg] must be a dictionary, and has 'metadata' as a key"
                    assert type_.__name__ == val['metadata']['__type__'], f"{type_.__name__} != {val['metadata']['__type__']}"
                    val = type_.from_data_dict(data_dict=val, *args, **kwargs)
                else:
                    assert isinstance(val, type_), f"type({arg}) = {type(val)} != {type_}"
                    if isinstance(val, torch.Tensor):
                        val = _torch_tensor_to(__tensor=val, *args, **kwargs)
                inputs[arg] = val
            else:
                raise KeyError(f"Unknown attribute {arg} found in data_dict.")
        
        if 'metadata' in data_dict.keys():
            metadata = data_dict['metadata']
            assert isinstance(metadata, Dict)
            for arg in metadata.keys():
                assert arg not in inputs.keys(), f"metadata_arg {arg} already exists as a data argument!"
        else:
            metadata = {}
        
        input_kwargs = {}
        for k,v in {**inputs, **metadata}.items():
            if k=='__type__':  # __type__ is not required as an argument to the class constructor
                pass
            else:
                input_kwargs[k] = v
        return cls(**input_kwargs)
    
    def __repr__(self, abbrv: bool = False) -> str:
        if abbrv:
            prefix = ''
            bullet = '- '
        else:
            prefix = '  '
            bullet = prefix + '  - '            

        if abbrv:
            repr = ""
        else:
            repr = f"<{self.__class__.__name__}>  (device: {str(self.device)})\n"

        if not abbrv:
            repr += prefix + "Metadata: \n"
        # for arg in self.metadata_args:
        #     obj = getattr(self, arg)
        #     repr += bullet + f"{arg}: {obj.__repr__()}\n"
        for arg, obj in self.metadata.items():
            if arg == '__type__':
                pass
            else:
                repr += bullet + f"{arg}: {obj.__repr__()}\n"

        if not abbrv:
            repr += prefix + "Data: \n"
        for arg in self.data_args_type.keys():
            obj = getattr(self, arg)

            repr += bullet + f"{arg}: <{type(obj).__name__}>"
            if hasattr(obj, 'shape'):
                repr += ' (Shape: ' + obj.shape.__repr__() + ')\n'
                if abbrv:
                    subrepr = ''
                else:
                    subrepr: str = obj.__repr__()
            elif isinstance(obj, DataAbstractBase):
                repr += '\n'
                if abbrv:
                    subrepr = ''
                else:
                    subrepr: str = obj.__repr__(abbrv=True)
            else:
                repr += '\n'
                if abbrv:
                    subrepr = ''
                else:
                    subrepr: str = obj.__repr__()
            
            if abbrv:
                indent = ' ' * (len(bullet))
            else:
                indent = ' ' * (len(bullet) + 4)
            subrepr = subrepr.replace('\n', '\n' + indent)
            subrepr += '\n'
            repr += indent + subrepr

        return repr

IndentationError: expected an indented block (3307311504.py, line 56)

In [None]:
sdaf

In [None]:
# @beartype
# class DataSequenceAbstract(DataAbstractBase):
#     sequence_type: type = list
#     data_seq: Sequence[DataAbstractBase]

#     data_args_type: Dict[str, type] = {}
#     metadata_args: List[str]
    

#     def __len__(self) -> int:
#         return len(self.data_seq)
    
#     @property
#     def is_empty(self) -> bool:
#         if len(self) == 0:
#             return True
#         else:
#             return False
        
#     @classmethod
#     def empty(cls, *args, **kwargs) -> Self:
#         return cls(cls.sequence_type(), *args, **kwargs)

#     @property
#     def device(self) -> torch.device:
#         if self.is_empty:
#             raise AttributeError("The 'device' property is ambiguous for empty data seqeunce.")
        
#         return self.data_seq[0].device

#     def __init__(self):
#         assert not self.data_args_type, f"Don't use self.data_args_hint."
#         super().__init__()
#         # for data in data_seq:
#         #     assert type(data_type_list)
#         # if device is None:
#         #     self.data_seq = self.sequence_type(data for data in data_seq)
#         # else:
#         #     self.data_seq = self.sequence_type(data.to(device) for data in data_seq)

#     def new(self, **kwargs) -> Self:
#         """
#         Returns a new object which is a shallow copy of original object, but with data and metadata that are specified as kwargs being replaced. 
#         """
#         for arg in (['data_seq'] + list(self.metadata_args)):
#             if arg not in kwargs.keys():
#                 kwargs[arg] = getattr(self, arg)

#         return self.__class__(**kwargs)
    
#     def to(self, *args, **kwargs) -> Self:
#         if self.is_empty:
#             return self
#         else:
#             return self.__class__(data_seq=self.sequence_type(data.to(*args, **kwargs) for data in self.data_seq))
    
#     def get_data_dict(self, *args, **kwargs) -> Dict[str, Any]:
#         """
#         Returns recursive data dictionary.
#         Similar to torch.nn.Module's .state_dict() method.
#         """
#         data_dict = {}
#         for i, data in enumerate(self.data_seq):
#             if isinstance(data, DataAbstractBase):
#                 data = data.get_data_dict(*args, **kwargs)
#             elif isinstance(data, torch.Tensor):
#                 assert '__tensor' not in kwargs.keys(), f"Don't use __tensor as a keyward arguments. It's reserved."
#                 data = _torch_tensor_to(data, *args, **kwargs)
#             data_dict[data] = data
#         data_dict['metadata'] = self.metadata
        
#         return data_dict
    
#     @classmethod
#     def from_data_dict(cls, data_dict: Dict[str, Any], *args, **kwargs) -> Self:
#         inputs: Dict[str, Any] = {}
#         for arg, val in data_dict.items():
#             if arg == 'metadata':
#                 continue
#             else:
#                 assert arg in cls.data_args_hint.keys(), f"Unknown data argument: {arg}"
#                 hint = cls.data_args_hint[arg]
#                 if issubclass(hint, DataAbstractBase):
#                     val = hint.from_data_dict(data_dict=val, *args, **kwargs)
#                 else:
#                     assert isinstance(val, hint), f"type({arg}) = {type(val)} != {hint}"
#                     if isinstance(val, torch.Tensor):
#                         val = _torch_tensor_to(__tensor=val, *args, **kwargs)
#                 inputs[arg] = val
        
#         if 'metadata' in data_dict.keys():
#             metadata = data_dict['metadata']
#             assert isinstance(metadata, Dict)
#             for arg, val in metadata.items():
#                 assert arg not in inputs.keys(), f"metadata_arg {arg} already exists as a data argument!"
#         else:
#             metadata = {}
        
#         inputs = {**inputs, **metadata}
        
#         return cls(**inputs)
    
#     def __repr__(self, abbrv: bool = False) -> str:
#         if abbrv:
#             prefix = ''
#             bullet = '- '
#         else:
#             prefix = '  '
#             bullet = prefix + '  - '            

#         if abbrv:
#             repr = ""
#         else:
#             repr = f"<{self.__class__.__name__}>  (device: {str(self.device)})\n"

#         if not abbrv:
#             repr += prefix + "Metadata: \n"
#         for arg in self.metadata_args_hint.keys():
#             obj = getattr(self, arg)
#             repr += bullet + f"{arg}: {obj.__repr__()}\n"

#         if not abbrv:
#             repr += prefix + "Data: \n"
#         for arg in self.data_args_hint.keys():
#             obj = getattr(self, arg)

#             repr += bullet + f"{arg}: <{type(obj).__name__}>"
#             if hasattr(obj, 'shape'):
#                 repr += ' (Shape: ' + obj.shape.__repr__() + ')\n'
#                 if abbrv:
#                     subrepr = ''
#                 else:
#                     subrepr: str = obj.__repr__()
#             elif isinstance(obj, DataAbstractBase):
#                 repr += '\n'
#                 if abbrv:
#                     subrepr = ''
#                 else:
#                     subrepr: str = obj.__repr__(abbrv=True)
#             else:
#                 repr += '\n'
#                 if abbrv:
#                     subrepr = ''
#                 else:
#                     subrepr: str = obj.__repr__()
            
#             if abbrv:
#                 indent = ' ' * (len(bullet))
#             else:
#                 indent = ' ' * (len(bullet) + 4)
#             subrepr = subrepr.replace('\n', '\n' + indent)
#             subrepr += '\n'
#             repr += indent + subrepr

#         return repr
    
#     def __str__(self) -> str:
#         return self.__repr__()

In [None]:
@beartype
class DataSequenceAbstract(DataAbstractBase):
    data_list: List[DataAbstractBase]
    data_type_list: List[type]
    metadata_args_hint: Dict[str, type]

    def __len__(self) -> int:
        return len(self.data_list)
    
    @property
    def is_empty(self) -> bool:
        if len(self) == 0:
            return True
        else:
            return False
        
    @classmethod
    def empty(cls, *args, **kwargs) -> Self:
        return cls([], *args, **kwargs)

    @property
    def device(self) -> torch.device:
        if self.is_empty:
            raise AttributeError("The 'device' property is ambiguous for empty DataList.")
        
        return self.data_list[0].device

    def __init__(self, data_list: Iterable[DataAbstractBase], device: Optional[Union[str, torch.device]] = None):
        super().__init__()
        if device is None:
            self.data_list = list(data for data in data_list)
        else:
            self.data_list = list(data.to(device) for data in data_list)

    def new(self, **kwargs) -> Self:
        raise NotImplementedError

    def _torch_tensor_to(self, device: Optional[_device]=None, 
                         dtype: Optional[_dtype]=None, 
                         non_blocking: bool = False, 
                         copy: bool = False, 
                         *args, **kwargs) -> Dict[str, torch.Tensor]:
        raise NotImplementedError
    
    def _data_to(self, *args, **kwargs) -> Dict[str, DataAbstractBase]:
        raise NotImplementedError
    
    def to(self, *args, **kwargs) -> Self:
        if self.is_empty:
            return self
        else:
            return self.__class__(data_list=list(data.to(*args, **kwargs) for data in self.data_list))
    
    def get_data_dict(self, *args, **kwargs) -> Dict[str, Any]:
        data_dict = {}
        for arg in self.data_args_hint.keys():
            obj = getattr(self, arg)
            if isinstance(obj, DataAbstractBase):
                obj = obj.get_data_dict(*args, **kwargs)
            elif isinstance(obj, torch.Tensor):
                assert '__tensor' not in kwargs.keys(), f"Don't use __tensor as a keyward arguments. It is reserved."
                obj = _torch_tensor_to(__tensor = obj, *args, **kwargs)
            data_dict[arg] = obj
        data_dict['metadata'] = self.metadata
        
        return data_dict
    
    @classmethod
    def from_data_dict(cls, data_dict: Dict[str, Any], *args, **kwargs) -> Self:
        inputs: Dict[str, Any] = {}
        for arg, val in data_dict.items():
            if arg == 'metadata':
                continue
            else:
                assert arg in cls.data_args_hint.keys(), f"Unknown data argument: {arg}"
                hint = cls.data_args_hint[arg]
                if issubclass(hint, DataAbstractBase):
                    val = hint.from_data_dict(data_dict=val, *args, **kwargs)
                else:
                    assert isinstance(val, hint), f"type({arg}) = {type(val)} != {hint}"
                    if isinstance(val, torch.Tensor):
                        val = _torch_tensor_to(__tensor=val, *args, **kwargs)
                inputs[arg] = val
        
        if 'metadata' in data_dict.keys():
            metadata = data_dict['metadata']
            assert isinstance(metadata, Dict)
            for arg, val in metadata.items():
                assert arg not in inputs.keys(), f"metadata_arg {arg} already exists as a data argument!"
        else:
            metadata = {}
        
        inputs = {**inputs, **metadata}
        
        return cls(**inputs)
    
    def __repr__(self, abbrv: bool = False) -> str:
        if abbrv:
            prefix = ''
            bullet = '- '
        else:
            prefix = '  '
            bullet = prefix + '  - '            

        if abbrv:
            repr = ""
        else:
            repr = f"<{self.__class__.__name__}>  (device: {str(self.device)})\n"

        if not abbrv:
            repr += prefix + "Metadata: \n"
        for arg in self.metadata_args_hint.keys():
            obj = getattr(self, arg)
            repr += bullet + f"{arg}: {obj.__repr__()}\n"

        if not abbrv:
            repr += prefix + "Data: \n"
        for arg in self.data_args_hint.keys():
            obj = getattr(self, arg)

            repr += bullet + f"{arg}: <{type(obj).__name__}>"
            if hasattr(obj, 'shape'):
                repr += ' (Shape: ' + obj.shape.__repr__() + ')\n'
                if abbrv:
                    subrepr = ''
                else:
                    subrepr: str = obj.__repr__()
            elif isinstance(obj, DataAbstractBase):
                repr += '\n'
                if abbrv:
                    subrepr = ''
                else:
                    subrepr: str = obj.__repr__(abbrv=True)
            else:
                repr += '\n'
                if abbrv:
                    subrepr = ''
                else:
                    subrepr: str = obj.__repr__()
            
            if abbrv:
                indent = ' ' * (len(bullet))
            else:
                indent = ' ' * (len(bullet) + 4)
            subrepr = subrepr.replace('\n', '\n' + indent)
            subrepr += '\n'
            repr += indent + subrepr

        return repr
    
    def __str__(self) -> str:
        return self.__repr__()