* Change padding idx
* CNN sentence encoder

In [None]:
import numpy as np

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as f

In [None]:
DEFAULT_VOCAB_LEN=401004
DEFAULT_WORD_DIM=100

In [None]:
def create_embedding_layer(weights, word_emb_size=100, non_trainable=False):
    if weights is not None:
        emb_len, word_dims = weights.size()
        emb_layer = torch.nn.Embedding(emb_len, word_dims, padding_idx=401001)
        emb_layer.load_state_dict({'weight': weights})
    else:
        emb_layer = torch.nn.Embedding(DEFAULT_VOCAB_LEN,
                                       word_emb_size, padding_idx=DEFAULT_VOCAB_LEN-1)
        emb_len, word_dims = DEFAULT_VOCAB_LEN, DEFAULT_WORD_DIM
    if non_trainable:
        emb_layer.weight.requires_grad = False
    return emb_layer, emb_len, word_dims

In [None]:
class LSTMSentenceEncoderParallel(nn.Module):
    '''
    INPUT: 3D Tensor of word Ids (batch_size * no_sentences_per_doc * no_words_per_sen)
    OUTPUT: 3D Tensor of sentence Embeddings (batch_size * no_sentence_per_doc * sen_emb_size)
    '''
    def __init__(self, weights=None,
                       word_emb_size=100,
                       sen_emb_size=150,
                       sen_len=50,
                       batch_size=20,
                       bidirectional=True):

        super(LSTMSentenceEncoderParallel, self).__init__()
        self.word_emb_size = word_emb_size
        self.sen_emb_size = sen_emb_size
        self.sen_len = sen_len
        self.batch_size = batch_size
        self.embeddings, vocab, emb_len = create_embedding_layer(weights, word_emb_size)
        if bidirectional: self.sen_emb_size *= 2
        self.sentenceEncoder = nn.LSTM(word_emb_size, sen_emb_size, batch_first=True, bidirectional=bidirectional)
        
    def forward(self, input):
        words = self.embeddings(input.view(-1)).view(-1, self.sen_len, self.word_emb_size)
        sentences = self.sentenceEncoder(words)[1][0].view(self.batch_size, -1, self.sen_emb_size)

        return sentences

In [None]:
class SourceBias(nn.Module):
    '''
    NOTE: Forward Prop is not parallel
    Transforms each sentence according to the source its cited from. 
    If a sentence has no such citations, default transformation is used.
    '''
    def __init__(self, sen_emb_size, no_urls, non_linearity=torch.tanh):
        super(SourceBias, self).__init__()
        self.trans = [nn.Linear(sen_emb_size, sen_emb_size) for _ in range(no_urls)]
        self.non_linearity = non_linearity
        
    def forward(self, input, urls):
        d = input.reshape(-1, input.size(0))
        u = urls.reshape(-1)

        output = []
        for emb, url in zip(d, u):
            output.append(self.trans[url](emb))
        output = torch.stack(output, 0)
        
        return self.non_linearity(output).reshape(input.size())

In [None]:
class Attention(nn.Module):
    pass

In [None]:
class 