In [1]:
import torch

In [2]:
torch.__version__

'2.0.0'

In [3]:
class BLSTM(torch.nn.Module):
    def __init__(self,input_size,hidden_units,n_layers):
        super(BLSTM,self).__init__()
        self.lstm = torch.nn.LSTM(input_size,hidden_units,n_layers,bidirectional = True,batch_first = True)
        
    def forward(self,input_tensor):
        output,(hidden,cell) = self.lstm(input_tensor)
        return output,hidden,cell

In [4]:
class ULSTM(torch.nn.Module):
    def __init__(self,input_size,hidden_units,n_layers):
        super(ULSTM,self).__init__()
        self.lstm = torch.nn.LSTM(input_size,hidden_units,n_layers,bidirectional = False,batch_first = True)
        
    def forward(self,input_tensor):
        output,(hidden,cell) = self.lstm(input_tensor)
        return output,hidden,cell

In [9]:
hidden_units = 5 
n_layers = 1 
timesteps = 3
batch_size =2
input_size = 7

In [11]:
input_tensor = torch.randn(size=(batch_size,timesteps,input_size)) 
input_tensor.shape # 2 batches,3 tokens,7 features

torch.Size([2, 3, 7])

In [12]:
input_tensor

tensor([[[-4.6078e-01,  3.1411e-01, -3.9345e-01,  1.1954e-02, -1.9683e+00,
          -8.4002e-01, -8.3430e-01],
         [-7.0688e-01, -1.5562e-03, -1.4663e-01,  5.0143e-02, -1.3607e+00,
           1.1165e+00,  2.6314e-01],
         [ 6.0850e-01, -1.1142e+00, -9.4470e-02, -1.6163e-01,  2.6845e-01,
           1.8741e-01,  1.9345e-01]],

        [[ 2.3831e+00, -1.9397e-01, -3.5056e-01,  1.2811e-03,  8.9334e-01,
           1.6937e+00,  9.1503e-01],
         [ 6.7038e-01,  5.4674e-01, -1.1532e+00, -3.5611e-01,  4.9124e-02,
          -4.6650e-01,  9.5078e-01],
         [-7.5790e-01,  4.3687e-01, -1.0456e+00, -9.8695e-01,  1.0886e+00,
           2.9417e-01,  1.2182e+00]]])

In [20]:
blstm = BLSTM(input_size,hidden_units,n_layers)
ulstm = ULSTM(input_size,hidden_units,n_layers)

In [21]:
blstm

BLSTM(
  (lstm): LSTM(7, 5, batch_first=True, bidirectional=True)
)

In [22]:
ulstm

ULSTM(
  (lstm): LSTM(7, 5, batch_first=True)
)

In [23]:
ulstm = ULSTM(input_size,hidden_units,n_layers=1)

In [24]:
lstm_output,lstm_hidden,lstm_cell = ulstm(input_tensor)

In [25]:
lstm_output # at each time whats the lstm output

tensor([[[-0.0696,  0.0567,  0.3007,  0.1432, -0.1309],
         [-0.1405,  0.2685,  0.4282, -0.0762, -0.0924],
         [-0.1389,  0.0556,  0.1144,  0.0413, -0.1802]],

        [[ 0.0047,  0.1243, -0.0229, -0.2999,  0.0050],
         [-0.0381, -0.0792,  0.1330,  0.0447, -0.1051],
         [-0.1483, -0.1258,  0.2783,  0.0280, -0.1503]]],
       grad_fn=<TransposeBackward0>)

In [27]:
lstm_hidden  # represents the last hidden states of both records in the batch

tensor([[[-0.1389,  0.0556,  0.1144,  0.0413, -0.1802],
         [-0.1483, -0.1258,  0.2783,  0.0280, -0.1503]]],
       grad_fn=<StackBackward0>)

In [28]:
lstm_cell

tensor([[[-0.3425,  0.0870,  0.2907,  0.0770, -0.2978],
         [-0.2746, -0.1988,  0.4332,  0.0504, -0.5045]]],
       grad_fn=<StackBackward0>)

## Bidirectional/Unidirectional with n_layers =1 

In [16]:
lstm_output,lstm_hidden,lstm_cell = blstm(input_tensor)
lstm_output.shape,lstm_hidden.shape,lstm_cell.shape

(torch.Size([2, 3, 10]), torch.Size([2, 2, 5]), torch.Size([2, 2, 5]))

In [117]:
m = lstm_output*2 + lstm_output
m.shape

torch.Size([1, 3, 10])

In [106]:
lstm_output,lstm_hidden,lstm_cell = ulstm(input_tensor)
lstm_output.shape,lstm_hidden.shape,lstm_cell.shape

(torch.Size([1, 3, 5]), torch.Size([1, 1, 5]), torch.Size([1, 1, 5]))

## Bidirectional/Unidirectional with n_layers =2

In [107]:
blstm = BLSTM(input_size,hidden_units,n_layers=2)
ulstm = ULSTM(input_size,hidden_units,n_layers=2)

In [108]:
lstm_output,lstm_hidden,lstm_cell = blstm(input_tensor)
lstm_output.shape,lstm_hidden.shape,lstm_cell.shape

(torch.Size([1, 3, 10]), torch.Size([4, 1, 5]), torch.Size([4, 1, 5]))

In [109]:
lstm_output,lstm_hidden,lstm_cell = ulstm(input_tensor)
lstm_output.shape,lstm_hidden.shape,lstm_cell.shape

(torch.Size([1, 3, 5]), torch.Size([2, 1, 5]), torch.Size([2, 1, 5]))