In [1]:
from nestedtensor import torch
from IPython.display import Markdown, display

def print_eval(s):
    colorS = "<span style='color:darkred'>$ {}</span>".format(s)
    display(Markdown('**{}**'.format(colorS))) 
    print('{}\n'.format(str(eval(s))))

## Custom nn.functionals

By default all nn.functionals are implemented as a tensorwise function. However, in some cases we want to support custom semantics that come about by slight modifications to the lifted function. Take nn.functional.conv2d as an example.



In [2]:
nt = torch.nested_tensor([
    torch.rand(3, 10, 30),
    torch.rand(3, 20, 40),
    torch.rand(3, 30, 50)
])
nt1 = torch.nested_tensor([
    torch.rand(1, 3, 10, 30),
    torch.rand(1, 3, 20, 40),
    torch.rand(1, 3, 30, 50)
])
weight = torch.rand(64, 3, 7, 7)
print_eval("nt.size()")

**<span style='color:darkred'>$ nt.size()</span>**

(3, 3, None, None)



By default this function fails, because the components do not have a batch dimension.

In [3]:
try:
    print_eval("torch.tensorwise()(torch.nn.functional.conv2d)(nt, weight)")
except RuntimeError as e:
    print_eval("str(e)")
    
print_eval("torch.tensorwise()(torch.nn.functional.conv2d)(nt1, weight).size()")

**<span style='color:darkred'>$ torch.tensorwise()(torch.nn.functional.conv2d)(nt, weight)</span>**

**<span style='color:darkred'>$ str(e)</span>**

Expected 4-dimensional input for 4-dimensional weight 64 3 7 7, but got 3-dimensional input of size [3, 10, 30] instead



**<span style='color:darkred'>$ torch.tensorwise()(torch.nn.functional.conv2d)(nt1, weight).size()</span>**

(3, 1, 64, None, None)



However, NestedTensors implement a version of conv2d that doesn't require a batch dimension for ease of use and for efficiency (more on that later).

In [4]:
print_eval("torch.nn.functional.conv2d(nt, weight).size()")

**<span style='color:darkred'>$ torch.nn.functional.conv2d(nt, weight).size()</span>**

(3, 64, None, None)



We have a similar story for nn.functional.embedding_bag. The lifted version only works on elements of batch size 1, unless given an offset, which is an unnecessary annoyance. We extend the lifted embedding_bag to support inputs of dimension 1, if offset is set to None.

In [5]:
nt2 = (torch.nested_tensor([
    torch.rand(1, 30),
    torch.rand(1, 40),
    torch.rand(1, 50)
]) * 10).to(torch.int64)
nt3 = (torch.nested_tensor([
    torch.rand(30),
    torch.rand(40),
    torch.rand(50)
]) * 10).to(torch.int64)
nt4 = (torch.nested_tensor([
    [
        torch.rand(1, 30),
    ],
    [
        torch.rand(1, 40),
        torch.rand(1, 50)
    ]
]) * 10).to(torch.int64)


In [6]:
weight = torch.rand(100, 256)
print_eval("torch.nn.functional.embedding_bag(nt2, weight).nested_size()")
print_eval("torch.nn.functional.embedding_bag(nt3, weight).nested_size()")
print_eval("torch.nn.functional.embedding_bag(nt4, weight).nested_size()")
print_eval("torch.nn.EmbeddingBag(100, 256)(nt2).nested_size()")
print_eval("torch.nn.EmbeddingBag(100, 256)(nt3).nested_size()")
print_eval("torch.nn.EmbeddingBag(100, 256)(nt4).nested_size()")

**<span style='color:darkred'>$ torch.nn.functional.embedding_bag(nt2, weight).nested_size()</span>**

torch.NestedSize((
	torch.Size([1, 256]),
	torch.Size([1, 256]),
	torch.Size([1, 256])
))



**<span style='color:darkred'>$ torch.nn.functional.embedding_bag(nt3, weight).nested_size()</span>**

torch.NestedSize((
	torch.Size([256]),
	torch.Size([256]),
	torch.Size([256])
))



**<span style='color:darkred'>$ torch.nn.functional.embedding_bag(nt4, weight).nested_size()</span>**

