In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as f
from torch.autograd import Variable
import numpy as np

In [2]:
def create_embedding_layer(weights, padding_idx=None):
    weights = torch.FloatTensor(weights)
    len_, dims = weights.size()
    if not padding_idx: padding_idx = len_ - 1
    emb_layer = torch.nn.Embedding(len_, dims, padding_idx=padding_idx)
    emb_layer.load_state_dict({'weight': weights})
    return emb_layer

In [3]:
class LSTMSentenceEncoderParallel(nn.Module):
    def __init__(self, weights=None,
                       word_emb_size=100,
                       sen_emb_size=150,
                       sen_len=50,
                       batch_size=20):

        super(LSTMSentenceEncoderParallel, self).__init__()
        self.word_emb_size = word_emb_size
        self.sen_emb_size = sen_emb_size
        self.sen_len = sen_len
        self.batch_size = batch_size
        self.embeddings = create_embedding_layer(weights)
        self.sentenceEncoder = nn.LSTM(word_emb_size, sen_emb_size, batch_first=True)
        
    def forward(self, input):
        batch_size = input.size()[0]
        words = self.embeddings(input.view(-1)).view(-1, self.sen_len, self.word_emb_size)
        sentences = self.sentenceEncoder(words)[1][0].view(batch_size, -1, self.sen_emb_size)

        return sentences

In [4]:
class KimCNN(nn.Module):
    def __init__(self, #mode='singlechannel',
                 embedding_weights=None, embedding_dim=(400001, 300), padding_idx=400000,
                 kernels=[3,4,5], out_channels=100):
        super().__init__()
        
        self.embeddings = nn.Embedding(*embedding_dim) #, padding_idx=padding_idx)
        #if embedding_weights is not None: self.embeddings.load_state_dict({'weight': embedding_weights})
        
        self.convolutions = nn.ModuleList([nn.Conv2d(in_channels=1, 
                                                     out_channels=out_channels, 
                                                     kernel_size=(k, embedding_dim[1])
                                                     ) for k in kernels])
        
    def forward(self, input, dropout=False):
        embeddings = self.embeddings(input)
        embeddings = embeddings.unsqueeze(1)
        conv_output = [f.relu(conv(embeddings).squeeze(-1)) for conv in self.convolutions]  # (batch * out_channels * strides_taken)     
        pooled_output = [f.max_pool1d(item, item.size(2)).squeeze(-1) for item in conv_output]  # (batch * out_channels)
        #sentence_embeddings = torch.stack(pooled_output, dim=-1).reshape(input.size(0), -1)
        sentence_embeddings = torch.cat(pooled_output, 1)
        
        return sentence_embeddings

In [6]:
class EncoderDecoder(nn.Module):
    def __init__(self, weights,
                       word_emb_size=300,
                       sen_emb_size=350,
                       doc_emb_size=600,
                       sen_len=50,
                       batch_size=20,
                       output_dim=2,
                       reverse=False):
        super(EncoderDecoder, self).__init__()
        self.output_dim = output_dim
        
        # LSTM Sentence Encoder
        #self.sentenceEncoder = LSTMSentenceEncoderParallel(weights, word_emb_size, sen_emb_size, sen_len, batch_size)
        
        # CNN Sentence Encoder
        self.sentenceEncoder = KimCNN(embedding_weights=weights, 
                                      embedding_dim=weights.shape, 
                                      kernels=[1,2,3,4,5,6,7], 
                                      out_channels=50)
        self.documentEncoder = nn.LSTM(sen_emb_size, doc_emb_size, batch_first=True)
        self.documentDecoder = nn.LSTM(sen_emb_size, doc_emb_size, batch_first=True)
        self.classifier = nn.Linear(doc_emb_size, output_dim)
        self.reverse = reverse
        
    def forward(self, input):
        # LSTM Sentence Encoder
        #sentences = self.sentenceEncoder(input)
        
        # CNN sentence Encoder
        batch_size, no_sentences, sen_len = input.size()
        words = input.reshape(-1, sen_len)
        sentences = self.sentenceEncoder(words).reshape(batch_size, no_sentences, -1)
        
        no_sentences = sentences.size(1)
        if self.reverse:
            try:
                sentences = sentences.index_select(1, torch.linspace(no_sentences-1, 0, no_sentences).long().cuda())
            except:
                sentences = sentences.index_select(1, torch.linspace(no_sentences-1, 0, no_sentences).long())
                
        hidden, (document_h, document_c) = self.documentEncoder(sentences)
        decoder_outputs = self.documentDecoder(sentences, (document_h, document_c))[0]
        if self.output_dim == 1:
            output = torch.sigmoid(self.classifier(decoder_outputs))
            output = output.reshape(input.size()[0], -1)
        else:
            output = torch.log_softmax(self.classifier(decoder_outputs), dim=-1)
        return output