# Some useful modules for RNNs in pytorch

## Imports

In [1]:
import writefile_run

In [2]:
%%writefile_run wrappedLSTM.py


"""
Module which wraps an input and output module around an LSTM.
"""

import torch
import torch.nn as nn
import torch.nn.utils.rnn as utils

## Modules

### Wrapped LSTM with input and output modules

In [3]:
%%writefile_run wrappedLSTM.py -a


class WrappedLSTM(nn.Module):
    def __init__(self, lstm_input_size, lstm_hidden_size, input_module=None, output_module=None):
        """
        lstm_input_size should equal input_module output size
        lstm_hidden_size should equal output_module input size
        """
        super(WrappedLSTM, self).__init__()
        
        self.input_module = input_module
        self.output_module = output_module
        
        self.lstm = nn.LSTM(input_size=lstm_input_size,hidden_size=lstm_hidden_size)
        
    def forward(self,hidden, *packed_input):
        """
        Applies input module to data in packed_inputs, 
        then applies the LSTM layers,
        Applied output module to output data of rnn,
        
        Returns packed output sequence and final hidden state of LSTM.
        """
        batch_sizes = packed_input[0].batch_sizes
        
        if self.input_module != None:
            rnn_input = self.input_module(*[p.data for p in packed_input])
            rnn_input = utils.PackedSequence(rnn_input,batch_sizes)
        else:
            rnn_input = packed_input[0]
            
        rnn_output, hidden = self.lstm(rnn_input,hidden)
        
        if self.output_module != None:
            output = self.output_module(rnn_output.data)
            output = utils.PackedSequence(output,batch_sizes)
        else:
            output = rnn_output
        
        return output, hidden

### Testing

In [4]:
class MyLSTMmodel(nn.Module):
    def __init__(self,num_embeddings, embedding_size, lstm_hidden_size):
        """
        num_embeddings sized input and outputs.
        lstm output interfaced to final output through Dense layer.
        """
        super(MyLSTMmodel, self).__init__()
        
        embed = nn.Embedding(num_embeddings, embedding_size)
        
        # hidden to output
        h2o = nn.Linear(lstm_hidden_size, num_embeddings)
        
        self.wrappedlstm = WrappedLSTM(embedding_size, lstm_hidden_size,embed, h2o)
        
        self.hidden_size = lstm_hidden_size
        
    def forward(self, packed_input, hidden):
        packed_output, hidden = self.wrappedlstm(packed_input, hidden)
        return packed_output, hidden
    
    def initHidden(self, num_seqs):
        return torch.rand(1,num_seqs,self.hidden_size),torch.rand(1,num_seqs,self.hidden_size)

In [5]:
lstm = MyLSTMmodel(10,100,100)

In [6]:
a = torch.LongTensor([0,4,1,1,2])
b = torch.LongTensor([3,7,6,5])
c = torch.LongTensor([8,1])
sample_input = utils.pack_sequence([a,b,c])
sample_hidden = lstm.initHidden(3)

In [7]:
out, hidden = lstm(sample_hidden,sample_input)

In [8]:
out

PackedSequence(data=tensor([[-0.0884, -0.1806,  0.0011,  0.0866, -0.2564,  0.0708,  0.0576,
         -0.0129, -0.1135,  0.2793],
        [-0.0395, -0.1614,  0.1557,  0.0780, -0.2334, -0.0519,  0.2044,
         -0.1434, -0.1146,  0.1510],
        [-0.1149, -0.1372,  0.0092,  0.1239, -0.0211, -0.0837,  0.0748,
          0.1177, -0.1419,  0.1817],
        [ 0.0140, -0.2058,  0.0085,  0.0069, -0.1778, -0.1465,  0.1653,
         -0.1534, -0.0208,  0.0347],
        [-0.0557, -0.2501,  0.0651, -0.0138, -0.2018, -0.1098,  0.1882,
         -0.0545, -0.0771, -0.0480],
        [-0.1886, -0.0810,  0.0456,  0.2567,  0.0308, -0.0172,  0.1563,
          0.1002,  0.0006,  0.0404],
        [-0.1456, -0.0798,  0.0083,  0.1486, -0.0440, -0.0441,  0.1614,
         -0.0626,  0.1085, -0.0141],
        [-0.0750, -0.0297,  0.1732,  0.0842, -0.0558, -0.1382,  0.1308,
         -0.0385, -0.1046,  0.0234],
        [-0.1973, -0.0132,  0.0204,  0.2051, -0.0107,  0.0015,  0.1529,
          0.0130,  0.1564, -0.0597],