## Manipulating shape and indexing

In [1]:
import torch
import nestedtensor
from IPython.display import Markdown, display

def print_eval(s):
    colorS = "<span style='color:darkred'>$ {}</span>".format(s)
    display(Markdown('**{}**'.format(colorS))) 
    print('{}\n'.format(str(eval(s))))

In [2]:
nt2 = nestedtensor.nested_tensor(
[
    [
        torch.tensor([[1.0, 0.5], [0.1, 0.6]]),
        torch.tensor([[5.5, 3.3], [2.2, 6.6]])
    ],
    [
        torch.tensor([[3.0, 1.0], [0.5, 0.7]]),
        torch.tensor([[5.0, 4.0], [1.0, 2.0]])
    ]
])
print_eval('nt2')

**<span style='color:darkred'>$ nt2</span>**

nested_tensor([
	[
		tensor([[1.0000, 0.5000],
		        [0.1000, 0.6000]]),
		tensor([[5.5000, 3.3000],
		        [2.2000, 6.6000]])
	],
	[
		tensor([[3.0000, 1.0000],
		        [0.5000, 0.7000]]),
		tensor([[5., 4.],
		        [1., 2.]])
	]
])



In [3]:
nt3 = nt2.to_tensor(1)
print_eval("nt2")
print_eval("nt3")
print_eval("nt3.size()")
print_eval("nt3.nested_dim()")
print_eval("nt3.nested_size()")

**<span style='color:darkred'>$ nt2</span>**

nested_tensor([
	[
		tensor([[1.0000, 0.5000],
		        [0.1000, 0.6000]]),
		tensor([[5.5000, 3.3000],
		        [2.2000, 6.6000]])
	],
	[
		tensor([[3.0000, 1.0000],
		        [0.5000, 0.7000]]),
		tensor([[5., 4.],
		        [1., 2.]])
	]
])



**<span style='color:darkred'>$ nt3</span>**

nested_tensor([
	tensor([[[1.0000, 0.5000],
	         [0.1000, 0.6000]],
	
	        [[5.5000, 3.3000],
	         [2.2000, 6.6000]]]),
	tensor([[[3.0000, 1.0000],
	         [0.5000, 0.7000]],
	
	        [[5.0000, 4.0000],
	         [1.0000, 2.0000]]])
])



**<span style='color:darkred'>$ nt3.size()</span>**

(2, 2, 2, 2)



**<span style='color:darkred'>$ nt3.nested_dim()</span>**

1



**<span style='color:darkred'>$ nt3.nested_size()</span>**

torch.NestedSize((
	torch.Size([2, 2, 2]),
	torch.Size([2, 2, 2])
))



In [4]:
nt4 = nt2.to_tensor(0)
print_eval("nt2")
print_eval("nt4")
print_eval("nt4.size()")
# print_eval("nt4.nested_dim()") Will crash. nt4 is a regular Tensor!
# print_eval("nt4.nested_size()") Will crash. nt4 is a regular Tensor!

**<span style='color:darkred'>$ nt2</span>**

nested_tensor([
	[
		tensor([[1.0000, 0.5000],
		        [0.1000, 0.6000]]),
		tensor([[5.5000, 3.3000],
		        [2.2000, 6.6000]])
	],
	[
		tensor([[3.0000, 1.0000],
		        [0.5000, 0.7000]]),
		tensor([[5., 4.],
		        [1., 2.]])
	]
])



**<span style='color:darkred'>$ nt4</span>**

tensor([[[[1.0000, 0.5000],
          [0.1000, 0.6000]],

         [[5.5000, 3.3000],
          [2.2000, 6.6000]]],


        [[[3.0000, 1.0000],
          [0.5000, 0.7000]],

         [[5.0000, 4.0000],
          [1.0000, 2.0000]]]])



**<span style='color:darkred'>$ nt4.size()</span>**

torch.Size([2, 2, 2, 2])



In [5]:
print_eval("nt2")
print_eval("nt2[0][0]")
print_eval("nt2[0, 0]")
print_eval("nt2[:, 0]")
print_eval("nt2[0, :]")

**<span style='color:darkred'>$ nt2</span>**

nested_tensor([
	[
		tensor([[1.0000, 0.5000],
		        [0.1000, 0.6000]]),
		tensor([[5.5000, 3.3000],
		        [2.2000, 6.6000]])
	],
	[
		tensor([[3.0000, 1.0000],
		        [0.5000, 0.7000]]),
		tensor([[5., 4.],
		        [1., 2.]])
	]
])



