In [None]:
## Nested tensor masking

'''
NestedTensor comes with two APIs that allow representing a NestedTensor as a pair of a data tensor with a mask tensor and another way around. 
Essentially, data tensor stores all the values, nesting, dimensionality, and metadata(dtype, layout, etc.) of the original NestedTensor. As NestedTensor allows tensors of different sizes, result data tensor will be padded. Therefore a mask is required to tell us which values are real and which ones are padding. 
'''

In [None]:
## to_tensor_mask
# data_tensor, mask = nt.to_tensor_mask(mask_dim=None)
# data_tensor - result data tensor.
# mask - result mask.
# nt - original nested tensor.
# mask_dim - desired dimensionality of the resulting mask.

'''
This method returns a data tensor with a mask that represents nested tensor. Result data tensor is always the same for a given nested tensor no matter what mask_dim is passed. This tensor represents all the values, dimensionality, and nesting of the original nested tensor.
The resulting mask depends on mask_dim value. It's important to understand the relationship between data tensor and a mask. The same nested tensor can be represented with a few combinations of the same data tensor and masks of different dimensionality. If no value for mask_dim was passed, the mask with the lowest possible dimensionality would be returned. It's easier to understand by looking at the examples.
'''
import torch
import nestedtensor as nt

a = nt.nested_tensor([
                nt.nested_tensor([
                    torch.tensor([1])
                ]),
                nt.nested_tensor([
                    torch.tensor([2])
                ]),
                nt.nested_tensor([
                    torch.tensor([3])
                ])
            ])
tensor, mask = a.to_tensor_mask()
print("\n\nmask_dim = None")
print("Data tensor: ", tensor)
print("Mask: ", mask)

# As you can see, we have only a boolean scalar which tells us that no elements from data tensor should be ignored. This is possible only in cases when we have nested tensor where all the leaf nodes have the same size.

tensor, mask = a.to_tensor_mask(mask_dim=3)
# Data tensor will always be the same
print("\n\nmask_dim = 3")
print("Data tensor: ", tensor)

# This is the highest dimensionality of the mask that we can get for a given nested tensor. We can see mask value per each tensor element.
print("Mask: ", mask)
    

In [20]:
# If the desired mask dimension is too small or too big to represent the nested tensor, an error will be thrown.
a = nt.nested_tensor([
            torch.tensor([1, 2,]),
            torch.tensor([3, 4, 5, 6]),
        ])

try:
    a.to_tensor_mask(mask_dim=1)
except RuntimeError as error:
    print(error)

try:
    a.to_tensor_mask(mask_dim=10)
except RuntimeError as error:
    print(error)

Mask dimension is too small to represent data tensor.
Mask dimension is bigger than nested dimension of a nested tensor.


In [28]:
## nested_tensor_from_tensor_mask
# result_nt = nested_tensor.nested_tensor_from_tensor_mask(tensor, mask, nested_dim=1)
# result_nt - result nested tensor
# tensor - input data tensor
# mask - input mask
# nested_dim - desired nested dimensionality of the result nested tensor

'''
This method returns a nested tensor which was constructed from passed tensor, mask, and optional dimensionality value. 
Resulting nested tensor dimensionality depends on nested_dim value. nested_dim has a default value of 1 and has to be between 1 and data tensor dimensionality.
'''

tensor = torch.tensor([[1, 2, 3],
                       [4, 0, 0]])
        
mask = torch.tensor([[ True,  True,  True],
                     [ True, False, False]])

# nested_dim has a default value of 1
print("nested_dim = 1")
print(nt.nested_tensor_from_tensor_mask(tensor, mask))

print("\nnested_dim = 2")
print(nt.nested_tensor_from_tensor_mask(tensor, mask, nested_dim = 2))

nested_dim = 1
nested_tensor([
	tensor([1, 2, 3]),
	tensor([4])
])

nested_dim = 2
nested_tensor([
	nested_tensor([
		tensor(1),
		tensor(2),
		tensor(3)
	]),
	nested_tensor([
		tensor(4)
	])
])
