-
Notifications
You must be signed in to change notification settings - Fork 0
/
encoderRNN.py
35 lines (25 loc) · 1.56 KB
/
encoderRNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch , torch.nn as nn
import torch.nn.functional as F
class encoder_RNN(nn.Module):
def __init__(self, embedding_size, vocab_size, hidden_size, n_layers=1, bidirectional=False, dropout=0):
super(encoder_RNN , self).__init__()
self.embedding = nn.Embedding(vocab_size , embedding_size)
self.source_rnn = nn.GRU(embedding_size , hidden_size , dropout=dropout,
num_layers=n_layers, batch_first=True, bidirectional=bidirectional)
self.bidirectional = bidirectional
self.dropout = dropout
self.hidden_size = hidden_size
def forward(self, input_wv, seq_len):
ip = F.dropout(self.embedding(input_wv) , p=self.dropout , training=self.training)
packed_ip_seq = nn.utils.rnn.pack_padded_sequence(ip , seq_len , batch_first=True)
## https://pytorch.org/docs/stable/nn.html#torch.nn.GRU
rnn_output, last_hidden = self.source_rnn(packed_ip_seq)
encoding_output , _ = nn.utils.rnn.pad_packed_sequence(rnn_output , batch_first=True)
if self.bidirectional:
## Add contributions from both directions.
## Can also try torch.cat, but decoder hidden size should be doubled.
encoding_output = torch.add(encoding_output[:,:,:self.hidden_size] , encoding_output[:,:,self.hidden_size:])
last_hidden = torch.add(last_hidden[0:2:,:,:] , last_hidden[1:2:,:,:])
return encoding_output , last_hidden
if __name__ == "__main__":
raise NotImplementedError("Sub modules are not callable")