## NestedTensor as unifying datastructure for non-uniform Tensor input


See [the corresponding RFC for more background on motivation](https://docs.google.com/document/d/1VdKG5JA0U8iiwd6eYpUlCItm3zNJns8_ooJvaH_JWV8/edit#).

In general this construct is meant as a container with the following layouts as inspired by the cited operators.

In [1]:
from enum import Enum
class Layout(Enum):
    Masked = 0 # Example: TransformerEncoderLayer or CrossEntropyLoss by using the mask to fill with padding_idx
    Packed = 1 # Example: EmbeddingBag
    PackedSequence = 2 # Restricted to RNN
    List = 3 # Fallback and default for quick creation

The following hidden cell is an incomplete implementation of this using torch_function. This structure does layout conversions via a ```to``` method and provides a unified constructor, which accepts a list of Tensors and that allows the specification of a layout.

In [2]:
#@title
import torch
from enum import Enum

def _nn_functional_embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
                                 scale_grad_by_freq=False, mode='mean', sparse=False,
                                 per_sample_weights=None, include_last_offset=False):
    # [...] Omitted input sanitization
    # [...] Verify that nested_size is shape compliant, i.e. all 1d Tensors (sequences)
    # Design decision: conversion happens automatically. This is similar to how we automatically
    # make Tensor contiguous or convert from fp16 to fp32 or sparse to dense if needed.
    # We could decide to throw a warning here.
    input = input.to(Layout.Packed)
    offsets = torch.tensor([0] + [x[0] for x in input.nested_size()[:-1]]).cumsum(0)
    # We could consider caching this metadata in NestedTensor
    offsets = offsets.to(data.device)
    assert input.layout is Layout.Packed
    return torch.nn.functional.embedding_bag(
        input.data,
        weight,
        offsets,
        max_norm,
        norm_type,
        scale_grad_by_freq,
        mode,
        sparse,
        per_sample_weights,
        include_last_offset)

def nested_tensor(tensors, layout=Layout.List, dtype=None, device=None, requires_grad=False): # pin_memory could be added as a layout
    """
    Given a list of Tensors, each of the same dimension but variable shape, construct a NestedTensor that represents
    this list of Tensors.

    If a given entry of tensors does not match the dtype or device of the others, the result dtype or device needs to
    be specified explicitly
    """
    assert layout is Layout.List # No other layout support for now
    assert isinstance(tensors, list)
    assert len(tensors) > 0
    dtype = tensors[0].dtype if dtype is None else dtype
    device = tensors[0].device if device is None else device
    # Change dtype and device if necessary
    tensors = [t.to(device, dtype) for t in tensors]
    nested_size = tuple(x.size() for x in tensors)
    return NestedTensor(tensors, nested_size, Layout.List, dtype, device, requires_grad).to(layout)

def _from_packed_sequence_to_list(packed_sequence):
    padded, lengths = torch.nn.utils.rnn.pad_packed_sequence(packed_sequence, batch_first=True)
    tensors = []
    for i, length in enumerate(lengths):
        tensors.append(padded[i, :length])
    return tensors

def as_nested_tensor(data, layout=Layout.List, dtype=None, device=None, requires_grad=False): # pin_memory could be added as a layout
    """
    Similar to torch.as_tensor, this converts the given data into a NestedTensor.
    """
    if isinstance(data, torch.nn.utils.rnn.PackedSequence):
        return nested_tensor(_from_packed_sequence_to_list(data))
    raise NotImplementedError("as_nested_tensor cannot convert data of type {} into a NestedTensor.".format(type(data)))


