### The tensorwise  decorator

The nestedtensor package allows the user to decorate existing functions with a tensorwise decorator. This decorator lifts the given function to check for NestedTensor arguments and apply recursively apply it to their constiuents with 
all other arguments untouched.

In [1]:
from nestedtensor import torch
from torch import tensorwise

In [2]:
@tensorwise()
def simple_fn(t1, t2):
    return t1 + 1 + t2


a = torch.tensor([1, 2])
b = torch.tensor([7, 8])
print(simple_fn(a, b))

tensor([ 9, 11])


Decorating the function as tensorwise does not affect its behavior with respect to non-NestedTensor arguments. In particular, the tensorwise decorator will search all arguments for a NestedTensor and if none is found dispatch to exactly the given function.

The next example creates two NestedTensors each a combination of the given above Tensors which we can then pass into the decorated function.

In [3]:
nt1 = torch.nested_tensor([a, b])
nt2 = torch.nested_tensor([b, a])
print(nt1)
print(nt2)


nested_tensor([
	tensor([1, 2]),
	tensor([7, 8])
])
nested_tensor([
	tensor([7, 8]),
	tensor([1, 2])
])


In [23]:
print(simple_fn(nt1, nt2))
print(a)
print(nt1)
print(nt2)
print(simple_fn(a, nt2))
print(simple_fn(a, torch.nested_tensor([nt1, nt2])))
print(simple_fn(nt1, torch.nested_tensor([nt1, nt2])))

nested_tensor([
	tensor([ 9, 11]),
	tensor([ 9, 11])
])
tensor([1, 2])
nested_tensor([
	tensor([1, 2]),
	tensor([7, 8])
])
nested_tensor([
	tensor([7, 8]),
	tensor([1, 2])
])
nested_tensor([
	tensor([ 9, 11]),
	tensor([3, 5])
])
nested_tensor([
	[
		tensor([3, 5]),
		tensor([ 9, 11])
	],
	[
		tensor([ 9, 11]),
		tensor([3, 5])
	]
])
nested_tensor([
	[
		tensor([3, 5]),
		tensor([ 9, 11])
	],
	[
		tensor([15, 17]),
		tensor([ 9, 11])
	]
])


We can write functionally equivalent code via a regular Python for-loop to further illustrate the behavior

In [5]:
print([simple_fn(t1, t2) for t1, t2 in zip([a, b], [b, a])])
    

[tensor([ 9, 11]), tensor([ 9, 11])]


In some sense we can view this as an unrolling operation, or in PyTorch terms, unbind. Unbind without further arguments returns a list of Tensor constiuents across the 0-th dimension.

In [6]:
print(a.unbind())
print(nt1.unbind())

(tensor(1), tensor(2))
(tensor([1, 2]), tensor([7, 8]))


Sometimes we might also want to unbind non-Tensor arguments. For this case tensorwise requires these arguments to have a definition of ```__getitem__```, e.g. lists or torch.Tensors.

In [21]:
@tensorwise(unbind_args=[2])
def simple_fn_scalar(t1, t2, scalar,):
    return t1 + scalar + t2

print(a)
print(b)
print(simple_fn_scalar(a, b, 3.0))

tensor([1, 2])
tensor([7, 8])
tensor([11., 13.])


In [22]:
print(nt1)
print(nt2)
print(simple_fn_scalar(nt1, nt2, (2.0, 3.0)))

nested_tensor([
	tensor([1, 2]),
	tensor([7, 8])
])
nested_tensor([
	tensor([7, 8]),
	tensor([1, 2])
])
nested_tensor([
	tensor([10., 12.]),
	tensor([11., 13.])
])
