In [1]:
import torch

In [2]:
# This is within core, the end user never have to look at this
class WrapperTensor(torch.Tensor):
    @staticmethod
    def __new__(cls, *args, **kwargs):
        t, kwargs = cls.get_wrapper_properties(*args, **kwargs)
        if "size" not in kwargs:
            size = t.size()
        else:
            size = kwargs["size"]
            del kwargs["size"]
        if "dtype" not in kwargs:
            kwargs["dtype"] = t.dtype
        if "layout" not in kwargs:
            kwargs["layout"] = t.layout
        if "device" not in kwargs:
            kwargs["device"] = t.device
        if "requires_grad" not in kwargs:
            kwargs["requires_grad"] = False
        # Ignore memory_format and pin memory for now as I don't know how to
        # safely access them on a Tensor (if possible??)

        wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs)
        wrapper._validate_methods()
        return wrapper

    @classmethod
    def get_wrapper_properties(cls, *args, **kwargs):
        # Should return both an example Tensor and a dictionaly of kwargs
        # to override any of that example Tensor's properly.
        # This is very similar to the `t.new_*(args)` API
        raise NotImplementedError("You need to implement get_wrapper_properties")

    def _validate_methods(self):
        # Skip this if not in debug mode?
        # Changing these on the python side is wrong as it would not be properly reflected
        # on the c++ side
        # This doesn't catch attributes set in the __init__
        forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"]
        for el in forbidden_overrides:
            if getattr(self.__class__, el) is not getattr(torch.Tensor, el):
                raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the "
                                   f"property {el} but this is not allowed as such change would "
                                   "not be reflected to c++ callers.")

    def __repr__(self):
        return f"{self.__class__.__name__}({self.__dict__})"


In [3]:
from torch.utils._pytree import tree_map

# Concept of wrapper Tensor is that there is a Tensor object without storage
# that represent what your Tensor should be. And you can store any other
# object in there.
# For DiagTensor, the wrapper will be 2D while the stored Tensor is 1D
class DiagTensor(WrapperTensor):
    @classmethod
    def get_wrapper_properties(cls, diag):
        return diag, {"size": diag.size() + diag.size()}

    def __init__(self, diag):
        self._diag = diag

    @property
    def data(self):
      return self._diag

    @data.setter
    def data(self, data):
        self._diag = data

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(e):
            return torch.diag(e._diag) if isinstance(e, DiagTensor) else e

        def wrap(e):
            return DiagTensor(torch.diag(e)) if isinstance(e, torch.Tensor) else e

        rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
        return rs

In [4]:
t = DiagTensor(torch.rand(2))

In [5]:
print("Doing add and mul")
out = (t + 2) * t
print(out)

print("Wrapper Tensor size")
print(out.size())
print("Contained Tensor size")
print(out._diag.size())

Doing add and mul
DiagTensor({'_diag': tensor([1.4086, 2.7097])})
Wrapper Tensor size
torch.Size([2, 2])
Contained Tensor size
torch.Size([2])


In [9]:
t.data = torch.tensor([1., 2.], dtype=torch.int8)

In [13]:
t

DiagTensor({'_diag': tensor([1, 2], dtype=torch.int8)})