torch.NestedSize((
	(
		torch.Size([1, 256])
	),
	(
		torch.Size([1, 256]),
		torch.Size([1, 256])
	)
))



**<span style='color:darkred'>$ torch.nn.EmbeddingBag(100, 256)(nt2).nested_size()</span>**

torch.NestedSize((
	torch.Size([1, 256]),
	torch.Size([1, 256]),
	torch.Size([1, 256])
))



**<span style='color:darkred'>$ torch.nn.EmbeddingBag(100, 256)(nt3).nested_size()</span>**

torch.NestedSize((
	torch.Size([256]),
	torch.Size([256]),
	torch.Size([256])
))



**<span style='color:darkred'>$ torch.nn.EmbeddingBag(100, 256)(nt4).nested_size()</span>**

torch.NestedSize((
	(
		torch.Size([1, 256])
	),
	(
		torch.Size([1, 256]),
		torch.Size([1, 256])
	)
))



In [7]:
nt3 = nt3.float()
print_eval("nt3")
print_eval("nt3.size()")
print_eval("nt3.nested_size()")
print_eval("torch.nested_tensor(nt3.nested_size(1))")
nt4 = nt3 / torch.nested_tensor(nt3.nested_size(1))
print_eval("nt4")
print_eval("nt4.size()")

**<span style='color:darkred'>$ nt3</span>**

nested_tensor([
	tensor([3., 9., 4., 9., 7., 5., 0., 4., 4., 6., 7., 4., 4., 9., 0., 4., 9., 2.,
	        4., 4., 7., 3., 1., 7., 4., 1., 7., 1., 8., 9.]),
	tensor([5., 2., 6., 3., 9., 2., 9., 6., 7., 8., 0., 4., 5., 7., 3., 0., 5., 4.,
	        2., 5., 0., 8., 1., 8., 8., 0., 4., 8., 5., 8., 3., 2., 6., 8., 0., 3.,
	        8., 3., 7., 5.]),
	tensor([2., 0., 5., 1., 1., 7., 9., 3., 7., 1., 6., 2., 2., 4., 5., 2., 5., 2.,
	        4., 6., 7., 1., 0., 2., 4., 7., 8., 2., 1., 1., 4., 7., 1., 4., 9., 6.,
	        1., 6., 0., 2., 1., 6., 8., 7., 1., 5., 0., 4., 7., 0.])
])



**<span style='color:darkred'>$ nt3.size()</span>**

(3, None)



**<span style='color:darkred'>$ nt3.nested_size()</span>**

torch.NestedSize((
	torch.Size([30]),
	torch.Size([40]),
	torch.Size([50])
))



**<span style='color:darkred'>$ torch.nested_tensor(nt3.nested_size(1))</span>**

nested_tensor([
	tensor(30),
	tensor(40),
	tensor(50)
])



**<span style='color:darkred'>$ nt4</span>**

