In [1]:
# This dispatch setup must only be executed once. The monkey patching determines whether it is overwriting already
# set attributes to make sure it is not overwriting critical code. This will change in the near future
# with PyTorch's improved extension support being added (akin to numpy).
from nestedtensor import torch
import time as time_module
def print_eval(s):
    print(("\033[1;31m$ " + s + ":\033[0m").ljust(30) + "\n{}\n".format(str(eval(s))))
def time(fn):
    t0 = time_module.time()
    count = 0
    past = 0
    while past < 10.0:
        fn()
        past = time_module.time() - t0
        count += 1
    past = past / count
    return "average {:2.4f}ms based on {} samples".format(past * 1000, count)

In [2]:
def generate_tensors(num_tensor, vocab_size):
    sentence_lengths = torch.normal(75.0, 10.0, size=(num_tensor,)).long()
    return [(torch.rand(l) * vocab_size).long() for l in sentence_lengths]

def generate_text(text):
    offsets = [0] + [len(entry) for entry in text]
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text = torch.cat(text)
    return text.to(torch.int64), offsets

class TextSentiment(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = torch.nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = torch.nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        emb = self.embedding(text, offsets)
        return self.fc(emb)
    
vocab_size = 10000
model = TextSentiment(10000, 256, 5)
tensors = generate_tensors(16, 10000)
text, offsets = generate_text(tensors)
nt_text = torch.nested_tensor(tensors)

print_eval("time(lambda: model(text, offsets))")
print_eval("time(lambda: model(nt_text, None))")

[1;31m$ time(lambda: model(text, offsets)):[0m
average 0.2323ms based on 43047 samples

[1;31m$ time(lambda: model(nt_text, None)):[0m
average 0.3689ms based on 27109 samples



In [3]:
from torchvision import models

model = models.resnet18(pretrained=False)
images = torch.rand(128, 3, 40, 50)
print_eval("images.numel()")
print_eval("time(lambda: model(images))")

nested_images = torch.nested_tensor(torch.rand(128, 3, 40, 50).unbind())
print_eval("time(lambda: model(nested_images))")

# There is still about a 10x gap in performance, which however
# can be significantly allieviated via custom code (e.g. using im2col).
images = [torch.rand(3, (i * 16) % 40 + 40, (i * 16) % 50 + 40) for i in range(64)]
nested_irregular_images = torch.nested_tensor(images)
print_eval("nested_irregular_images.numel()")
print_eval("nested_irregular_images.size()")
print_eval("time(lambda: model(nested_irregular_images))")

[1;31m$ images.numel():[0m  
768000

[1;31m$ time(lambda: model(images)):[0m
average 47.6147ms based on 211 samples

[1;31m$ time(lambda: model(nested_images)):[0m
average 113.2366ms based on 89 samples

[1;31m$ nested_irregular_images.numel():[0m
692112

[1;31m$ nested_irregular_images.size():[0m
(64, 3, None, None)

[1;31m$ time(lambda: model(nested_irregular_images)):[0m
average 1435.1884ms based on 7 samples



In [6]:
def generate_tensors(num_tensor, num_features):
    sentence_lengths = torch.normal(75.0, 10.0, size=(num_tensor,)).long()
    return [torch.rand(l.item(), num_features) for l in sentence_lengths]

tensors = generate_tensors(32, 256)
nt_text = torch.nested_tensor(tensors)
text = torch.rand(32, 75, 256)

h0 = torch.randn(6, len(nt_text), 512)
c0 = torch.randn(6, len(nt_text), 512)
print_eval("nt_text.nested_size(1)")
print_eval("nt_text.numel()")
print_eval("text.numel()")
print_eval("time(lambda: torch.nn.LSTM(256, 512, 6, batch_first=True)(nt_text, (h0, c0)))")
print_eval("time(lambda: torch.nn.LSTM(256, 512, 6, batch_first=True)(text, (h0, c0)))")

[1;31m$ nt_text.nested_size(1):[0m
(80, 71, 63, 65, 77, 85, 59, 62, 65, 80, 78, 85, 77, 78, 64, 80, 82, 72, 67, 90, 70, 67, 63, 70, 68, 79, 74, 86, 76, 69, 50, 70)

[1;31m$ nt_text.numel():[0m 
594432

[1;31m$ text.numel():[0m    
614400

[1;31m$ time(lambda: torch.nn.LSTM(256, 512, 6, batch_first=True)(nt_text, (h0, c0))):[0m
average 1812.3047ms based on 6 samples

[1;31m$ time(lambda: torch.nn.LSTM(256, 512, 6, batch_first=True)(text, (h0, c0))):[0m
average 372.5485ms based on 28 samples

