### The tensorwise  decorator

The nestedtensor package allows the user to decorate existing functions with a tensorwise decorator. This decorator lifts the given function to check for NestedTensor arguments and apply recursively apply it to their constiuents with 
all other arguments untouched.

In [1]:
import torch
import nestedtensor
from nestedtensor import tensorwise
from IPython.display import Markdown, display

def print_eval(s):
    colorS = "<span style='color:darkred'>$ {}</span>".format(s)
    display(Markdown('**{}**'.format(colorS))) 
    print('{}\n'.format(str(eval(s))))

In [2]:
@tensorwise()
def sum_plus_one(t1, t2):
    return t1 + 1 + t2


a = torch.tensor([1, 2])
b = torch.tensor([7, 8])
print_eval("sum_plus_one(a, b)")

**<span style='color:darkred'>$ sum_plus_one(a, b)</span>**

tensor([ 9, 11])



Decorating the function as tensorwise does not affect its behavior with respect to non-NestedTensor arguments. In particular, the tensorwise decorator will search all arguments for a NestedTensor and if none is found dispatch to exactly the given function.

The next example creates two NestedTensors each a combination of the given above Tensors which we can then pass into the decorated function.

In [3]:
nt1 = nestedtensor.nested_tensor([a, b])
nt2 = nestedtensor.nested_tensor([b, a])
print_eval('nt1')
print_eval('nt2')

**<span style='color:darkred'>$ nt1</span>**

nested_tensor([
	tensor([1, 2]),
	tensor([7, 8])
])



**<span style='color:darkred'>$ nt2</span>**

nested_tensor([
	tensor([7, 8]),
	tensor([1, 2])
])



In [4]:
print_eval('sum_plus_one(nt1, nt2)')
print_eval('a')
print_eval('nt1')
print_eval('nt2')
print_eval('sum_plus_one(a, nt2)')
print_eval('sum_plus_one(a, nestedtensor.nested_tensor([nt1, nt2]))')
print_eval('sum_plus_one(nt1, nestedtensor.nested_tensor([nt1, nt2]))')

**<span style='color:darkred'>$ sum_plus_one(nt1, nt2)</span>**

nested_tensor([
	tensor([ 9, 11]),
	tensor([ 9, 11])
])



**<span style='color:darkred'>$ a</span>**

tensor([1, 2])



**<span style='color:darkred'>$ nt1</span>**

nested_tensor([
	tensor([1, 2]),
	tensor([7, 8])
])



**<span style='color:darkred'>$ nt2</span>**

nested_tensor([
	tensor([7, 8]),
	tensor([1, 2])
])



**<span style='color:darkred'>$ sum_plus_one(a, nt2)</span>**

nested_tensor([
	tensor([ 9, 11]),
	tensor([3, 5])
])



**<span style='color:darkred'>$ sum_plus_one(a, nestedtensor.nested_tensor([nt1, nt2]))</span>**

nested_tensor([
	[
		tensor([3, 5]),
		tensor([ 9, 11])
	],
	[
		tensor([ 9, 11]),
		tensor([3, 5])
	]
])



**<span style='color:darkred'>$ sum_plus_one(nt1, nestedtensor.nested_tensor([nt1, nt2]))</span>**

nested_tensor([
	[
		tensor([3, 5]),
		tensor([ 9, 11])
	],
	[
		tensor([15, 17]),
		tensor([ 9, 11])
	]
])



We can write functionally equivalent code via a regular Python for-loop to further illustrate the behavior

In [5]:
print([sum_plus_one(t1, t2) for t1, t2 in zip([a, b], [b, a])])
    

[tensor([ 9, 11]), tensor([ 9, 11])]


In some sense we can view this as an unrolling operation, or in PyTorch terms, unbind. Unbind without further arguments returns a list of Tensor constiuents across the 0-th dimension.

In [6]:
print_eval('a.unbind()')
print_eval('nt1.unbind()')

**<span style='color:darkred'>$ a.unbind()</span>**

(tensor(1), tensor(2))



**<span style='color:darkred'>$ nt1.unbind()</span>**

(tensor([1, 2]), tensor([7, 8]))



Sometimes we might also want to unbind non-Tensor arguments. For this case tensorwise requires these arguments to have a definition of ```__getitem__```, e.g. lists or torch.Tensors.

In [7]:
@tensorwise(unbind_args=[2])
def sum_plus_scalar(t1, t2, scalar,):
    return t1 + scalar + t2

print_eval('a')
print_eval('b')
print_eval('sum_plus_scalar(a, b, 3.0)')

**<span style='color:darkred'>$ a</span>**

tensor([1, 2])



**<span style='color:darkred'>$ b</span>**

tensor([7, 8])



**<span style='color:darkred'>$ sum_plus_scalar(a, b, 3.0)</span>**

tensor([11., 13.])



In [8]:
print_eval('nt1')
print_eval('nt2')
print_eval('sum_plus_scalar(nt1, nt2, (2.0, 3.0))')

**<span style='color:darkred'>$ nt1</span>**

nested_tensor([
	tensor([1, 2]),
	tensor([7, 8])
])



**<span style='color:darkred'>$ nt2</span>**

nested_tensor([
	tensor([7, 8]),
	tensor([1, 2])
])



**<span style='color:darkred'>$ sum_plus_scalar(nt1, nt2, (2.0, 3.0))</span>**

nested_tensor([
	tensor([10., 12.]),
	tensor([11., 13.])
])