def _from_list_to_layout(list_nt, target_layout):
    assert list_nt.layout is Layout.List
    if target_layout is Layout.List:
        return list_nt
    if target_layout is Layout.Masked:
        max_size = [len(list_nt.data)]
        for d in range(list_nt.data[0].dim()):
            max_size.append(max(x.size(d) for x in list_nt.data))
        # This approach doesn't support autograd and can also be used during construction or without autograd
        # An approach that does work with autograd uses pad and cat, but is a bit more involved
        # See https://github.com/pytorch/nestedtensor/blob/master/nestedtensor/nested/masking.py#L142 for a complete implementation
        data = torch.zeros(*max_size, dtype=list_nt.dtype, device=list_nt.device)
        mask = torch.zeros(*max_size, dtype=torch.bool, device=list_nt.device)
        for d_t, d_m, t in zip(data, mask, list_nt.data):
            for d in range(t.dim()):
                d_t = d_t.narrow(d, 0, t.size(d))
                d_m = d_m.narrow(d, 0, t.size(d))
            d_t.copy_(t.detach())
            d_m.fill_(1)
        return NestedTensor(data, list_nt.nested_size(), Layout.Masked, list_nt.dtype, list_nt.device, list_nt.requires_grad, metadata=mask)
    if target_layout is Layout.Packed:
        offsets_ = list_nt.nested_size()
        data = torch.cat([x.reshape(-1) for x in list_nt.data]) # shape information is stored in nested_size
        return NestedTensor(data, list_nt.nested_size(), Layout.Packed, list_nt.dtype, list_nt.device, list_nt.requires_grad)
    if target_layout is Layout.PackedSequence:
        return NestedTensor(torch.nn.utils.rnn.pack_sequence(list_nt.data, enforce_sorted=False), # enforce_sorted set to False doesn't support ONNX for now,
                            list_nt.nested_size(),
                            Layout.PackedSequence,
                            list_nt.dtype,
                            list_nt.device,
                            list_nt.requires_grad)
    raise NotImplemented("Converstion from list to target layout {} not supported.".format(target_layout.name))
            
class NestedTensor(object):
    def __init__(self, data, nested_size, layout, dtype, device, requires_grad, metadata=None):
        # Can be list of tensors, single packed or masked Tensor or PackedSequence
        self.data = data
        # Metadata is overloaded with type and meaning
        # Masked: Stores bool mask where True means included, False means excluded
        # Packed: Stores 1d Tensor of offsets. offsets are the length of each entry in the flat data. Packed currently only supports 2d NestedTensors
        # PackedSequence: Stores the lengths of the PackedSequence
        self.metadata = metadata
        self._nested_size = nested_size
        self._layout = layout
        self._dtype = dtype
        self._device = device
        # Gradient is supported by differentiable layout conversion functions a tracked by data field
        self._requires_grad = requires_grad 

    def __torch_function__(self, func, types, args=(), kwargs=None):
        if func is torch.nn.functional.embedding_bag:
            # Design decision pending: We could make conversion to Layout.Padding automatic
            return _nn_functional_embedding_bag(*args, **kwargs)
        raise NotImplementedError("Given func {} does not support NestedTensor.".format(func))

    def nested_size(self):
        return self._nested_size

    @property
    def dtype(self):
        return self._dtype

    @property
    def layout(self):
        return self._layout

    @property
    def device(self):
        return self._device

    @property
    def requires_grad(self):
        return self._requires_grad

    # There are 5 layouts, therefore there are 20 possible
    # conversions excluding identities
    def to(self, target_layout):
        assert isinstance(target_layout, Layout)
        if self.layout is target_layout:
            return self
        if self.layout is Layout.List:
            return _from_list_to_layout(self, target_layout)
        raise NotImplementedError(
            "Cannot convert {} to desired layout {}".format(
                self.layout.name, target_layout.name))

    
    def to_tensor_list(self):
        # Returns a list of Tensors
        return self.to(Layout.List).data

    def to_padded(self, padding_value=-1):
        # Returns a Tensor padded with padding_value
        converted = self.to(Layout.Masked)
        return converted.data.masked_fill_(~converted.metadata, padding_value)

    def to_masked(self):
        # Returns a Tensor plus a Bool mask of same shape
        converted = self.to(Layout.Masked)
        return converted.data, converted.mask

    def to_packed_sequence(self):
        return self.to(Layout.PackedSequence).data
              

Let's step through an intended usecase and compare it a current application.

The following EmbeddingBag represents a lookupt table of 10 vectors, each of dimension 3.

In [3]:
import torch
from torch import nn
embedding_bag = nn.EmbeddingBag(10, 3)

Let's construct a list of tensors filled with a varying degree of word ids and feed it into EmbeddingBag as we were to right now.

In [4]:
sentences = [torch.tensor([0, 3, 1]), torch.tensor([5, 1, 2, 4]), torch.tensor([3, 2])]