**<span style='color:darkred'>$ nt2[0][0]</span>**

tensor([[1.0000, 0.5000],
        [0.1000, 0.6000]])



**<span style='color:darkred'>$ nt2[0, 0]</span>**

nested_tensor([
  tensor([1.0000, 0.5000]),
  tensor([5.5000, 3.3000]),
])



**<span style='color:darkred'>$ nt2[:, 0]</span>**

nested_tensor([
  tensor([[1.0000, 0.5000],
        [0.1000, 0.6000]]),
  tensor([[3.0000, 1.0000],
        [0.5000, 0.7000]]),
])



**<span style='color:darkred'>$ nt2[0, :]</span>**

nested_tensor([
  tensor([[1.0000, 0.5000],
        [0.1000, 0.6000]]),
  tensor([[5.5000, 3.3000],
        [2.2000, 6.6000]]),
])



In [6]:
# Advanced indexing is allowed over tensor dimensions
print_eval("nt2")
print_eval("nt2[:, :, (1, 0)]")

**<span style='color:darkred'>$ nt2</span>**

nested_tensor([
	[
		tensor([[1.0000, 0.5000],
		        [0.1000, 0.6000]]),
		tensor([[5.5000, 3.3000],
		        [2.2000, 6.6000]])
	],
	[
		tensor([[3.0000, 1.0000],
		        [0.5000, 0.7000]]),
		tensor([[5., 4.],
		        [1., 2.]])
	]
])



**<span style='color:darkred'>$ nt2[:, :, (1, 0)]</span>**

nested_tensor([
  nested_tensor([
  tensor([[0.1000, 0.6000],
        [1.0000, 0.5000]]),
  tensor([[2.2000, 6.6000],
        [5.5000, 3.3000]]),
]),
  nested_tensor([
  tensor([[0.5000, 0.7000],
        [3.0000, 1.0000]]),
  tensor([[1., 2.],
        [5., 4.]]),
]),
])



In [7]:
# Advanced indexing using binary mask
print_eval("nt2")
ind = torch.tensor(((1, 0), (0, 1)))
print_eval("ind")
print_eval("nt2[:, :, ind]")

**<span style='color:darkred'>$ nt2</span>**

nested_tensor([
	[
		tensor([[1.0000, 0.5000],
		        [0.1000, 0.6000]]),
		tensor([[5.5000, 3.3000],
		        [2.2000, 6.6000]])
	],
	[
		tensor([[3.0000, 1.0000],
		        [0.5000, 0.7000]]),
		tensor([[5., 4.],
		        [1., 2.]])
	]
])



**<span style='color:darkred'>$ ind</span>**

tensor([[1, 0],
        [0, 1]])



**<span style='color:darkred'>$ nt2[:, :, ind]</span>**

nested_tensor([
  nested_tensor([
  tensor([[[0.1000, 0.6000],
         [1.0000, 0.5000]],

        [[1.0000, 0.5000],
         [0.1000, 0.6000]]]),
  tensor([[[2.2000, 6.6000],
         [5.5000, 3.3000]],

        [[5.5000, 3.3000],
         [2.2000, 6.6000]]]),
]),
  nested_tensor([
  tensor([[[0.5000, 0.7000],
         [3.0000, 1.0000]],

        [[3.0000, 1.0000],
         [0.5000, 0.7000]]]),
  tensor([[[1., 2.],
         [5., 4.]],

        [[5., 4.],
         [1., 2.]]]),
]),
])



In [8]:
# Ellipsis
print_eval("nt2")
print_eval("nt2[:, :, ..., 0]")
print("$ nt2[..., 0]")
try:
    nt2[..., 0]
except NotImplementedError as e:
    print(str(e))

**<span style='color:darkred'>$ nt2</span>**

nested_tensor([
	[
		tensor([[1.0000, 0.5000],
		        [0.1000, 0.6000]]),
		tensor([[5.5000, 3.3000],
		        [2.2000, 6.6000]])
	],
	[
		tensor([[3.0000, 1.0000],
		        [0.5000, 0.7000]]),
		tensor([[5., 4.],
		        [1., 2.]])
	]
])



**<span style='color:darkred'>$ nt2[:, :, ..., 0]</span>**

nested_tensor([
  nested_tensor([
  tensor([1.0000, 0.1000]),
  tensor([5.5000, 2.2000]),
]),
  nested_tensor([
  tensor([3.0000, 0.5000]),
  tensor([5., 1.]),
]),
])

$ nt2[..., 0]
Ellipsis is not yet supported for nested dimensions
