In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as f
from torch.autograd import Variable
import numpy as np

In [5]:
def create_embedding_layer(weights, padding_idx=None):
    weights = torch.FloatTensor(weights)
    len_, dims = weights.size()
    if not padding_idx: padding_idx = len_ - 1
    emb_layer = torch.nn.Embedding(len_, dims, padding_idx=padding_idx)
    emb_layer.load_state_dict({'weight': weights})
    return emb_layer

In [6]:
class LSTMSentenceEncoderParallel(nn.Module):
    def __init__(self, weights=None,
                       word_emb_size=100,
                       sen_emb_size=150,
                       sen_len=50,
                       batch_size=20):

        super(LSTMSentenceEncoderParallel, self).__init__()
        self.word_emb_size = word_emb_size
        self.sen_emb_size = sen_emb_size
        self.sen_len = sen_len
        self.batch_size = batch_size
        self.embeddings = create_embedding_layer(weights)
        self.sentenceEncoder = nn.LSTM(word_emb_size, sen_emb_size, batch_first=True)
        
    def forward(self, input):
        batch_size = input.size()[0]
        words = self.embeddings(input.view(-1)).view(-1, self.sen_len, self.word_emb_size)
        sentences = self.sentenceEncoder(words)[1][0].view(batch_size, -1, self.sen_emb_size)

        return sentences

In [8]:
class EncoderDecoder(nn.Module):
    def __init__(self, weights,
                       word_emb_size=100,
                       sen_emb_size=150,
                       doc_emb_size=300,
                       sen_len=50,
                       batch_size=20,
                       output_dim=2):
        super(EncoderDecoder, self).__init__()
        self.output_dim = output_dim
        self.sentenceEncoder = LSTMSentenceEncoderParallel(weights, word_emb_size, sen_emb_size, sen_len, batch_size)
        self.documentEncoder = nn.LSTM(sen_emb_size, doc_emb_size, batch_first=True)
        self.documentDecoder = nn.LSTM(sen_emb_size, doc_emb_size, batch_first=True)
        self.classifier = nn.Linear(doc_emb_size, output_dim)
        
    def forward(self, input):
        sentences = self.sentenceEncoder(input)
        hidden, (document_h, document_c) = self.documentEncoder(sentences)
        decoder_outputs = self.documentDecoder(sentences, (document_h, document_c))[0]
        if self.output_dim == 1:
            output = torch.sigmoid(self.classifier(decoder_outputs))
            output = output.reshape(input.size()[0], -1)
        else:
            output = torch.softmax(self.classifier(decoder_outputs), dim=-1)
        return output