In [5]:
data = torch.cat(sentences)
offsets = torch.tensor([0] + [len(x) for x in sentences[:-1]]).cumsum(0)
print(offsets)
print(embedding_bag(data, offsets))

tensor([0, 3, 7])
tensor([[-0.0482,  0.0242, -0.6505],
        [-0.6074,  0.6866, -0.4335],
        [ 0.5125, -0.1862, -0.8296]], grad_fn=<EmbeddingBagBackward>)


And this is what it'll look like with NestedTensor

In [6]:
nt = nested_tensor(sentences)
print(nt.nested_size())
embedding_bag(nt)

(torch.Size([3]), torch.Size([4]), torch.Size([2]))


tensor([[-0.0482,  0.0242, -0.6505],
        [-0.6074,  0.6866, -0.4335],
        [ 0.5125, -0.1862, -0.8296]], grad_fn=<EmbeddingBagBackward>)

Is it going to be less efficient to first construct a NestedTensor and then convert into an operator specific data structure? If we do this automatically we have the chance of optimizing a conversion, but we also run the risk of converting prematurely or in an inefficient way. This is the usual lazy vs. eager tradeoff and the current PyTorch convention seem to lean towards automatic conversion (e.g. when given non-contiguous inputs, sparse inputs (usually) or inputs of other dtype).

In [7]:
print(nt.to_padded())

tensor([[ 0,  3,  1, -1],
        [ 5,  1,  2,  4],
        [ 3,  2, -1, -1]])


In [8]:
print(nt.to_tensor_list())

[tensor([0, 3, 1]), tensor([5, 1, 2, 4]), tensor([3, 2])]


In [9]:
rnn = nn.RNN(5, #embedding dimension
             3, 2)
h0 = torch.randn(2, 3, 3)
embeddings = [s.unsqueeze(1).repeat(1, 5) #emulating embedding
              for s in sentences]
nt = nested_tensor(embeddings, dtype=torch.float)

try:
  rnn(nt) # 
except AttributeError as e:
  print(e)

'NestedTensor' object has no attribute 'size'


RNN doesn't have good torch_function support, but luckily we can just convert manually into the desired format.

In [10]:
ps = nt.to_packed_sequence()
output, hn = rnn(ps, h0)
print(output)


PackedSequence(data=tensor([[ 0.1349, -0.1506,  0.8108],
        [ 0.6356,  0.2794,  0.7581],
        [-0.1012,  0.3027,  0.9623],
        [ 0.3990, -0.0811,  0.6990],
        [ 0.0292, -0.2913,  0.7972],
        [ 0.3070, -0.4692,  0.7617],
        [ 0.2164, -0.0570,  0.7273],
        [ 0.4771,  0.0845,  0.6256],
        [-0.0036, -0.2968,  0.7427]], grad_fn=<CatBackward>), batch_sizes=tensor([3, 3, 2, 1]), sorted_indices=tensor([1, 0, 2]), unsorted_indices=tensor([1, 0, 2]))


And now we use the as_nested_tensor function (similar to torch.as_tensor) to interpret the resulting value (which is also a PackedSequence) as a NestedTensor again. This is useful in particular when you're about to feed this output into a linear layer as your final projection before the loss, because you can retrieve the padded version of your output.

In [11]:
output_nt = as_nested_tensor(output)
padded_output = output_nt.to_padded(0)
print(padded_output.size())
print(padded_output)

torch.Size([3, 4, 3])
tensor([[[ 0.6356,  0.2794,  0.7581],
         [ 0.0292, -0.2913,  0.7972],
         [ 0.4771,  0.0845,  0.6256],
         [ 0.0000,  0.0000,  0.0000]],

        [[ 0.1349, -0.1506,  0.8108],
         [ 0.3990, -0.0811,  0.6990],
         [ 0.2164, -0.0570,  0.7273],
         [-0.0036, -0.2968,  0.7427]],

        [[-0.1012,  0.3027,  0.9623],
         [ 0.3070, -0.4692,  0.7617],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]]])


In [12]:
loss = nn.NLLLoss()

In [13]:
targets = torch.tensor([1, 2, 1, -100, 2, 1, 1, 2, 1, 1, -100, -100])
loss(padded_output.reshape(-1, 3), targets)

tensor(-0.2678)