# Some useful modules for RNNs in pytorch

## Imports

In [1]:
import writefile_run
filename = '../package/pytorch_utils/wrapped_lstm.py'

In [2]:
%%writefile_run $filename


"""
Module which wraps an input and output module around an LSTM.
"""

import torch
import torch.nn as nn
import torch.nn.utils.rnn as utils

## Modules

### Wrapped LSTM with input and output modules

In [3]:
%%writefile_run $filename -a


class WrappedLSTM(nn.Module):
    def __init__(self, lstm_input_size, lstm_hidden_size, input_module=None, output_module=None, num_layers=1):
        """
        lstm_input_size should equal input_module output size
        lstm_hidden_size should equal output_module input size
        """
        super(WrappedLSTM, self).__init__()
        
        self.input_module = input_module
        self.output_module = output_module
        
        self.lstm = nn.LSTM(input_size=lstm_input_size,hidden_size=lstm_hidden_size, num_layers=num_layers)
        
    def forward(self,hidden, *packed_input):
        """
        Applies input module to data in packed_inputs, 
        then applies the LSTM layers,
        Applied output module to output data of rnn,
        
        Returns packed output sequence and final hidden state of LSTM.
        """
        batch_sizes = packed_input[0].batch_sizes
        
        if self.input_module != None:
            rnn_input = self.input_module(*[p.data for p in packed_input])
            rnn_input = utils.PackedSequence(rnn_input,batch_sizes)
        else:
            rnn_input = packed_input[0]
            
        rnn_output, hidden = self.lstm(rnn_input,hidden)
        
        if self.output_module != None:
            output = self.output_module(rnn_output.data)
            output = utils.PackedSequence(output,batch_sizes)
        else:
            output = rnn_output
        
        return output, hidden

### Testing

In [4]:
class MyLSTMmodel(nn.Module):
    def __init__(self,num_embeddings, embedding_size, lstm_hidden_size, num_layers=1):
        """
        num_embeddings sized input and outputs.
        lstm output interfaced to final output through Dense layer.
        """
        super(MyLSTMmodel, self).__init__()
        
        embed = nn.Embedding(num_embeddings, embedding_size)
        
        # hidden to output
        h2o = nn.Linear(lstm_hidden_size, num_embeddings)
        
        self.wrappedlstm = WrappedLSTM(embedding_size, lstm_hidden_size,embed, h2o,num_layers=num_layers)
        
        self.hidden_size = lstm_hidden_size
        self.num_layers = num_layers
        
    def forward(self, packed_input, hidden):
        packed_output, hidden = self.wrappedlstm(packed_input, hidden)
        return packed_output, hidden
    
    def initHidden(self, num_seqs):
        return (torch.rand(self.num_layers,num_seqs,self.hidden_size),
                torch.rand(self.num_layers,num_seqs,self.hidden_size))

In [5]:
lstm = MyLSTMmodel(10,100,100,num_layers=2)

In [6]:
a = torch.LongTensor([0,4,1,1,2])
b = torch.LongTensor([3,7,6,5])
c = torch.LongTensor([8,1])
sample_input = utils.pack_sequence([a,b,c])
sample_hidden = lstm.initHidden(3)

In [7]:
out, hidden = lstm(sample_hidden,sample_input)

In [8]:
out

PackedSequence(data=tensor([[ 0.0977,  0.0047,  0.1723, -0.0532,  0.2271,  0.2114, -0.0280,
         -0.0200, -0.0315,  0.0284],
        [ 0.1886,  0.0175,  0.1198, -0.0272,  0.2438,  0.2450,  0.0635,
          0.0546,  0.0195,  0.0870],
        [ 0.0810, -0.0442,  0.1582, -0.0489,  0.2122,  0.1573, -0.0846,
         -0.0411, -0.0178,  0.0773],
        [ 0.0916,  0.0159,  0.1076, -0.0891,  0.1828,  0.1796, -0.0201,
          0.0178,  0.0028,  0.0753],
        [ 0.1512, -0.0135,  0.1055, -0.0557,  0.1852,  0.1878, -0.0095,
          0.0680,  0.0164,  0.1166],
        [ 0.0697, -0.0740,  0.1117, -0.1219,  0.1426,  0.1006, -0.0587,
         -0.0058, -0.0324,  0.1461],
        [ 0.0925,  0.0025,  0.0698, -0.1073,  0.1409,  0.1123, -0.0376,
          0.0218,  0.0152,  0.1058],
        [ 0.1072, -0.0421,  0.0735, -0.0550,  0.1302,  0.1511, -0.0303,
          0.0637,  0.0422,  0.1428],
        [ 0.1013, -0.0092,  0.0460, -0.1180,  0.1158,  0.0621, -0.0503,
          0.0202,  0.0229,  0.1225],

In [9]:
hidden[0].shape

torch.Size([2, 3, 100])