In [1]:
import torch
from torch import LongTensor
from torch.nn import Embedding, LSTM
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [2]:
seqs = ['long_str',  # len = 8
        'tiny',      # len = 4
        'medium']    # len = 6

In [3]:
vocab = ['<pad>'] + sorted(set([char for seq in seqs for char in seq]))

In [4]:
vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs]
vectorized_seqs

[[6, 9, 8, 4, 1, 11, 12, 10], [12, 5, 8, 14], [7, 3, 2, 5, 13, 7]]

In [5]:
embed = Embedding(len(vocab), 4) # embedding_dim = 4
lstm = LSTM(input_size=4, hidden_size=5, batch_first=True) # input_dim = 4, hidden_dim = 5

In [6]:
seq_lengths = LongTensor(list(map(len, vectorized_seqs)))

In [7]:
seq_tensor = Variable(torch.zeros((len(vectorized_seqs), seq_lengths.max()))).long()

In [8]:
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
    seq_tensor[idx, :seqlen] = LongTensor(seq)

In [9]:
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]

In [10]:
embedded_seq_tensor = embed(seq_tensor)
embedded_seq_tensor

tensor([[[-0.0191,  0.8068,  0.9808,  0.2794],
         [ 1.2822,  0.1383, -0.6989,  0.4806],
         [-0.8367, -0.1104,  1.0734,  0.3339],
         [ 0.5013,  0.3875, -1.3025, -0.7538],
         [-0.3312, -0.9663, -1.8228, -1.0092],
         [-1.3694, -1.4217, -0.0769, -1.1875],
         [-0.5485, -0.9826, -0.5783, -0.0862],
         [-0.1736, -0.1003,  0.9977,  1.3328]],

        [[ 1.0520,  0.7034, -1.0505, -0.1413],
         [-0.5545,  0.4678,  1.8334, -1.0459],
         [ 0.5947,  0.1594,  2.2202, -1.7805],
         [-0.3083,  0.1105, -0.7959,  0.1952],
         [ 0.4516, -1.0611,  0.4607,  1.1831],
         [ 1.0520,  0.7034, -1.0505, -0.1413],
         [-0.7071,  0.3720,  2.1704, -0.2844],
         [-0.7071,  0.3720,  2.1704, -0.2844]],

        [[-0.5485, -0.9826, -0.5783, -0.0862],
         [-0.3083,  0.1105, -0.7959,  0.1952],
         [-0.8367, -0.1104,  1.0734,  0.3339],
         [-0.0707,  0.2808,  0.3322,  0.3542],
         [-0.7071,  0.3720,  2.1704, -0.2844],
         

In [20]:
packed_input = pack_padded_sequence(embedded_seq_tensor, seq_lengths.cpu().numpy(), batch_first=True)
packed_input

PackedSequence(data=tensor([[ 0.5991, -0.9387, -1.0265,  0.0279],
        [ 0.3411,  0.0770, -0.0044, -0.5827],
        [ 1.0636,  1.1258, -0.6755,  0.8060],
        [-1.7511, -0.6925,  0.6365,  1.2075],
        [-0.7073,  0.1721, -0.4367,  0.4844],
        [-0.7713, -0.8426,  0.8001, -0.4661],
        [ 0.6446,  1.7838, -0.7365,  0.5896],
        [-1.9668, -1.2308,  0.0957, -0.3078],
        [ 0.6446,  1.7838, -0.7365,  0.5896],
        [ 0.3751,  2.5702, -1.8247,  0.6942],
        [-0.7713, -0.8426,  0.8001, -0.4661],
        [-0.9686,  0.2134,  0.5014,  0.5733],
        [-0.2036, -0.0290, -0.6138, -1.3588],
        [-0.6084, -0.2529, -1.1476, -0.2269],
        [-0.4690, -2.2208,  1.4078,  2.0441],
        [ 0.3411,  0.0770, -0.0044, -0.5827],
        [ 1.0636,  1.1258, -0.6755,  0.8060],
        [-0.5070, -1.5001,  1.1586,  1.4968]],
       grad_fn=<PackPaddedSequenceBackward>), batch_sizes=tensor([3, 3, 3, 3, 2, 2, 1, 1]), sorted_indices=None, unsorted_indices=None)

In [24]:
packed_input.batch_sizes

tensor([3, 3, 3, 3, 2, 2, 1, 1])

In [25]:
packed_output, (ht, ct) = lstm(packed_input)

In [26]:
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)

In [27]:
output.shape

torch.Size([3, 8, 5])

In [28]:
input_sizes

tensor([8, 6, 4])

In [29]:
output

tensor([[[ 0.0526, -0.0341,  0.1088, -0.1267,  0.1887],
         [-0.0093, -0.2676, -0.1365, -0.0057,  0.0400],
         [-0.0768, -0.0222,  0.0690,  0.0608,  0.2656],
         [-0.0789,  0.1583,  0.1094,  0.0640,  0.4118],
         [-0.0426, -0.0508,  0.1951,  0.0337,  0.2560],
         [-0.1172, -0.2423, -0.0604, -0.0702,  0.0647],
         [-0.1896,  0.0283,  0.0827,  0.0109,  0.2962],
         [-0.1923, -0.1930, -0.1195, -0.0132,  0.0988]],

        [[-0.0535, -0.0966,  0.0429,  0.0427,  0.1012],
         [-0.0228, -0.1252,  0.0128,  0.0282,  0.1335],
         [ 0.0399, -0.4182, -0.1953, -0.1160, -0.0116],
         [ 0.0296, -0.4048, -0.2311, -0.0691, -0.0738],
         [ 0.1416, -0.3529,  0.0964, -0.1331,  0.0718],
         [ 0.0637, -0.2601,  0.1475, -0.0587,  0.1242],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.0875,  0.1523,  0.0884,  0.0499,  0.2663],
         [-0.0882, -0.1859, -0.1651,  0.0744