In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('../')
from constant import *

In [19]:
class EncoderSentence(nn.Module):
    
    def __init__(self,word_size,word_dim, hidden_size, pretrained_word_embeds=None, output_type = 'sum'):
        super(EncoderSentence, self).__init__()
        
        self.output_type = output_type
        self.word_size = word_size
        self.word_dim = word_dim
        self.hidden_size = hidden_size
        self.pretrained_word_embeds = pretrained_word_embeds
        self.embedding = nn.Embedding(self.word_size,self.word_dim,padding_idx=0)
        self.lstm = nn.LSTM(self.word_dim,self.hidden_size,batch_first = True,bidirectional=True)
        self._init_weights()
        
    def forward(self,x,input_lengths):
        embedded = self.embedding(x)
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths,batch_first=True)
        outputs, hidden_cell = self.lstm(packed)
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs,batch_first=True)
        #return outputs,hidden_cell
        if self.output_type == 'sum':
            outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
        elif self.output_type =='concat':
            outputs = torch.cat((outputs[:, :, :self.hidden_size], outputs[:, : ,self.hidden_size:]),dim=2)
        else:
            raise NotImplementedError 
        return outputs,hidden_cell

    def _init_weights(self):
        if PRE_TRAINED_EMBEDDING or WORD2VEC_EMBEDDING :
            self.embedding.weight.data.copy_(torch.from_numpy(self.pretrained_word_embeds))
            if NON_TRAINABLE:
                self.embedding.weight.requires_grad = False
            else:
                self.embedding.weight.requires_grad = True
        else:
            nn.init.xavier_uniform_(self.embedding.weight.data)

## Test

In [20]:
from load_data_exp import *

In [21]:
for i in train_dataloader:
    break

In [22]:
i[3]

tensor([62, 62, 62, 54, 54, 54, 50, 50, 43, 43, 43, 41, 35, 35, 35, 33, 31, 30,
        30, 30, 24, 23, 23, 21, 20, 19, 17, 17, 15, 13, 10,  9])

In [23]:
enc = EncoderSentence(len(word_mapping)+1,WORD_DIM,128,pretrained_word_embeds,'sum')

In [24]:
e,f = enc(i[1],i[3])

In [25]:
e.size()

torch.Size([32, 62, 128])

In [18]:
e[0].size()

torch.Size([61, 256])

In [8]:
f[0].size()

torch.Size([2, 32, 128])

In [9]:

f[0, :, :] + f[1, :, :]

TypeError: tuple indices must be integers, not tuple

In [10]:
f[0, :, :]

TypeError: tuple indices must be integers, not tuple