nested_tensor([
	tensor([0.1000, 0.3000, 0.1333, 0.3000, 0.2333, 0.1667, 0.0000, 0.1333, 0.1333,
	        0.2000, 0.2333, 0.1333, 0.1333, 0.3000, 0.0000, 0.1333, 0.3000, 0.0667,
	        0.1333, 0.1333, 0.2333, 0.1000, 0.0333, 0.2333, 0.1333, 0.0333, 0.2333,
	        0.0333, 0.2667, 0.3000]),
	tensor([0.1250, 0.0500, 0.1500, 0.0750, 0.2250, 0.0500, 0.2250, 0.1500, 0.1750,
	        0.2000, 0.0000, 0.1000, 0.1250, 0.1750, 0.0750, 0.0000, 0.1250, 0.1000,
	        0.0500, 0.1250, 0.0000, 0.2000, 0.0250, 0.2000, 0.2000, 0.0000, 0.1000,
	        0.2000, 0.1250, 0.2000, 0.0750, 0.0500, 0.1500, 0.2000, 0.0000, 0.0750,
	        0.2000, 0.0750, 0.1750, 0.1250]),
	tensor([0.0400, 0.0000, 0.1000, 0.0200, 0.0200, 0.1400, 0.1800, 0.0600, 0.1400,
	        0.0200, 0.1200, 0.0400, 0.0400, 0.0800, 0.1000, 0.0400, 0.1000, 0.0400,
	        0.0800, 0.1200, 0.1400, 0.0200, 0.0000, 0.0400, 0.0800, 0.1400, 0.1600,
	        0.0400, 0.0200, 0.0200, 0.0800, 0.1400, 0.0200, 0.0800, 0.1800, 0.1200,
	        0.0200

**<span style='color:darkred'>$ nt4.size()</span>**

(3, None)



In [8]:
nt5 = torch.nested_tensor([
    torch.rand(30, 10),
    torch.rand(40, 10),
    torch.rand(50, 10)
])
print_eval("nt5.nested_size()")
print_eval("torch.mm(nt5, torch.rand(10, 5)).nested_size()")

**<span style='color:darkred'>$ nt5.nested_size()</span>**

torch.NestedSize((
	torch.Size([30, 10]),
	torch.Size([40, 10]),
	torch.Size([50, 10])
))



**<span style='color:darkred'>$ torch.mm(nt5, torch.rand(10, 5)).nested_size()</span>**

torch.NestedSize((
	torch.Size([30, 5]),
	torch.Size([40, 5]),
	torch.Size([50, 5])
))



In [9]:
print_eval("nt5.argmax(1)")
print_eval("nt5.argmax(1).size()")
print_eval("nt5.argmax(1).to_tensor()")

**<span style='color:darkred'>$ nt5.argmax(1)</span>**

nested_tensor([
	tensor([21,  0, 24, 12, 16, 10, 21, 20,  5, 20]),
	tensor([ 4,  1, 31,  2, 15, 25,  0,  2, 32, 19]),
	tensor([42,  6, 14, 34, 42, 11, 48, 13, 36, 41])
])



**<span style='color:darkred'>$ nt5.argmax(1).size()</span>**

(3, 10)



**<span style='color:darkred'>$ nt5.argmax(1).to_tensor()</span>**

tensor([[21,  0, 24, 12, 16, 10, 21, 20,  5, 20],
        [ 4,  1, 31,  2, 15, 25,  0,  2, 32, 19],
        [42,  6, 14, 34, 42, 11, 48, 13, 36, 41]])



In [10]:
print_eval("nt5.nested_size()")
print_eval("nt5.argmax(2).nested_size()")
print_eval("torch.nn.functional.cross_entropy(nt5, nt5.argmax(2))")

**<span style='color:darkred'>$ nt5.nested_size()</span>**

torch.NestedSize((
	torch.Size([30, 10]),
	torch.Size([40, 10]),
	torch.Size([50, 10])
))



**<span style='color:darkred'>$ nt5.argmax(2).nested_size()</span>**

torch.NestedSize((
	torch.Size([30]),
	torch.Size([40]),
	torch.Size([50])
))



**<span style='color:darkred'>$ torch.nn.functional.cross_entropy(nt5, nt5.argmax(2))</span>**

nested_tensor([
	tensor(1.9496),
	tensor(1.9408),
	tensor(1.9305)
])



In [11]:
nt6 = torch.nested_tensor([torch.rand(10, 10), torch.rand(20, 20), torch.rand(30, 30)])
print_eval("nt6.lu()[0].size()")
print_eval("nt6.lu()[1].size()")

**<span style='color:darkred'>$ nt6.lu()[0].size()</span>**

(3, None, None)



**<span style='color:darkred'>$ nt6.lu()[1].size()</span>**

(3, None)



In [12]:
nt7 = torch.nested_tensor([[torch.rand(1, 10), torch.rand(2, 20)], [torch.rand(3, 30)]])
nt8 = torch.nested_tensor([[torch.rand(10, 1), torch.rand(20, 2)], [torch.rand(30, 3)]])
print_eval("torch.mm(nt7, nt8)")

**<span style='color:darkred'>$ torch.mm(nt7, nt8)</span>**

nested_tensor([
	[
		tensor([[3.0984]]),
		tensor([[5.3822, 5.3800],
		        [4.8560, 5.2896]])
	],
	[
		tensor([[9.7878, 5.3839, 7.7967],
		        [8.1811, 6.1959, 7.1056],
		        [8.6106, 5.1237, 6.6942]])
	]
])

