In [1]:
import torch
from torch import LongTensor
from torch.nn import Embedding, LSTM
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# https://github.com/HarshTrivedi/packing-unpacking-pytorch-minimal-tutorial

In [2]:
seqs = ["long_str",
       "tiny",
       "medium"]

In [3]:
vocab = ["<pad>"] + sorted(set([c for seq in seqs for c in seq]))
vocab

['<pad>', '_', 'd', 'e', 'g', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'y']

In [4]:
vectorized_seqs = [[vocab.index(tok) for tok in seq] for seq in seqs]
vectorized_seqs

[[6, 9, 8, 4, 1, 11, 12, 10], [12, 5, 8, 14], [7, 3, 2, 5, 13, 7]]

In [5]:
embed = Embedding(len(vocab), 4)
lstm = LSTM(input_size=4, hidden_size=5, batch_first=True)

In [6]:
seq_lengths = LongTensor(list(map(len, vectorized_seqs)))
seq_lengths

tensor([8, 4, 6])

In [8]:
seq_tensor = Variable(torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long()
seq_tensor

tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]])

In [11]:
for idx, (seq, seq_len) in enumerate(zip(vectorized_seqs, seq_lengths)):
    seq_tensor[idx, :seq_len] = LongTensor(seq)
    
seq_tensor

tensor([[ 6,  9,  8,  4,  1, 11, 12, 10],
        [12,  5,  8, 14,  0,  0,  0,  0],
        [ 7,  3,  2,  5, 13,  7,  0,  0]])

In [13]:
embedded_seq_tensor = embed(seq_tensor)
embedded_seq_tensor # batch_size * max_seq_len * embedding_dim

