<a href="https://colab.research.google.com/github/pooriaazami/deep_learning_class_notebooks/blob/main/13_Machine_Translation_From_Scratch_Part_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [36]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [37]:
class Encoder(nn.Module):
  def __init__(self, num_tokens, embedding_dim, latent_dim):
    super().__init__()

    self.embedding = nn.Embedding(num_embeddings=num_tokens, embedding_dim=embedding_dim)
    self.rnn = nn.GRU(input_size=embedding_dim, hidden_size=latent_dim, num_layers=1, batch_first=True, bidirectional=True)

    self.latent_dim = latent_dim

  def forward(self, x):
    x = self.embedding(x)
    batch_size, _, _ = x.size()
    h_0 = torch.zeros(2, batch_size, self.latent_dim).to(DEVICE)
    outputs, context_vector = self.rnn(x, h_0)

    return context_vector, outputs

In [38]:
encoder = Encoder(num_tokens=100, embedding_dim=16, latent_dim=64).to(DEVICE)

In [39]:
batch_size = 10
seq_length = 30

test_input = torch.zeros(batch_size, seq_length, dtype=torch.int64)
test_output, _ = encoder(test_input)

test_output.shape

torch.Size([2, 10, 64])

In [40]:
class AttentionBlock(nn.Module):
  def __init__(self, hidden_dim):
    super().__init__()
    self.W = nn.Linear(hidden_dim, hidden_dim, bias=False)
    self.U = nn.Linear(hidden_dim, hidden_dim, bias=False)
    self.V = nn.Linear(hidden_dim, 1, bias=False)

  def forward(self, query, keys):
    QK = self.W(query).unsqueeze(1) + self.U(keys)
    QK = torch.tanh(QK)
    scores = self.V(QK)

    scores = scores.squeeze(2).unsqueeze(1)
    weigths = F.softmax(scores, dim=-1)

    context = torch.bmm(weigths, keys)

    return context, weigths

In [41]:
class AttentionGRU(nn.Module):
  def __init__(self, input_size, latent_dim):
    super().__init__()
    self.attention = AttentionBlock(2 * latent_dim)
    self.rnn = nn.GRU(input_size=input_size, hidden_size=2 * latent_dim)
    self.latent_dim = latent_dim

  def forward(self, predicted_label, encoder_outputs):
    batch_size, _, _ = predicted_label.size()
    h = torch.zeros(batch_size, 2 * self.latent_dim)
    predicted_label = predicted_label.permute(1, 0, 2)
    for token in predicted_label:
      context, weights = self.attention(h, encoder_outputs)
      context = context.permute(1, 0, 2)
      token = token.unsqueeze(1).permute(1, 0, 2)
      output, h = self.rnn(token, context)
      h = h.squeeze()

    return output, h

In [42]:
class Decoder(nn.Module):
  def __init__(self, num_tokens, embedding_dim, latent_dim):
    super().__init__()

    self.embedding = nn.Embedding(num_embeddings=num_tokens, embedding_dim=embedding_dim)
    self.rnn = AttentionGRU(embedding_dim, latent_dim)
    self.fc = nn.Linear(in_features=2 * latent_dim, out_features=num_tokens)
    self.softmax = nn.LogSoftmax(dim=1)

  def forward(self, encoder_outputs, predicted_label):
    x = self.embedding(predicted_label)
    x, _ = self.rnn(x, encoder_outputs)
    x = self.fc(x)
    x = self.softmax(x)

    return x

In [43]:
encoder = Encoder(num_tokens=100, embedding_dim=8, latent_dim=16)
decoder = Decoder(num_tokens=100, embedding_dim=8, latent_dim=16)

In [44]:
batch_size = 50
seq_length = 20
predicted_labels_count = 10
test_input = torch.zeros(batch_size, seq_length, dtype=torch.int64)
predicted_labels = torch.zeros(batch_size, predicted_labels_count, dtype=torch.int64)

_, encoder_output = encoder(test_input)
new_token = decoder(encoder_output, predicted_labels)
new_token.size()

torch.Size([1, 50, 100])