In [15]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [59]:
class LSTM(nn.Module):
    
    def __init__(self, num_inputs, num_outputs):
        
        super(LSTM, self).__init__()
        
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs

        self.lstm = nn.LSTM(input_size=num_inputs, hidden_size=16, batch_first=True)
        self.fc1 = nn.Linear(16, 128)
        self.fc2 = nn.Linear(128, num_outputs)

    def forward(self, x, hidden=None):
        
        # x [batch_size, sequence_length, num_inputs]
        out, hidden = self.lstm(x, hidden)
        out = torch.relu(out)
        out = self.fc1(out)
        out = torch.relu(out)
        out = self.fc2(out)
        out = torch.sigmoid(out)
        
        return out

In [60]:
onehot_source = np.eye(5)

In [30]:
lstm = LSTM(num_inputs=2, num_outputs=1)


In [61]:
a = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
b = torch.tensor([[7, 8]])
c = torch.tensor([[13, 14], [15, 16]])

In [62]:
data = nn.utils.rnn.pad_sequence([a, b, c], batch_first=True)
data

tensor([[[ 1,  2],
         [ 3,  4],
         [ 5,  6],
         [ 7,  8]],

        [[ 7,  8],
         [ 0,  0],
         [ 0,  0],
         [ 0,  0]],

        [[13, 14],
         [15, 16],
         [ 0,  0],
         [ 0,  0]]])

In [63]:
data = nn.utils.rnn.pack_padded_sequence(data, lengths=[4, 1, 2], batch_first=True, enforce_sorted=False)
data

PackedSequence(data=tensor([[ 1,  2],
        [13, 14],
        [ 7,  8],
        [ 3,  4],
        [15, 16],
        [ 5,  6],
        [ 7,  8]]), batch_sizes=tensor([3, 2, 1, 1]), sorted_indices=tensor([0, 2, 1]), unsorted_indices=tensor([0, 2, 1]))

In [64]:
lstm_layer = nn.LSTM(input_size=2, hidden_size=3, batch_first=True)
linear_layer = nn.Linear(3, 4)

In [65]:
o, (h, c) = lstm_layer(data.float())

In [66]:
o, _ = torch.nn.utils.rnn.pad_packed_sequence(sequence=o, batch_first=True)
o, o.shape

(tensor([[[ 7.3189e-02,  1.0773e-01, -6.9654e-02],
          [ 4.2135e-02,  7.9337e-02, -2.8355e-02],
          [ 1.4783e-02,  3.5969e-02, -9.6717e-03],
          [ 4.0341e-03,  1.2493e-02, -3.2697e-03]],
 
         [[ 2.7512e-04,  5.1486e-03, -2.1126e-05],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],
 
         [[ 3.9987e-07,  1.3631e-04, -3.0929e-09],
          [ 1.0617e-07,  4.9671e-05, -1.2481e-09],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00]]],
        grad_fn=<IndexSelectBackward>),
 torch.Size([3, 4, 3]))

In [25]:
linear_layer(o).shape

torch.Size([3, 4, 4])

In [26]:
a = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
b = torch.tensor([[7, 8]])
c = torch.tensor([[13, 14], [15, 16]])
data = nn.utils.rnn.pad_sequence([a, b, c], batch_first=True)

In [27]:
o, (h, c) = lstm_layer(data.float())

In [28]:
o

tensor([[[-1.3476e-02,  2.5510e-01,  3.8923e-02],
         [-6.8503e-03,  5.0041e-01,  2.9872e-02],
         [-2.3512e-03,  6.8747e-01, -1.4301e-02],
         [-7.0616e-04,  8.0632e-01, -7.6779e-02]],

        [[-1.4196e-06,  4.0640e-01, -1.7599e-02],
         [ 4.5898e-02,  2.9161e-01, -3.0840e-02],
         [ 7.3712e-02,  3.1488e-01, -3.2047e-02],
         [ 8.7233e-02,  3.3480e-01, -3.5870e-02]],

        [[ 1.5862e-09,  2.9619e-01, -1.1440e-02],
         [ 5.7016e-10,  4.8869e-01, -2.2434e-02],
         [ 4.2685e-02,  3.0739e-01, -3.6683e-02],
         [ 7.1743e-02,  3.2167e-01, -3.4562e-02]]],
       grad_fn=<TransposeBackward0>)