# Efficient TS batching via PackedSequence


- <https://discuss.pytorch.org/t/customized-rnn-cell-which-can-accept-packsequence/1067>
- https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.PackedSequence.html


In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import time

import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import (
    PackedSequence,
    pack_sequence,
    pad_packed_sequence,
    pad_sequence,
)

device = torch.device("cuda")
dtype = torch.float32

#### Classes:

- [PackedSequence](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.PackedSequence.html#torch.nn.utils.rnn.PackedSequence)

#### Functions:

- [pack_padded_sequence](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html#torch.nn.utils.rnn.pack_padded_sequence): 
    - inputs: `tuple[inputs: Tensor, lengths: Tensor]`
    - output: `PackedSequence[data: Tensor, batch_sizes: Tensor]`
    - signature: `[BS, max[LEN], *DIMS], [BS] -> [sum(LEN), *DIMS], [max[LEN]]`

- [pad_packed_sequence](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_packed_sequence.html#torch.nn.utils.rnn.pad_packed_sequence): 
    - inputs: `PackedSequence[data: Tensor, batch_sizes: Tensor]`
    - output: `tuple[inputs: Tensor, lengths: Tensor]`
    - signature: `[sum(LEN), *DIMS], [max[LEN]] -> [BS, max[LEN], *DIMS], [BS]`

- [pad_sequence](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_sequence.html#torch.nn.utils.rnn.pad_sequence): 
    - inputs: `list[Tensor]`
    - output: `Tensor`
    - signature: `BS×[LEN[k], *DIMS] -> [BS, max[LEN], *DIMS]`

- [pack_sequence](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_sequence.html#torch.nn.utils.rnn.pack_sequence): 
    - inputs: `list[Tensor]`
    - output: `PackedSequence[data: Tensor, batch_sizes: Tensor]`
    - signature `BS×[SEQ_LEN[k], *DIMS] -> [sum(LEN), *DIMS], [max[SEQ_LEN]]`

#### TODO:

- unpad_sequence: tuple[Tensor, Tensor] -> list[Tensor]

- unpack_sequnce: PackedSequence -> list[Tensor]

#### Questions: 

- How to apply loss functions directly on packed / padded Tensors?

## Notes

PackedSequence stores data in a peculiar way:

In [None]:
# model creation
batch_size = 4
input_size = 3
hidden_size = 5
seq_len_range = (2, 9)
num_batches = 10
low, high = 0, 9

rnn = nn.RNN(input_size, hidden_size, num_layers=4, bias=True, batch_first=True)
rnn.to(device)
rnn.zero_grad()

In [None]:
# data generation
batches = list()
for idx in range(num_batches):
    batch = []
    for k in range(batch_size):
        rand_len = np.random.randint(*seq_len_range)
        x = torch.randint(low, high, (rand_len, input_size), device=device)
        y = torch.randint(low, high, (rand_len, hidden_size), device=device)
        batch += [(x, y)]
    batch = sorted(batch, key=lambda x: x[0].size(0), reverse=True)
    batches += [batch]

In [None]:
batch = batches[0]
[[tensor.shape for tensor in x] for x in batch]

In [None]:
x = [x[0] for x in batch]

In [None]:
# torch.Size([222, 3])
# [LEN, 3]
packed = pack_sequence(x)

In [None]:
padded = pad_packed_sequence(packed)

In [None]:
[batch.shape for batch in batches]

In [None]:
def pack(sequence: list[torch.Tensor], **kwargs) -> tuple[PackedSequence, list[int]]:
    lengths = list(map(len, sequence))
    tensors = [tensor for length, tensor in zip(lengths, sequence) if length > 0]
    packed_sequence = pack_sequence(tensors, **kwargs)
    return packed_sequence, lengths


def unpack(packed_sequence: PackedSequence, lengths: list[int]) -> list[torch.Tensor]:
    device = packed_sequence.data.device
    dtype = packed_sequence.data.dtype
    trailing_dims = packed_sequence.data.shape[1:]
    unpacked_sequence = []
    idx_map = {}
    head = 0
    for b_idx, length in enumerate(lengths):
        unpacked_sequence.append(
            torch.zeros(length, *trailing_dims, device=device, dtype=dtype)
        )
        if length > 0:
            idx_map[head] = b_idx
            head += 1
    head = 0
    for l_idx, b_size in enumerate(packed_sequence.batch_sizes):
        for b_idx in range(b_size):
            unpacked_sequence[idx_map[b_idx]][l_idx] = packed_sequence.data[head]
            head += 1
    return unpacked_sequence

In [None]:
# data generation
batches = list()
for idx in range(num_batches):
    batch = []
    for k in range(batch_size):
        rand_len = np.random.randint(*seq_len_range)
        x = torch.rand((rand_len, input_size), device=device)
        y = torch.rand((rand_len, hidden_size), device=device)
        batch += [(x, y)]
    # batch = sorted(batch, key=lambda x: x[0].size(0), reverse=True)
    batches += [batch]

## Python loops = too slow

In [None]:
# for padded input
start = time.time()
for batch in batches:
    yhat = []
    l = torch.tensor(0, dtype=dtype, device=device)
    for x, y in batch:
        yhat = rnn(x.unsqueeze(0))[0].squeeze(dim=0)
        r = (y - yhat) ** 2
        l += torch.sum(r)
    l.backward()
    g = torch.cat([w.grad.flatten() for w in rnn.parameters()])
    rnn.zero_grad()
end = time.time()
print(f"elapsed time for padded input: {end - start} secs")
print(torch.sum(torch.isnan(g)))
print(r)

## Padded is much faster!

In [None]:
# for padded input
start = time.time()
for batch in batches:
    x, y = zip(*batch)
    x = pad_sequence(x, padding_value=np.nan, batch_first=True)
    y = pad_sequence(y, padding_value=np.nan, batch_first=True)
    yhat = rnn(x)[0]
    mask = torch.isnan(yhat)
    zero = torch.tensor(0, dtype=dtype, device=device)
    r = torch.where(mask, zero, (y - yhat) ** 2)
    l = torch.sum(r)
    l.backward()
    g = torch.cat([w.grad.flatten() for w in rnn.parameters()])
    rnn.zero_grad()
end = time.time()
print(f"elapsed time for padded input: {end - start} secs")
print(torch.sum(torch.isnan(g)))
print(r.flatten())

## Packed is also fast!

In [None]:
# for packed input
start = time.time()
for batch in batches:
    x, y = zip(*batch)
    x = pack_sequence(x)
    y = pack_sequence(y)
    yhat = rnn(x)[0]
    r = (y.data - yhat.data) ** 2
    l = torch.sum(r)
    l.backward()
    g = torch.cat([w.grad.flatten() for w in rnn.parameters()])
    rnn.zero_grad()
end = time.time()
print(f"elapsed time for packed input: {end - start} secs")
print(torch.sum(torch.isnan(g)))
print(r)

In [None]:
# for packed input with unpack
start = time.time()
for batch in batches:
    x_batch, y_batch = zip(*batch)
    x_packed, _ = pack(x_batch)
    y_packed, lengths = pack(y_batch)
    yhat_packed = rnn(x_packed)[0]

    r = torch.tensor(0, dtype=dtype, device=device)
    for y, yhat in zip(y_batch, unpack(y_packed, lengths)):
        r += torch.mean((y - yhat) ** 2)
    r.backward()
    g = torch.cat([w.grad.flatten() for w in rnn.parameters()])
    print(torch.sum(torch.isnan(g)))
    rnn.zero_grad()
end = time.time()
print(f"elapsed time for packed input: {end - start} secs")

In [None]:
dtype = torch.float32
device = torch.device("cpu")
rnn = nn.RNN(2, 2, num_layers=4, bias=True, batch_first=True)
rnn.to(device)

In [None]:
a = torch.tensor(np.random.randint(0, 9, (5, 2)), dtype=dtype, device=device)
b = torch.tensor(np.random.randint(0, 9, (4, 2)), dtype=dtype, device=device)
c = torch.tensor(np.random.randint(0, 9, (3, 2)), dtype=dtype, device=device)

In [None]:
batch = [a, b, c]
lengths = [len(x) for x in batch]
x, lengths = pack([a, b, c])
rnn(x)

In [None]:
y = rnn(x)[0]
y = unpack(y, lengths)
yhat = [rnn(z.unsqueeze(dim=0))[0] for z in batch]
[z - zhat for z, zhat in zip(y, yhat)]

In [None]:
batch = pad_sequence(batch, padding_value=np.nan, batch_first=True)
rnn(batch)