# Some useful modules for RNNs in pytorch

## Imports

In [1]:
import writefile_run

In [2]:
%%writefile_run wrappedLSTM.py


"""
Module which wraps an input and output module around an LSTM.
"""

import torch
import torch.nn as nn
import torch.nn.utils.rnn as utils

## Modules

### Wrapped LSTM with input and output modules

In [3]:
%%writefile_run wrappedLSTM.py -a


class WrappedLSTM(nn.Module):
    def __init__(self, lstm_input_size, lstm_hidden_size, input_module=None, output_module=None):
        """
        lstm_input_size should equal input_module output size
        lstm_hidden_size should equal output_module input size
        """
        super(WrappedLSTM, self).__init__()
        
        self.input_module = input_module
        self.output_module = output_module
        
        self.lstm = nn.LSTM(input_size=lstm_input_size,hidden_size=lstm_hidden_size)
        
    def forward(self,hidden, *packed_input):
        """
        Applies input module to data in packed_inputs, 
        then applies the LSTM layers,
        Applied output module to output data of rnn,
        
        Returns packed output sequence and final hidden state of LSTM.
        """
        batch_sizes = packed_input[0].batch_sizes
        
        if self.input_module != None:
            rnn_input = self.input_module(*[p.data for p in packed_input])
            rnn_input = utils.PackedSequence(rnn_input,batch_sizes)
        else:
            rnn_input = packed_input[0]
            
        rnn_output, hidden = self.lstm(rnn_input,hidden)
        
        if self.output_module != None:
            output = self.output_module(rnn_output.data)
            output = utils.PackedSequence(output,batch_sizes)
        else:
            output = rnn_output
        
        return output, hidden

### Testing

In [4]:
class MyLSTMmodel(nn.Module):
    def __init__(self,num_embeddings, embedding_size, lstm_hidden_size):
        """
        num_embeddings sized input and outputs.
        lstm output interfaced to final output through Dense layer.
        """
        super(MyLSTMmodel, self).__init__()
        
        embed = nn.Embedding(num_embeddings, embedding_size)
        
        # hidden to output
        h2o = nn.Linear(lstm_hidden_size, num_embeddings)
        
        self.wrappedlstm = WrappedLSTM(embedding_size, lstm_hidden_size,embed, h2o)
        
        self.hidden_size = lstm_hidden_size
        
    def forward(self, packed_input, hidden):
        packed_output, hidden = self.wrappedlstm(packed_input, hidden)
        return packed_output, hidden
    
    def initHidden(self, num_seqs):
        return torch.rand(1,num_seqs,self.hidden_size),torch.rand(1,num_seqs,self.hidden_size)

In [5]:
lstm = MyLSTMmodel(10,100,100)

In [6]:
a = torch.LongTensor([0,4,1,1,2])
b = torch.LongTensor([3,7,6,5])
c = torch.LongTensor([8,1])
sample_input = utils.pack_sequence([a,b,c])
sample_hidden = lstm.initHidden(3)

In [7]:
out, hidden = lstm(sample_hidden,sample_input)

In [8]:
out

PackedSequence(data=tensor([[-0.1198,  0.2623, -0.1670,  0.1608, -0.0619, -0.0063, -0.1187,
         -0.0892,  0.1493, -0.0194],
        [-0.1502,  0.2648, -0.0497,  0.1750,  0.0225,  0.0755, -0.1130,
         -0.1128,  0.1811, -0.1531],
        [-0.0749,  0.2787, -0.0570,  0.1000,  0.0272,  0.0323, -0.0653,
         -0.1151,  0.1706, -0.0051],
        [-0.1785,  0.1828, -0.1452,  0.0981, -0.0155, -0.0578, -0.1371,
         -0.0514,  0.2532, -0.0751],
        [-0.1527,  0.2111, -0.1164,  0.0885,  0.0797,  0.0234, -0.0679,
         -0.0652,  0.1986,  0.0064],
        [-0.0912,  0.2530, -0.0036,  0.0616,  0.1297,  0.0063, -0.0319,
         -0.1574,  0.1323, -0.0517],
        [-0.0825,  0.2068, -0.0488,  0.0304,  0.1356,  0.0468, -0.0999,
         -0.1165,  0.1842, -0.1161],
        [-0.0806,  0.2818, -0.1655, -0.1356,  0.1483, -0.0135, -0.0359,
         -0.0588,  0.2090,  0.0693],
        [-0.0647,  0.1890, -0.0069,  0.0352,  0.2219,  0.0822, -0.0694,
         -0.0774,  0.0981, -0.1270],