In [7]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
import numpy as np
import itertools

In [8]:
def flatten(l):
    return list(itertools.chain.from_iterable(l))

seqs = ['ghatmasala', 'nicela', 'chutpakodas']
vocab = ['<pad>'] + sorted(list(set(flatten(seqs))))

In [9]:
embedding_size = 3
embed = nn.Embedding(len(vocab), embedding_size)
lstm = nn.LSTM(embedding_size, 5)

vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs]
print("vectorized_seqs", vectorized_seqs)

vectorized_seqs [[5, 6, 1, 15, 10, 1, 14, 1, 9, 1], [11, 7, 2, 4, 9, 1], [2, 6, 16, 15, 13, 1, 8, 12, 3, 1, 14]]


In [10]:
print([x for x in map(len, vectorized_seqs)])

[10, 6, 11]


In [13]:
seq_lengths = torch.LongTensor([x for x in map(len, vectorized_seqs)])

seq_tensor = Variable(torch.zeros((len(vectorized_seqs),seq_lengths.max()))).long()
for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
    seq_tensor[idx, :seqlen] = torch.LongTensor(seq)
    
print("seq_tensor", seq_tensor)

seq_tensor tensor([[ 5,  6,  1, 15, 10,  1, 14,  1,  9,  1,  0],
        [11,  7,  2,  4,  9,  1,  0,  0,  0,  0,  0],
        [ 2,  6, 16, 15, 13,  1,  8, 12,  3,  1, 14]])


In [14]:
# SORT YOUR TENSORS BY LENGTH!
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
print("seq_tensor after sorting", seq_tensor)

seq_tensor after sorting tensor([[ 2,  6, 16, 15, 13,  1,  8, 12,  3,  1, 14],
        [ 5,  6,  1, 15, 10,  1, 14,  1,  9,  1,  0],
        [11,  7,  2,  4,  9,  1,  0,  0,  0,  0,  0]])


In [15]:
seq_tensor = seq_tensor.transpose(0, 1)  # (B,L,D) -> (L,B,D)
print("seq_tensor after transposing", seq_tensor.size(), seq_tensor.data)

seq_tensor after transposing torch.Size([11, 3]) tensor([[ 2,  5, 11],
        [ 6,  6,  7],
        [16,  1,  2],
        [15, 15,  4],
        [13, 10,  9],
        [ 1,  1,  1],
        [ 8, 14,  0],
        [12,  1,  0],
        [ 3,  9,  0],
        [ 1,  1,  0],
        [14,  0,  0]])


In [16]:
embeded_seq_tensor = embed(seq_tensor)
print("seq_tensor after embeding", embeded_seq_tensor.size(), seq_tensor.data)

seq_tensor after embeding torch.Size([11, 3, 3]) tensor([[ 2,  5, 11],
        [ 6,  6,  7],
        [16,  1,  2],
        [15, 15,  4],
        [13, 10,  9],
        [ 1,  1,  1],
        [ 8, 14,  0],
        [12,  1,  0],
        [ 3,  9,  0],
        [ 1,  1,  0],
        [14,  0,  0]])


In [17]:
# pack them up nicely
packed_input = pack_padded_sequence(
    embeded_seq_tensor, seq_lengths.cpu().numpy())
packed_output, (ht, ct) = lstm(packed_input)

# unpack your output
output, _ = pad_packed_sequence(packed_output)
print("Lstm output", output.size(), output.data)
# the final hidden state
print("Last output", ht[-1].size(), ht[-1].data)

Lstm output torch.Size([11, 3, 5]) tensor([[[-0.0695,  0.0401,  0.0505, -0.0021,  0.0710],
         [-0.0071, -0.0545, -0.0546, -0.0356,  0.1264],
         [-0.0253,  0.0235,  0.1605,  0.0986,  0.0624]],

        [[ 0.0298,  0.0532,  0.1270,  0.0362,  0.2236],
         [ 0.0401, -0.0080,  0.0937,  0.0267,  0.2443],
         [-0.1020, -0.1217, -0.1543, -0.0785,  0.0429]],

        [[-0.0251,  0.0275,  0.2128,  0.1310,  0.0806],
         [-0.0846, -0.0221, -0.0019, -0.0152,  0.0401],
         [-0.0877, -0.0169,  0.0343, -0.0340,  0.0819]],

        [[ 0.0377, -0.1428, -0.0945, -0.0795,  0.2093],
         [ 0.0072, -0.2050, -0.1701, -0.1549,  0.2076],
         [-0.1389, -0.1500, -0.2347, -0.2065, -0.0156]],

        [[-0.0033, -0.0921,  0.0249, -0.0430,  0.2613],
         [ 0.0516, -0.0291,  0.1964,  0.1199,  0.1719],
         [-0.1173, -0.1029, -0.0423, -0.1234,  0.0436]],

        [[-0.0978, -0.0696, -0.0083, -0.0679,  0.0387],
         [-0.0693, -0.0515, -0.0232,  0.0143,  0.0131],
   