## Sequence

In [1]:
import torch
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

In [2]:
def argsort(seq):
    return sorted(range(len(seq)), key=seq.__getitem__)

In [3]:
class SeqBatch:
    def __init__(self, seqs, dtype=None, device=None):
        self.seqs = [torch.tensor(s, dtype=dtype, device=device) for s in seqs]
        self.lens = [len(x) for x in seqs]
        self.ind = torch.tensor(argsort(self.lens)[::-1], dtype=torch.long)
        self.inv = torch.tensor(argsort(self.ind), dtype=torch.long)

    def packed(self):
        padded = self.padded()
        return pack_padded_sequence(padded.index_select(1, self.ind),
                                     sorted(self.lens, reverse=True))

    def padded(self):
        return pad_sequence(self.seqs)

    def reorder(self, batch, dim=0):
        return batch.index_select(dim, self.inv)

In [4]:
batch = SeqBatch([[1, 2], [3, 4, 5], [4]])
packed = batch.packed()
packed

PackedSequence(data=tensor([3, 1, 4, 4, 2, 5]), batch_sizes=tensor([3, 2, 1]))

In [5]:
batch.reorder(packed.data[:3])

tensor([1, 3, 4])