In [1]:
from nestedtensor import torch
def print_eval(s):
    print(("\033[1;31m$ " + s + ":\033[0m").ljust(30) + "\n{}\n".format(str(eval(s))))

## Manipulating shape and indexing

In [2]:
nt2 = torch.nested_tensor(
[
    [
        torch.tensor([[1.0, 0.5], [0.1, 0.6]]),
        torch.tensor([[5.5, 3.3], [2.2, 6.6]])
    ],
    [
        torch.tensor([[3.0, 1.0], [0.5, 0.7]]),
        torch.tensor([[5.0, 4.0], [1.0, 2.0]])
    ]
])

print_eval("nt2")
print_eval("nt2.size()")
print_eval("nt2.nested_dim()")
print_eval("nt2.nested_size()")

[1;31m$ nt2:[0m             
nested_tensor([
	[
		tensor([[1.0000, 0.5000],
		        [0.1000, 0.6000]]),
		tensor([[5.5000, 3.3000],
		        [2.2000, 6.6000]])
	],
	[
		tensor([[3.0000, 1.0000],
		        [0.5000, 0.7000]]),
		tensor([[5., 4.],
		        [1., 2.]])
	]
])

[1;31m$ nt2.size():[0m      
(2, 2, 2, 2)

[1;31m$ nt2.nested_dim():[0m
2

[1;31m$ nt2.nested_size():[0m
torch.NestedSize((
	(
		torch.Size([2, 2]),
		torch.Size([2, 2])
	),
	(
		torch.Size([2, 2]),
		torch.Size([2, 2])
	)
))



In [3]:
nt3 = nt2.to_tensor(1)
print_eval("nt2")
print_eval("nt3")
print_eval("nt3.size()")
print_eval("nt3.nested_dim()")
print_eval("nt3.nested_size()")

[1;31m$ nt2:[0m             
nested_tensor([
	[
		tensor([[1.0000, 0.5000],
		        [0.1000, 0.6000]]),
		tensor([[5.5000, 3.3000],
		        [2.2000, 6.6000]])
	],
	[
		tensor([[3.0000, 1.0000],
		        [0.5000, 0.7000]]),
		tensor([[5., 4.],
		        [1., 2.]])
	]
])

[1;31m$ nt3:[0m             
nested_tensor([
	tensor([[[1.0000, 0.5000],
	         [0.1000, 0.6000]],
	
	        [[5.5000, 3.3000],
	         [2.2000, 6.6000]]]),
	tensor([[[3.0000, 1.0000],
	         [0.5000, 0.7000]],
	
	        [[5.0000, 4.0000],
	         [1.0000, 2.0000]]])
])

[1;31m$ nt3.size():[0m      
(2, 2, 2, 2)

[1;31m$ nt3.nested_dim():[0m
1

[1;31m$ nt3.nested_size():[0m
torch.NestedSize((
	torch.Size([2, 2, 2]),
	torch.Size([2, 2, 2])
))



In [4]:
nt4 = nt2.to_tensor(0)
print_eval("nt2")
print_eval("nt4")
print_eval("nt4.size()")
# print_eval("nt4.nested_dim()") Will crash. nt4 is a regular Tensor!
# print_eval("nt4.nested_size()") Will crash. nt4 is a regular Tensor!

[1;31m$ nt2:[0m             
nested_tensor([
	[
		tensor([[1.0000, 0.5000],
		        [0.1000, 0.6000]]),
		tensor([[5.5000, 3.3000],
		        [2.2000, 6.6000]])
	],
	[
		tensor([[3.0000, 1.0000],
		        [0.5000, 0.7000]]),
		tensor([[5., 4.],
		        [1., 2.]])
	]
])

[1;31m$ nt4:[0m             
tensor([[[[1.0000, 0.5000],
          [0.1000, 0.6000]],

         [[5.5000, 3.3000],
          [2.2000, 6.6000]]],


        [[[3.0000, 1.0000],
          [0.5000, 0.7000]],

         [[5.0000, 4.0000],
          [1.0000, 2.0000]]]])

[1;31m$ nt4.size():[0m      
torch.Size([2, 2, 2, 2])



In [5]:
print_eval("nt2")
print_eval("nt2[0][0]")
print_eval("nt2[0, 0]")
print_eval("nt2[:, 0]")
print_eval("nt2[0, :]")

[1;31m$ nt2:[0m             
nested_tensor([
	[
		tensor([[1.0000, 0.5000],
		        [0.1000, 0.6000]]),
		tensor([[5.5000, 3.3000],
		        [2.2000, 6.6000]])
	],
	[
		tensor([[3.0000, 1.0000],
		        [0.5000, 0.7000]]),
		tensor([[5., 4.],
		        [1., 2.]])
	]
])

[1;31m$ nt2[0][0]:[0m       
tensor([[1.0000, 0.5000],
        [0.1000, 0.6000]])

[1;31m$ nt2[0, 0]:[0m       
nested_tensor([
	tensor([1.0000, 0.5000]),
	tensor([5.5000, 3.3000])
])

[1;31m$ nt2[:, 0]:[0m       
nested_tensor([
	tensor([[1.0000, 0.5000],
	        [0.1000, 0.6000]]),
	tensor([[3.0000, 1.0000],
	        [0.5000, 0.7000]])
])

[1;31m$ nt2[0, :]:[0m       
nested_tensor([
	tensor([[1.0000, 0.5000],
	        [0.1000, 0.6000]]),
	tensor([[5.5000, 3.3000],
	        [2.2000, 6.6000]])
])



In [6]:
# Advanced indexing is allowed over tensor dimensions
print_eval("nt2")
print_eval("nt2[:, :, (1, 0)]")

[1;31m$ nt2:[0m             
nested_tensor([
	[
		tensor([[1.0000, 0.5000],
		        [0.1000, 0.6000]]),
		tensor([[5.5000, 3.3000],
		        [2.2000, 6.6000]])
	],
	[
		tensor([[3.0000, 1.0000],
		        [0.5000, 0.7000]]),
		tensor([[5., 4.],
		        [1., 2.]])
	]
])

[1;31m$ nt2[:, :, (1, 0)]:[0m
nested_tensor([
	[
		tensor([[0.1000, 0.6000],
		        [1.0000, 0.5000]]),
		tensor([[2.2000, 6.6000],
		        [5.5000, 3.3000]])
	],
	[
		tensor([[0.5000, 0.7000],
		        [3.0000, 1.0000]]),
		tensor([[1., 2.],
		        [5., 4.]])
	]
])



In [7]:
# Advanced indexing using binary mask
print_eval("nt2")
ind = torch.tensor(((1, 0), (0, 1)))
print_eval("ind")
print_eval("nt2[:, :, ind]")

[1;31m$ nt2:[0m             
nested_tensor([
	[
		tensor([[1.0000, 0.5000],
		        [0.1000, 0.6000]]),
		tensor([[5.5000, 3.3000],
		        [2.2000, 6.6000]])
	],
	[
		tensor([[3.0000, 1.0000],
		        [0.5000, 0.7000]]),
		tensor([[5., 4.],
		        [1., 2.]])
	]
])

[1;31m$ ind:[0m             
tensor([[1, 0],
        [0, 1]])

[1;31m$ nt2[:, :, ind]:[0m  
nested_tensor([
	[
		tensor([[[0.1000, 0.6000],
		         [1.0000, 0.5000]],
		
		        [[1.0000, 0.5000],
		         [0.1000, 0.6000]]]),
		tensor([[[2.2000, 6.6000],
		         [5.5000, 3.3000]],
		
		        [[5.5000, 3.3000],
		         [2.2000, 6.6000]]])
	],
	[
		tensor([[[0.5000, 0.7000],
		         [3.0000, 1.0000]],
		
		        [[3.0000, 1.0000],
		         [0.5000, 0.7000]]]),
		tensor([[[1., 2.],
		         [5., 4.]],
		
		        [[5., 4.],
		         [1., 2.]]])
	]
])



In [8]:
# Ellipsis
print_eval("nt2")
print_eval("nt2[:, :, ..., 0]")
print("$ nt2[..., 0]")
try:
    nt2[..., 0]
except NotImplementedError as e:
    print(str(e))

[1;31m$ nt2:[0m             
nested_tensor([
	[
		tensor([[1.0000, 0.5000],
		        [0.1000, 0.6000]]),
		tensor([[5.5000, 3.3000],
		        [2.2000, 6.6000]])
	],
	[
		tensor([[3.0000, 1.0000],
		        [0.5000, 0.7000]]),
		tensor([[5., 4.],
		        [1., 2.]])
	]
])

[1;31m$ nt2[:, :, ..., 0]:[0m
nested_tensor([
	[
		tensor([1.0000, 0.1000]),
		tensor([5.5000, 2.2000])
	],
	[
		tensor([3.0000, 0.5000]),
		tensor([5., 1.])
	]
])

$ nt2[..., 0]
Ellipsis is not yet supported for nested dimensions
