In [6]:
import torch
import torch.nn as nn
import torch.utils.data as Data
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
import random

dtype = torch.FloatTensor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional=True)
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        src = src.transpose(0, 1)
        embedded = self.dropout(self.embedding(src)).transpose(0, 1)
        enc_output, enc_hidden = self.rnn(embedded)
        s = torch.tanh(self.fc(torch.cat((enc_hidden[-2, :, :], enc_hidden[-1, :, :]), dim=1)))

        return enc_output, s

In [5]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        self.attn = nn.Linear((enc_hid_dim * 2 + dec_hid_dim), dec_hid_dim, bias=False)
        self.v = nn.Linear(dec_hid_dim, 1, bias=False)

    def forward(self, s, enc_output):
        batch_size = enc_output.shape[1]
        src_len = enc_output.shape[0]

        s = s.unsqueeze(1).repeat(1, src_len, 1)
        enc_output = enc_output.transpose(0, 1)

        energy = torch.tanh(self.attn(torch.cat((s, enc_output), dim=1)))

        attention = self.v(energy).squeeze(2)

        return F.softmax(attention, dim=1)

In [7]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim

        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)

        enc_output, s = self.encoder(src)

        dec_input = trg[0, :]

        for t in range(1, trg_len):
            dec_output, s = self.decoder(dec_input, s, enc_output)

            outputs[t] = dec_output

            teacher_force = random.random() < teacher_forcing_ratio

            top1 = dec_output.argmax(1)

            dec_input = trg[t] if teacher_force else top1

        return outputs

In [8]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, dec_input, s, enc_output):
        dec_input = dec_input.unsqueeze(1) # dec_input = [batch_size, 1]

        embedded = self.dropout(self.embedding(dec_input)).transpose(0, 1) # embedded = [1, batch_size, emb_dim]

        a = self.attention(s, enc_output).unsqueeze(1)

        enc_output = enc_output.transpose(0, 1)

        c = torch.bmm(a, enc_output).transpose(0, 1)

        rnn_input = torch.cat((embedded, c), dim = 2)

        dec_output, dec_hidden = self.rnn(rnn_input, s.unsqueeze(0))

        embedded = embedded.squeeze(0)
        dec_output = dec_output.squeeze(0)
        c = c.squeeze(0)

        pred = self.fc_out(torch.cat((dec_output, c, embedded), dim = 1))

        return pred, dec_hidden.squeeze(0)