In [1]:
from nestedtensor import torch
def print_eval(s):
    print(("\033[1;31m$ " + s + ":\033[0m").ljust(30) + "\n{}\n".format(str(eval(s))))

## Custom nn.functionals

By default all nn.functionals are implemented as a tensorwise function. However, in some cases we want to support custom semantics that come about by slight modifications to the lifted function. Take nn.functional.conv2d as an example.



In [2]:
nt = torch.nested_tensor([
    torch.rand(3, 10, 30),
    torch.rand(3, 20, 40),
    torch.rand(3, 30, 50)
])
nt1 = torch.nested_tensor([
    torch.rand(1, 3, 10, 30),
    torch.rand(1, 3, 20, 40),
    torch.rand(1, 3, 30, 50)
])
weight = torch.rand(64, 3, 7, 7)
print_eval("nt.size()")

[1;31m$ nt.size():[0m       
(3, 3, None, None)



By default this function fails, because the components do not have a batch dimension.

In [3]:
try:
    print_eval("torch.tensorwise()(torch.nn.functional.conv2d)(nt, weight)")
except RuntimeError as e:
    print_eval("str(e)")
    
print_eval("torch.tensorwise()(torch.nn.functional.conv2d)(nt1, weight).size()")

[1;31m$ str(e):[0m          
Expected 4-dimensional input for 4-dimensional weight 64 3 7 7, but got 3-dimensional input of size [3, 10, 30] instead

[1;31m$ torch.tensorwise()(torch.nn.functional.conv2d)(nt1, weight).size():[0m
(3, 1, 64, None, None)



However, NestedTensors implement a version of conv2d that doesn't require a batch dimension for ease of use and for efficiency (more on that later).

In [4]:
print_eval("torch.nn.functional.conv2d(nt, weight).size()")
# print_eval("torch.nn.functional.conv2d(nt1, weight).size()")
print(nt1.flatten(2).size())

[1;31m$ torch.nn.functional.conv2d(nt, weight).size():[0m
(3, 64, None, None)

(3, 1, None)


We have a similar story for nn.functional.embedding_bag. The lifted version only works on elements of batch size 1, unless given an offset, which is an unnecessary annoyance. We extend the lifted embedding_bag to support inputs of dimension 1, if offset is set to None.

In [5]:
nt2 = (torch.nested_tensor([
    torch.rand(1, 30),
    torch.rand(1, 40),
    torch.rand(1, 50)
]) * 10).to(torch.int64)
nt3 = (torch.nested_tensor([
    torch.rand(30),
    torch.rand(40),
    torch.rand(50)
]) * 10).to(torch.int64)
nt4 = (torch.nested_tensor([
    [
        torch.rand(1, 30),
    ],
    [
        torch.rand(1, 40),
        torch.rand(1, 50)
    ]
]) * 10).to(torch.int64)


In [6]:
weight = torch.rand(100, 256)
print_eval("torch.nn.functional.embedding_bag(nt2, weight).nested_size()")
print_eval("torch.nn.functional.embedding_bag(nt3, weight).nested_size()")
print_eval("torch.nn.functional.embedding_bag(nt4, weight).nested_size()")
print_eval("torch.nn.EmbeddingBag(100, 256)(nt2).nested_size()")
print_eval("torch.nn.EmbeddingBag(100, 256)(nt3).nested_size()")
print_eval("torch.nn.EmbeddingBag(100, 256)(nt4).nested_size()")

[1;31m$ torch.nn.functional.embedding_bag(nt2, weight).nested_size():[0m
torch.NestedSize((
	torch.Size([1, 256]),
	torch.Size([1, 256]),
	torch.Size([1, 256])
))

[1;31m$ torch.nn.functional.embedding_bag(nt3, weight).nested_size():[0m
torch.NestedSize((
	torch.Size([256]),
	torch.Size([256]),
	torch.Size([256])
))

[1;31m$ torch.nn.functional.embedding_bag(nt4, weight).nested_size():[0m
torch.NestedSize((
	(
		torch.Size([1, 256])
	),
	(
		torch.Size([1, 256]),
		torch.Size([1, 256])
	)
))

[1;31m$ torch.nn.EmbeddingBag(100, 256)(nt2).nested_size():[0m
torch.NestedSize((
	torch.Size([1, 256]),
	torch.Size([1, 256]),
	torch.Size([1, 256])
))

[1;31m$ torch.nn.EmbeddingBag(100, 256)(nt3).nested_size():[0m
torch.NestedSize((
	torch.Size([256]),
	torch.Size([256]),
	torch.Size([256])
))

[1;31m$ torch.nn.EmbeddingBag(100, 256)(nt4).nested_size():[0m
torch.NestedSize((
	(
		torch.Size([1, 256])
	),
	(
		torch.Size([1, 256]),
		torch.Size([1, 256])
	)
))



In [7]:
nt3 = nt3.float()
print_eval("nt3")
print_eval("nt3.size()")
print_eval("nt3.nested_size()")
print_eval("torch.nested_tensor(nt3.nested_size(1))")
nt4 = nt3 / torch.nested_tensor(nt3.nested_size(1))
print_eval("nt4")
print_eval("nt4.size()")

[1;31m$ nt3:[0m             
nested_tensor([
	tensor([3., 6., 7., 2., 0., 6., 0., 3., 9., 2., 4., 4., 5., 0., 0., 6., 3., 7.,
	        7., 4., 7., 5., 9., 7., 5., 7., 6., 1., 7., 7.]),
	tensor([4., 0., 3., 0., 7., 9., 3., 4., 8., 4., 7., 6., 6., 3., 7., 5., 3., 0.,
	        1., 3., 4., 4., 6., 3., 9., 7., 6., 9., 5., 4., 6., 3., 6., 1., 3., 7.,
	        4., 4., 2., 2.]),
	tensor([4., 2., 9., 3., 4., 3., 1., 2., 0., 2., 1., 4., 3., 1., 5., 2., 8., 6.,
	        7., 3., 8., 3., 2., 0., 9., 0., 7., 4., 9., 2., 7., 6., 3., 3., 6., 4.,
	        2., 4., 3., 2., 0., 0., 0., 4., 9., 4., 7., 7., 6., 0.])
])

[1;31m$ nt3.size():[0m      
(3, None)

[1;31m$ nt3.nested_size():[0m
torch.NestedSize((
	torch.Size([30]),
	torch.Size([40]),
	torch.Size([50])
))

[1;31m$ torch.nested_tensor(nt3.nested_size(1)):[0m
nested_tensor([
	tensor(30),
	tensor(40),
	tensor(50)
])

[1;31m$ nt4:[0m             
nested_tensor([
	tensor([0.1000, 0.2000, 0.2333, 0.0667, 0.0000, 0.2000, 0.0000, 0.1000, 0.3000,

In [8]:
nt5 = torch.nested_tensor([
    torch.rand(30, 10),
    torch.rand(40, 10),
    torch.rand(50, 10)
])
print_eval("nt5.nested_size()")
print_eval("torch.mm(nt5, torch.rand(10, 5)).nested_size()")

[1;31m$ nt5.nested_size():[0m
torch.NestedSize((
	torch.Size([30, 10]),
	torch.Size([40, 10]),
	torch.Size([50, 10])
))

[1;31m$ torch.mm(nt5, torch.rand(10, 5)).nested_size():[0m
torch.NestedSize((
	torch.Size([30, 5]),
	torch.Size([40, 5]),
	torch.Size([50, 5])
))



In [9]:
print_eval("nt5.argmax(1)")
print_eval("nt5.argmax(1).size()")
print_eval("nt5.argmax(1).to_tensor()")

[1;31m$ nt5.argmax(1):[0m   
nested_tensor([
	tensor([19,  5,  1, 17, 27,  3, 27, 26, 24, 28]),
	tensor([ 8, 27, 23, 25, 24, 24,  6, 38,  8,  3]),
	tensor([32, 36,  9,  0, 29, 42, 46, 20, 46, 22])
])

[1;31m$ nt5.argmax(1).size():[0m
(3, 10)

[1;31m$ nt5.argmax(1).to_tensor():[0m
tensor([[19,  5,  1, 17, 27,  3, 27, 26, 24, 28],
        [ 8, 27, 23, 25, 24, 24,  6, 38,  8,  3],
        [32, 36,  9,  0, 29, 42, 46, 20, 46, 22]])



In [10]:
print_eval("nt5.nested_size()")
print_eval("nt5.argmax(2).nested_size()")
print_eval("torch.nn.functional.cross_entropy(nt5, nt5.argmax(2))")

[1;31m$ nt5.nested_size():[0m
torch.NestedSize((
	torch.Size([30, 10]),
	torch.Size([40, 10]),
	torch.Size([50, 10])
))

[1;31m$ nt5.argmax(2).nested_size():[0m
torch.NestedSize((
	torch.Size([30]),
	torch.Size([40]),
	torch.Size([50])
))

[1;31m$ torch.nn.functional.cross_entropy(nt5, nt5.argmax(2)):[0m
nested_tensor([
	tensor(1.9369),
	tensor(1.9437),
	tensor(1.9298)
])



In [11]:
nt6 = torch.nested_tensor([torch.rand(10, 10), torch.rand(20, 20), torch.rand(30, 30)])
print_eval("nt6.lu()[0].size()")
print_eval("nt6.lu()[1].size()")

[1;31m$ nt6.lu()[0].size():[0m
(3, None, None)

[1;31m$ nt6.lu()[1].size():[0m
(3, None)



In [12]:
nt7 = torch.nested_tensor([[torch.rand(1, 10), torch.rand(2, 20)], [torch.rand(3, 30)]])
nt8 = torch.nested_tensor([[torch.rand(10, 1), torch.rand(20, 2)], [torch.rand(30, 3)]])
print_eval("torch.mm(nt7, nt8)")

[1;31m$ torch.mm(nt7, nt8):[0m
nested_tensor([
	[
		tensor([[2.4037]]),
		tensor([[3.6637, 4.1156],
		        [4.9553, 4.3066]])
	],
	[
		tensor([[7.2216, 7.7442, 8.4139],
		        [8.2279, 7.1767, 8.8830],
		        [7.3217, 7.0047, 7.2424]])
	]
])

