In [1]:
import torch
import torch.nn as nn

* Input dimension - represents the size of the input at each time step, e.g. input of dimension 5 will look like this [1, 3, 8, 2, 3]

* Hidden dimension - represents the size of the hidden state and cell state at each time step, e.g. the hidden state and cell state will both have the shape of [3, 5, 4] if the hidden dimension is 3

* Number of layers - the number of LSTM layers stacked on top of each other

In [2]:
input_dim = 5
hidden_dim = 10
num_layers = 1

In [4]:
lstm_layer = nn.LSTM(input_dim, hidden_dim, 
                     num_layers, batch_first=True)

Let's create some dummy data to see how the layer takes in the input. As our input dimension is 5, we have to create a tensor of the shape (1, 1, 5) which represents **(batch size, sequence length/rows, input dimension/cols)**.

Additionally, we'll have to initialize a hidden state and cell state for the LSTM as this is the first cell. The hidden state and cell state is stored in a tuple with the format (hidden_state, cell_state).

In [5]:
batch_size = 1
seq_len = 1

inp = torch.randn(batch_size, seq_len, input_dim)
hidden_state = torch.randn(num_layers, batch_size, hidden_dim)
cell_state = torch.randn(num_layers, batch_size, hidden_dim)
hidden = (hidden_state, cell_state)

In [14]:
print("Input shape: ", inp.shape)
print("Hidden state shape: ", hidden_state.shape)
print("cell state shape: ", cell_state.shape)

Input shape:  torch.Size([1, 1, 5])
Hidden state shape:  torch.Size([1, 1, 10])
cell state shape:  torch.Size([1, 1, 10])


In [6]:
out, hidden = lstm_layer(inp, hidden)
print("Output shape: ", out.shape)
print("Hidden: ", hidden)

Output shape:  torch.Size([1, 1, 10])
Hidden:  (tensor([[[-0.4332, -0.4763, -0.4782, -0.3733, -0.1707, -0.0073,  0.4739,
           0.2451,  0.1646,  0.2765]]], grad_fn=<StackBackward>), tensor([[[-0.6766, -1.5660, -0.7869, -1.2284, -0.2727, -0.0090,  1.1304,
           1.1287,  0.6611,  0.8783]]], grad_fn=<StackBackward>))


In the process above, we saw how the LSTM cell will process the input and hidden states at each time step. However in most cases, we'll be processing the input data in large sequences. The LSTM can also take in sequences of variable length and produce an output at each time step. Let's try changing the sequence length this time.