tensor([[[ 1.3746,  1.7937,  0.4702, -0.1266],
         [ 0.0530,  0.5602, -1.2176, -1.3647],
         [ 1.1737, -0.8779, -1.5115,  2.4027],
         [-1.5345, -0.2037, -0.3288,  0.6264],
         [-0.0025,  0.9873, -0.0149,  0.9598],
         [-0.6353, -1.9604, -0.6830,  0.6624],
         [ 0.8536,  0.5150,  0.8952,  0.3396],
         [-1.2177, -1.5947,  0.2392, -1.0436]],

        [[ 0.8536,  0.5150,  0.8952,  0.3396],
         [-2.3059, -0.2358,  0.9411, -0.3228],
         [ 1.1737, -0.8779, -1.5115,  2.4027],
         [ 0.3702,  0.7784,  1.0046, -1.4837],
         [ 0.2173,  0.5291,  0.5392, -0.5921],
         [ 0.2173,  0.5291,  0.5392, -0.5921],
         [ 0.2173,  0.5291,  0.5392, -0.5921],
         [ 0.2173,  0.5291,  0.5392, -0.5921]],

        [[ 0.0230, -2.0035, -0.2308,  1.2062],
         [ 0.1585,  0.5577,  0.6678, -1.4421],
         [ 1.2074, -1.1181,  0.9672, -0.6232],
         [-2.3059, -0.2358,  0.9411, -0.3228],
         [ 1.4578, -0.1260, -0.9923,  0.0044],
         

In [16]:
# pack_padded_sequence with embedded instances and sequence lengths
packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), 
                                    batch_first=True, enforce_sorted=False)
packed_input.data # sum_batch_seq_len * embedding_dim

tensor([[ 1.3746,  1.7937,  0.4702, -0.1266],
        [ 0.0230, -2.0035, -0.2308,  1.2062],
        [ 0.8536,  0.5150,  0.8952,  0.3396],
        [ 0.0530,  0.5602, -1.2176, -1.3647],
        [ 0.1585,  0.5577,  0.6678, -1.4421],
        [-2.3059, -0.2358,  0.9411, -0.3228],
        [ 1.1737, -0.8779, -1.5115,  2.4027],
        [ 1.2074, -1.1181,  0.9672, -0.6232],
        [ 1.1737, -0.8779, -1.5115,  2.4027],
        [-1.5345, -0.2037, -0.3288,  0.6264],
        [-2.3059, -0.2358,  0.9411, -0.3228],
        [ 0.3702,  0.7784,  1.0046, -1.4837],
        [-0.0025,  0.9873, -0.0149,  0.9598],
        [ 1.4578, -0.1260, -0.9923,  0.0044],
        [-0.6353, -1.9604, -0.6830,  0.6624],
        [ 0.0230, -2.0035, -0.2308,  1.2062],
        [ 0.8536,  0.5150,  0.8952,  0.3396],
        [-1.2177, -1.5947,  0.2392, -1.0436]],
       grad_fn=<PackPaddedSequenceBackward>)

In [18]:
packed_output, (ht, ct) = lstm(packed_input)
packed_output.data # sum_batch_seq_len * embedding_dim

tensor([[ 0.0803, -0.0008,  0.1106, -0.0464,  0.0202],
        [ 0.1156, -0.0932,  0.1572, -0.0889,  0.2454],
        [ 0.1609, -0.0393,  0.1558, -0.0386,  0.0827],
        [ 0.0642,  0.1866,  0.0519, -0.0030,  0.0769],
        [ 0.2906,  0.0038,  0.0832,  0.0241,  0.1964],
        [ 0.2751,  0.0974,  0.2204, -0.0334,  0.0410],
        [-0.0676, -0.1134,  0.2056, -0.2809,  0.1253],
        [ 0.3752, -0.0331,  0.1202,  0.0687,  0.3952],
        [ 0.0017, -0.1094,  0.2199, -0.2948,  0.1525],
        [-0.1339, -0.0508,  0.2306, -0.3195,  0.0483],
        [ 0.4770,  0.0974,  0.1844,  0.0531,  0.1096],
        [ 0.2375, -0.0085,  0.1454, -0.0260,  0.1530],
        [-0.1744, -0.0583,  0.2565, -0.2723,  0.0005],
        [ 0.2957, -0.0123,  0.1131, -0.0499,  0.2820],
        [-0.0698, -0.0815,  0.2116, -0.1960,  0.2247],
        [ 0.2971, -0.0797,  0.2064, -0.1237,  0.3482],
        [ 0.1297, -0.0813,  0.2466, -0.0871,  0.2772],
        [ 0.3546,  0.0675,  0.1533,  0.0024,  0.2131]], grad_fn=<

In [19]:
packed_output.batch_sizes

tensor([3, 3, 3, 3, 2, 2, 1, 1])

In [20]:
# unpack
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
output

tensor([[[ 0.0803, -0.0008,  0.1106, -0.0464,  0.0202],
         [ 0.0642,  0.1866,  0.0519, -0.0030,  0.0769],
         [-0.0676, -0.1134,  0.2056, -0.2809,  0.1253],
         [-0.1339, -0.0508,  0.2306, -0.3195,  0.0483],
         [-0.1744, -0.0583,  0.2565, -0.2723,  0.0005],
         [-0.0698, -0.0815,  0.2116, -0.1960,  0.2247],
         [ 0.1297, -0.0813,  0.2466, -0.0871,  0.2772],
         [ 0.3546,  0.0675,  0.1533,  0.0024,  0.2131]],

        [[ 0.1609, -0.0393,  0.1558, -0.0386,  0.0827],
         [ 0.2751,  0.0974,  0.2204, -0.0334,  0.0410],
         [ 0.0017, -0.1094,  0.2199, -0.2948,  0.1525],
         [ 0.2375, -0.0085,  0.1454, -0.0260,  0.1530],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.1156, -0.0932,  0.1572, -0.0889,  0.2454],
         [ 0.2906,  0.0038,  0.0832,  0.0241

In [21]:
ht

tensor([[[ 0.3546,  0.0675,  0.1533,  0.0024,  0.2131],
         [ 0.2375, -0.0085,  0.1454, -0.0260,  0.1530],
         [ 0.2971, -0.0797,  0.2064, -0.1237,  0.3482]]],
       grad_fn=<IndexSelectBackward>)

In [None]:
# (batch_size X max_seq_len X embedding_dim) --> Sort by seqlen ---> (batch_size X max_seq_len X embedding_dim)
# (batch_size X max_seq_len X embedding_dim) --->      Pack     ---> (batch_sum_seq_len X embedding_dim)
# (batch_sum_seq_len X embedding_dim)        --->      LSTM     ---> (batch_sum_seq_len X hidden_dim)
# (batch_sum_seq_len X hidden_dim)           --->    UnPack     ---> (batch_size X max_seq_len X hidden_dim)