In [8]:
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F

In [9]:
class EncoderRNN(nn.Module):
  def __init__(self, input_size, hidden_size, dropout = 0.1):
    super(EncoderRNN, self).__init__()
    self.embedding = nn.Embedding(input_size, hidden_size)
    self.dropout = nn.Dropout(dropout)
    self.gru = nn.GRU(hidden_size, hidden_size, batch_first= True)

  def forward(self, x):
    embed = self.dropout(self.embedding(x))
    output, hidden = self.gru(embed)

    return output, hidden

In [10]:
class Attention(nn.Module):
  def __init__(self, hidden_size):
    super(Attention, self).__init__()
    self.w_enc = nn.Linear(hidden_size, hidden_size)
    self.w_dec = nn.Linear(hidden_size, hidden_size)
    self.V = nn.Linear(hidden_size, 1)

  def forward(self, hidden, x):
    hidden = hidden.permute(1, 0, 2)
    score = torch.tanh(self.w_enc(x) + self.w_dec(hidden))
    weights = torch.softmax(self.V(score), dim = 1)
    context = torch.bmm(weights.permute(0, 2, 1), x)

    return context, weights

In [11]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout = 0.1):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.attention = Attention(hidden_size)
        self.gru = nn.GRU(hidden_size* 2, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, input_step, hidden, encoder_outputs):
        embed = self.dropout(self.embedding(input_step))
        context, _ = self.attention(hidden, encoder_outputs)
        rnn_input = torch.cat((embed, context), dim=2)
        output, hidden = self.gru(rnn_input, hidden)
        prediction = self.out(output.squeeze(1))

        return prediction, hidden

In [12]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        batch_size = source.shape[0]
        target_len = target.shape[1]
        target_vocab_size = self.decoder.out.out_features
        outputs = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)
        encoder_outputs, hidden = self.encoder(source)
        decoder_input = target[:, 0].unsqueeze(1)

        for t in range(1, target_len):
            output, hidden = self.decoder(decoder_input, hidden, encoder_outputs)
            outputs[:, t, :] = output
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(1)
            decoder_input = target[:, t].unsqueeze(1) if teacher_force else top1.unsqueeze(1)

        return outputs

In [13]:
input_dim = 100
output_dim = 100
hidden_dim = 256
batch_size = 32
seq_len = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

enc = EncoderRNN(input_dim, hidden_dim).to(device)
dec = DecoderRNN(hidden_dim, output_dim).to(device)
model = Seq2Seq(enc, dec, device).to(device)

src = torch.randint(0, input_dim, (batch_size, seq_len)).to(device)
trg = torch.randint(0, output_dim, (batch_size, seq_len)).to(device)

output = model(src, trg)

print(f"input shape: {src.shape}")
print(f"output shape: {output.shape}")

input shape: torch.Size([32, 10])
output shape: torch.Size([32, 10, 100])
