## Example of __torch_function__ on a tensor object

The `__torch_function__` will be invoked if it is found on one of the tensors as
part of the arguments into a `torch.` function. It's a layer above the PyTorch
dispatcher. It lets you override PyTorch Python APIs.

There's also `__torch_dispatch__` which corresponds to a specific C++ dispatch
key and lets you provide a backend implementation from Python.

In [1]:
import torch

class ScalarTensor(object):
   def __init__(self, N, value):
       self._N = N
       self._value = value

   def __repr__(self):
       return "ScalarTensor(N={}, value={})".format(self._N, self._value)

   def tensor(self):
       return self._value * torch.eye(self._N)

In [2]:
d = ScalarTensor(5, 2)

In [3]:
d

ScalarTensor(N=5, value=2)

In [4]:
d.tensor()

tensor([[2., 0., 0., 0., 0.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 0., 0.],
        [0., 0., 0., 2., 0.],
        [0., 0., 0., 0., 2.]])

In [5]:
torch.mean(d)  # type: ignore

TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor

In [6]:
d.mean()  # type: ignore

AttributeError: 'ScalarTensor' object has no attribute 'mean'

In [12]:
HANDLED_FUNCTIONS = {}
class ScalarTensor(object):
    def __init__(self, N, value):
        self._N = N
        self._value = value

    def __repr__(self):
        return "ScalarTensor(N={}, value={})".format(self._N, self._value)

    def tensor(self):
        return self._value * torch.eye(self._N)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
            return NotImplemented
        return HANDLED_FUNCTIONS[func](*args, **kwargs)
    
import functools
def implements(torch_function):
    """Register a torch function override for ScalarTensor"""
    def decorator(func):
        functools.update_wrapper(func, torch_function)
        HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator

@implements(torch.mean)
def mean(input):
    return float(input._value) / input._N

def ensure_tensor(data):
    if isinstance(data, ScalarTensor):
        return data.tensor()
    return torch.as_tensor(data)

@implements(torch.add)
def add(input, other):
   try:
       if input._N == other._N:
           return ScalarTensor(input._N, input._value + other._value)
       else:
           raise ValueError("Shape mismatch!")
   except AttributeError:
       return torch.add(ensure_tensor(input), ensure_tensor(other))


In [13]:
d = ScalarTensor(5, 2)
torch.mean(d)  # type: ignore

0.4

In [14]:
s = ScalarTensor(2, 2)
torch.add(s, s)  # type: ignore

ScalarTensor(N=2, value=4)

In [15]:
t = torch.tensor([[1, 1,], [1, 1]])
torch.add(s, t)  # type: ignore

tensor([[3., 1.],
        [1., 3.]])

Note that so far this only handles `torch.` functions. You cannot call `+` or
`.add`. To do that you will have to manually implement `add` and `__add__` on
`ScalarTensor`, or subclass `torch.Tensor`, which I can never get to work.

In [31]:
s + s  # type: ignore

AttributeError: 'ScalarTensor' object has no attribute '_N'

In [32]:
s.add(s)  # type: ignore

AttributeError: 'ScalarTensor' object has no attribute '_N'

In [37]:
HANDLED_FUNCTIONS = {}
class ScalarTensor(torch.Tensor):
    def __init__(self, N, value):
        self._N = N
        self._value = value

    def __repr__(self):
        return "ScalarTensor(N={}, value={})".format(self._N, self._value)

    def tensor(self):
        return self._value * torch.eye(self._N)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
            return super().__torch_function__(func, types, args, kwargs)
        return HANDLED_FUNCTIONS[func](*args, **kwargs)
    
    def __tensor_flatten__(self):
        return ["_N", "_value"], None

    @staticmethod
    def __tensor_unflatten__(inner_tensors, meta):
        assert meta is None
        N, value = inner_tensors["_N"], inner_tensors["_value"]
        return ScalarTensor(N, value)


import functools
def implements(torch_function):
    """Register a torch function override for ScalarTensor"""
    def decorator(func):
        functools.update_wrapper(func, torch_function)
        HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator

@implements(torch.mean)
def mean(input):
    return float(input._value) / input._N

def ensure_tensor(data):
    if isinstance(data, ScalarTensor):
        return data.tensor()
    return torch.as_tensor(data)

@implements(torch.add)
def add(input, other):
   try:
       if input._N == other._N:
           return ScalarTensor(input._N, input._value + other._value)
       else:
           raise ValueError("Shape mismatch!")
   except AttributeError:
       return torch.add(ensure_tensor(input), ensure_tensor(other))


In [53]:
s = ScalarTensor(5, 2)
a = s.mean()
print(type(a))
print(type(a.data))
# The answer is wrong still.
a.untyped_storage(), torch.mean(s)

<class '__main__.ScalarTensor'>
<class '__main__.ScalarTensor'>


( 230
  137
  202
  15
 [torch.storage.UntypedStorage(device=cpu) of size 4],
 0.4)