# Some useful modules for RNNs in pytorch

## Imports

In [1]:
import writefile_run

In [2]:
%%writefile_run wrappedLSTM.py


"""
Module which wraps an input and output module around an LSTM.
"""

import torch
import torch.nn as nn
import torch.nn.utils.rnn as utils

## Modules

### Wrapped LSTM with input and output modules

In [3]:
%%writefile_run wrappedLSTM.py -a


class WrappedLSTM(nn.Module):
    def __init__(self, lstm_input_size, lstm_hidden_size, input_module=None, output_module=None, num_layers=1):
        """
        lstm_input_size should equal input_module output size
        lstm_hidden_size should equal output_module input size
        """
        super(WrappedLSTM, self).__init__()
        
        self.input_module = input_module
        self.output_module = output_module
        
        self.lstm = nn.LSTM(input_size=lstm_input_size,hidden_size=lstm_hidden_size, num_layers=num_layers)
        
    def forward(self,hidden, *packed_input):
        """
        Applies input module to data in packed_inputs, 
        then applies the LSTM layers,
        Applied output module to output data of rnn,
        
        Returns packed output sequence and final hidden state of LSTM.
        """
        batch_sizes = packed_input[0].batch_sizes
        
        if self.input_module != None:
            rnn_input = self.input_module(*[p.data for p in packed_input])
            rnn_input = utils.PackedSequence(rnn_input,batch_sizes)
        else:
            rnn_input = packed_input[0]
            
        rnn_output, hidden = self.lstm(rnn_input,hidden)
        
        if self.output_module != None:
            output = self.output_module(rnn_output.data)
            output = utils.PackedSequence(output,batch_sizes)
        else:
            output = rnn_output
        
        return output, hidden

### Testing

In [4]:
class MyLSTMmodel(nn.Module):
    def __init__(self,num_embeddings, embedding_size, lstm_hidden_size, num_layers=1):
        """
        num_embeddings sized input and outputs.
        lstm output interfaced to final output through Dense layer.
        """
        super(MyLSTMmodel, self).__init__()
        
        embed = nn.Embedding(num_embeddings, embedding_size)
        
        # hidden to output
        h2o = nn.Linear(lstm_hidden_size, num_embeddings)
        
        self.wrappedlstm = WrappedLSTM(embedding_size, lstm_hidden_size,embed, h2o,num_layers=num_layers)
        
        self.hidden_size = lstm_hidden_size
        self.num_layers = num_layers
        
    def forward(self, packed_input, hidden):
        packed_output, hidden = self.wrappedlstm(packed_input, hidden)
        return packed_output, hidden
    
    def initHidden(self, num_seqs):
        return (torch.rand(self.num_layers,num_seqs,self.hidden_size),
                torch.rand(self.num_layers,num_seqs,self.hidden_size))

In [5]:
lstm = MyLSTMmodel(10,100,100,num_layers=2)

In [6]:
a = torch.LongTensor([0,4,1,1,2])
b = torch.LongTensor([3,7,6,5])
c = torch.LongTensor([8,1])
sample_input = utils.pack_sequence([a,b,c])
sample_hidden = lstm.initHidden(3)

In [7]:
out, hidden = lstm(sample_hidden,sample_input)

In [8]:
out

PackedSequence(data=tensor([[ 4.5356e-02, -1.3117e-02, -7.5379e-02,  1.8073e-01,  2.2724e-03,
         -2.0360e-01, -9.1893e-02, -9.3743e-02,  1.4752e-01, -1.1098e-01],
        [ 1.9661e-02, -1.3342e-01, -9.6478e-03,  1.0177e-01,  1.2161e-02,
         -2.2733e-01, -9.8890e-02, -4.4546e-02,  9.2363e-02,  8.4648e-02],
        [ 5.2866e-02, -1.0012e-01, -6.6004e-02,  2.2615e-01,  6.1098e-02,
         -1.9439e-01, -1.0957e-01, -7.8684e-02,  7.6870e-02, -1.3136e-01],
        [ 1.4732e-02, -1.7848e-02, -9.9732e-03,  1.4387e-01,  3.3322e-02,
         -1.6140e-01, -7.7387e-02, -4.4499e-02,  1.0875e-01, -9.5107e-02],
        [-6.1885e-02, -6.9286e-02,  4.2829e-02,  9.8437e-02,  2.0012e-02,
         -1.4493e-01, -9.0458e-02, -4.3558e-02,  7.0776e-02,  3.8683e-02],
        [ 1.1775e-04, -8.9769e-02, -1.9824e-02,  1.5231e-01,  3.3981e-02,
         -1.7862e-01, -8.4037e-02, -5.4298e-02,  6.8807e-02, -1.3086e-01],
        [ 1.0023e-02, -2.5405e-02,  1.0689e-02,  1.0071e-01,  3.5320e-02,
         -1.

In [9]:
hidden[0].shape

torch.Size([2, 3, 100])