<a href="https://colab.research.google.com/github/ymoslem/PyTorchNLP/blob/main/Ex6-NMT-Transformer-custom-dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# NMT with PyTorch nn.Transformer

* Paper: [Attention is all you need](https://arxiv.org/abs/1706.03762)

* PyTorch Transformer Classs: https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html

* Reference video: https://www.youtube.com/watch?v=M6adRGJe5cQ

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter

from torchtext.data.functional import generate_sp_model, load_sp_model, sentencepiece_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from collections import Counter

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

from itertools import islice

In [None]:
# !wget -qq https://nmt-datasets.s3.us-west-2.amazonaws.com/ar-en/Tatoeba.ar-en.ar-filtered.ar.semantic.bz2
# !wget -qq https://nmt-datasets.s3.us-west-2.amazonaws.com/ar-en/Tatoeba.ar-en.en-filtered.en.semantic.bz2

# !bzip2 -d Tatoeba.ar-en.ar-filtered.ar.semantic.bz2
# !bzip2 -d Tatoeba.ar-en.en-filtered.en.semantic.bz2

# !mkdir data
# !mv data/Tatoeba.ar-en.ar-filtered.ar.semantic data/Tatoeba.ar-en.ar
# !mv data/Tatoeba.ar-en.en-filtered.en.semantic data/Tatoeba.ar-en.en

In [None]:
source_file = "data/Tatoeba.ar-en.en"
target_file = "data/Tatoeba.ar-en.ar"

src_code = "<en>"
tgt_code = "<ar>"
eos_code = "</s>"

source_file_tok = source_file+".tok"
target_file_tok = target_file+".tok"

lower = True
dev_size = 1000
test_size = 1000

max_vocab_size = 32000

In [None]:
# Count words in the file
# [To-Do] split up to 10 million lines for SentencePiece

def get_data_freq(data_file, min_freq=2, max_vocab_size=32000):
    with open(data_file) as data:
        text = data.read()
        frequent_word_count = list(Counter(text.split()).values()).count(min_freq)
        vocab_size = frequent_word_count if frequent_word_count < max_vocab_size else max_vocab_size
    return vocab_size


source_vocab_size = get_data_freq(data_file=source_file, min_freq=2, max_vocab_size=max_vocab_size)
print("Source Vocab Size:", source_vocab_size)

target_vocab_size = get_data_freq(data_file=target_file, min_freq=2, max_vocab_size=max_vocab_size)
print("Target Vocab Size:", target_vocab_size)

Source Vocab Size: 2402
Target Vocab Size: 4640


In [None]:
# Train a SentencePiece model for the source
generate_sp_model(source_file,
                  vocab_size=source_vocab_size,
                  model_type="unigram",
                  model_prefix="source.spm")
print("Done! Training SentencePiece source model completed.")

# Train a SentencePiece model for the target
generate_sp_model(target_file,
                  vocab_size=target_vocab_size,
                  model_type="unigram",
                  model_prefix="target.spm")
print("Done! Training SentencePiece target model completed.")

Done! Training SentencePiece source model completed.
Done! Training SentencePiece target model completed.


In [None]:
# Load the SentencePiece models
source_sp_model = load_sp_model("source.spm.model")
target_sp_model = load_sp_model("target.spm.model")

In [None]:
# Tokenize data and split it into train, dev, test
def prepare(source_file, target_file, dev_size, test_size, source_sp_model, target_sp_model, lower=False):
    with open(source_file) as source, open(target_file) as target:
        if lower==True:
            source_lines = [line.strip().lower() for line in source.readlines()]
            target_lines = [line.strip().lower() for line in target.readlines()]
        else:
            source_lines = [line.strip() for line in source.readlines()]
            target_lines = [line.strip() for line in target.readlines()]


        if len(source_lines) == len(target_lines):
            data_size = len(source_lines)
            # print(data_size)
        else:
            raise ValueError("Length of source and target lines must be the same!")

        # Tokenize source sentences
        sp_tokens_generator = sentencepiece_tokenizer(source_sp_model)
        source_lines_tok = sp_tokens_generator(source_lines)

        # Tokenize target sentences
        sp_tokens_generator = sentencepiece_tokenizer(target_sp_model)
        target_lines_tok = sp_tokens_generator(target_lines)

        # Split data into train, dev, and test
        data_lines = zip(source_lines_tok, target_lines_tok)

        train_size = data_size - (dev_size + test_size)
        # print(train_size)

        dev_dataset = list(islice(data_lines, dev_size))
        # print(next(iter(dev_dataset)))
        test_dataset = list(islice(data_lines, test_size))
        # print(next(iter(test_dataset)))
        train_dataset = list(islice(data_lines, train_size))
        # print(next(iter(train_dataset)))

        return train_dataset, dev_dataset, test_dataset

train_dataset, dev_dataset, test_dataset = prepare(source_file, target_file, dev_size, test_size,
                                                   source_sp_model, target_sp_model,
                                                   lower=lower)

print("Training dataset size: %s" %len(train_dataset),
      "Dev dataset size: %s" %len(dev_dataset),
      "Test dataset size: %s" %len(test_dataset),
      sep=" | "
     )

Training dataset size: 24714 | Dev dataset size: 1000 | Test dataset size: 1000


In [None]:
source, target = zip(*test_dataset)
print(*source[:5], sep="\n")
print(*target[:5], sep="\n")

#print(*test_dataset[:5], sep="\n")

['▁', 'i', '▁hate', '▁this', '▁weather', '.']
['▁this', '▁movie', '▁is', '▁worth', '▁see', 'ing', '▁again', '.']
['▁who', '▁teach', 'es', '▁you', '▁f', 're', 'n', 'ch', '?']
['▁', 'i', '▁fel', 't', '▁sad', '.']
['▁we', '▁have', '▁to', '▁help', '.']
['▁كم', '▁أكره', '▁هذا', '▁الطقس', '!']
['▁هذا', '▁الفلم', '▁يستحق', '▁ال', 'مشاهدة', '▁لم', 'رة', '▁ثانية', '.']
['▁من', '▁يدر', 'ّ', 'سك', '▁الفرنسية', '؟']
['▁شعر', 'ت', '▁بال', 'حزن', '.']
['▁علينا', '▁أن', '▁ن', 'ساعد']


In [None]:
# Build Vocabulary

source_train_tok, target_train_tok = zip(*train_dataset)
print("First source sentence:", next(iter(source_train_tok)))
print("First target sentence:", next(iter(target_train_tok)))

# For shared vocabulary, combine the source and target, and build vocab only once
# shared_train_tok = source_train_tok + target_train_tok

# Build source vocabulary
source_vocab = build_vocab_from_iterator(source_train_tok,
                                     specials=["<unk>", '<pad>', "<s>", "</s>", "<en>", "<ar>"],
                                     min_freq=2,
                                     max_tokens=max_vocab_size)
source_vocab.set_default_index(source_vocab["<unk>"])
print("Source vocab Size:", len(source_vocab))

# Build target vocabulary
target_vocab = build_vocab_from_iterator(target_train_tok,
                                     specials=["<unk>", '<pad>', "<s>", "</s>", "<en>", "<ar>"],
                                     min_freq=2,
                                     max_tokens=max_vocab_size)
target_vocab.set_default_index(target_vocab["<unk>"])
print("Target vocab Size:", len(target_vocab))

First source sentence: ['▁that', '▁clinic', '▁still', '▁exist', 's', '.']
First target sentence: ['▁تلك', '▁العيادة', '▁لا', '▁ت', 'زال', '▁موجودة', '.']
Source vocab Size: 2075
Target vocab Size: 4665


In [None]:
print(source_vocab([src_code, '▁here', '▁is', '▁an', '▁example', '</s>']))

[4, 95, 18, 87, 1637, 3]


In [None]:
source_vocab["<pad>"]

1

In [None]:
# Save Vocabulary if needed
torch.save(source_vocab, 'source_vocab.pth')
torch.save(target_vocab, 'target_vocab.pth')

# How to load later
# vocab_obj = torch.load('source_vocab.pth')
# vocab_obj["<pad>"]

In [None]:
# List mapping indices to tokens
# itos = source_vocab.get_itos()

# Dictionary mapping tokens to indices
# stoi = source_vocab.get_stoi()

In [None]:
import os

# Which GPU to use
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# For debugging CUDA errors
os.environ["CUDA_LAUNCH_BLOCKING"]="1"

os.environ.get("CUDA_VISIBLE_DEVICES")

'0'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128  # examples
pad_idx = target_vocab["<pad>"]
max_len = 100  # tokens in source or target


def collate_batch(batch):
    sources, targets = [], []
    for source, target in batch:

        if len(source) < max_len and len(target) < max_len \
        and len(source) > 1 and len(target) > 1:

            source = [src_code] + source + [eos_code]
            target = [tgt_code] + target + [eos_code]

            source_idx = source_vocab(source)
            target_idx = target_vocab(target)

            source_tensor = torch.tensor(source_idx, dtype=torch.int64)
            target_tensor = torch.tensor(target_idx, dtype=torch.int64)

            sources.append(source_tensor)
            targets.append(target_tensor)

    sources = pad_sequence(sources, padding_value=pad_idx)
    sources = sources.to(device)

    targets = pad_sequence(targets, padding_value=pad_idx)
    targets = targets.to(device)

    return sources, targets

In [None]:
# One way without a bucket iterator
# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
# valid_dataloader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
# test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)

# Test
# for x, y in test_dataloader:
#    print(x.shape, y.shape)

In [None]:
# Alternatively, add batch_sampler to act as a bucket iterator
# It batches examples of similar lengths together.
# Minimizes amount of padding needed while producing freshly shuffled batches for each new epoch.
# https://colab.research.google.com/drive/1Zg7Csa4NJ1APg5JGR0BakjOudvmqxTpt

from random import shuffle

def batch_sampler(dataset):
    indices = [(i, s[1]) for i, s in enumerate(dataset)]
    shuffle(indices)
    pooled_indices = []
    # create pool of indices with similar lengths
    for i in range(0, len(indices), batch_size * 100):
        pooled_indices.extend(sorted(indices[i:i + batch_size * 100], key=lambda x: x[1]))

    pooled_indices = [x[0] for x in pooled_indices]

    # yield indices for current batch
    for i in range(0, len(pooled_indices), batch_size):
        yield pooled_indices[i:i + batch_size]


train_dataloader = DataLoader(train_dataset, collate_fn=collate_batch, batch_sampler=batch_sampler(train_dataset))
valid_dataloader = DataLoader(dev_dataset, collate_fn=collate_batch, batch_sampler=batch_sampler(dev_dataset))
test_dataloader = DataLoader(test_dataset, collate_fn=collate_batch, batch_sampler=batch_sampler(test_dataset))

In [None]:
# Test
for x, y in test_dataloader:
  print(x.shape, y.shape)

torch.Size([55, 127]) torch.Size([46, 127])
torch.Size([62, 126]) torch.Size([71, 126])
torch.Size([43, 128]) torch.Size([34, 128])
torch.Size([59, 128]) torch.Size([40, 128])
torch.Size([35, 128]) torch.Size([44, 128])
torch.Size([60, 128]) torch.Size([48, 128])
torch.Size([36, 128]) torch.Size([40, 128])
torch.Size([83, 104]) torch.Size([55, 104])


# Transformer

In [None]:
# Info: https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html

class Transformer(nn.Module):
    def __init__(
        self,
        embedding_size,
        src_vocab_size,
        tgt_vocab_size,
        src_pad_idx,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        forward_expantion,
        dropout,
        max_len,
        device,
        norm_first=True
    ):
        super(Transformer, self).__init__()

        self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)
        self.src_positional_embedding = nn.Embedding(max_len, embedding_size)

        self.tgt_word_embedding = nn.Embedding(tgt_vocab_size, embedding_size)
        self.tgt_positional_embedding = nn.Embedding(max_len, embedding_size)

        self.device = device
        self.src_pad_idx = src_pad_idx
        self.dropout = nn.Dropout(dropout)

        self.transformer = nn.Transformer(
            embedding_size,
            num_heads,
            num_encoder_layers,
            num_decoder_layers,
            forward_expantion,
            dropout,
            norm_first=True
        )

        self.fc_out = nn.Linear(embedding_size, tgt_vocab_size)

    def make_src_mask(self, src):
        # src shape: (src_len, N)
        # src_mask shape: (N, src_len)
        # maching required shape of src_key_padding_mask in nn.Transformer
        src_mask = src.transpose(0, 1) == self.src_pad_idx
        # src_mask shape: (N, src_len)

        return src_mask.to(self.device)

    def forward(self, src, tgt):
        src_seq_length, N = src.shape
        tgt_seq_length, N = tgt.shape

        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(1)
            .expand(src_seq_length, N)
            .to(self.device)
        )

        tgt_positions = (
            torch.arange(0, tgt_seq_length)
            .unsqueeze(1)
            .expand(tgt_seq_length, N)
            .to(self.device)
        )

        src_embedding = self.dropout(
            (self.src_word_embedding(src) + self.src_positional_embedding(src_positions))
        )

        tgt_embedding = self.dropout(
            (self.tgt_word_embedding(tgt) + self.tgt_positional_embedding(tgt_positions))
        )

        src_padding_mask = self.make_src_mask(src)
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_seq_length).to(self.device)

        out = self.transformer(
            src_embedding,
            tgt_embedding,
            src_key_padding_mask=src_padding_mask,
            tgt_mask=tgt_mask
        )

        out = self.fc_out(out)

        return out

# Helper Functions

In [None]:
import torch
from torchtext.data.functional import load_sp_model, sentencepiece_tokenizer
from torchtext.data.metrics import bleu_score
import sys
from random import random


def translate(text, model, sp_model, source_vocab, target_vocab, device, max_length=100):

    # Tokenize the text and lower-case it
    source_sp_model = load_sp_model(sp_model)
    sp_tokens_generator = sentencepiece_tokenizer(source_sp_model)
    tokenized_text = sp_tokens_generator(text)

    # [To-Do] Adjust to translate multiple sentences
    tokenized_text = next(tokenized_text)
    print("• Source text:", tokenized_text, sep="\n")

    tokenized_text = [src_code] + tokenized_text + [eos_code]
    # print(tokenized_text)

    # Convert text to indices
    text_to_indices = source_vocab(tokenized_text)

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    model.eval()

    outputs = target_vocab([src_code])

    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        outputs.append(best_guess)

        if best_guess == target_vocab["</s>"]:
            break

    target_vocab_itos = target_vocab.get_itos()
    translated_sentence = [target_vocab_itos[idx] for idx in outputs]
    # remove the start token
    translated_sentence = translated_sentence[1:]
    # remove the end token
    if translated_sentence[-1] == "</s>":
        translated_sentence = translated_sentence[:-1]
    translated_sentence = " ".join(translated_sentence).replace(" ", "").replace("▁", " ")

    return translated_sentence


def bleu(data_iter, model, tokenizer, source_vocab, target_vocab, device):
    targets = []
    outputs = []

    for source, target in data_iter:

        prediction = translate(source, model, tokenizer, source_vocab, target_vocab, device)
        prediction = prediction[:-1]  # remove <eos> token

        targets.append([target])
        outputs.append(prediction)

    return bleu_score(outputs, targets)


def save_checkpoint(model, optimizer, epoch, mean_valid_loss):
  checkpoint = {"state_dict":model.state_dict(),
                "opt":optimizer.state_dict(),
                "epoch":epoch,
                "valid_loss":mean_valid_loss,
                "encoder_type":"transformer",
                }
  checkpoint_path = f"model_checkpoint.{epoch}.tar"
  torch.save(checkpoint, checkpoint_path)

  return checkpoint_path


def load_checkpoint(checkpoint_path, model, optimizer):
  print("=> Loading checkpoint")
  checkpoint = torch.load(checkpoint_path)
  model.load_state_dict(checkpoint["state_dict"])
  optimizer.load_state_dict(checkpoint["opt"])

  return model


def load_checkpoint_for_inference(checkpoint_path, model):
  checkpoint = torch.load(checkpoint_path)
  model.load_state_dict(checkpoint['state_dict'])
  device_name = "GPU" if next(model.parameters()).is_cuda is True else "CPU"
  print("Model checkpoint loaded to %s" %device_name)

  return model


# Reference: https://stackoverflow.com/a/73704579/7614380
class EarlyStopper:
  def __init__(self, patience=1, min_delta=0):
    self.patience = patience
    self.min_delta = min_delta
    self.counter = 0
    self.min_validation_loss = float("inf")

  def early_stop(self, validation_loss):
    if validation_loss < self.min_validation_loss:
      self.min_validation_loss = validation_loss
      self.counter = 0
    elif validation_loss > (self.min_validation_loss + self.min_delta):
      self.counter += 1
      if self.counter >= self.patience:
        return True
    return False

In [None]:
# Validate the model

def validate(valid_dataloader):

  valid_dataloader = DataLoader(dev_dataset, collate_fn=collate_batch, batch_sampler=batch_sampler(dev_dataset))

  model.eval() # prep model for evaluation

  with torch.no_grad():

    valid_losses = []

    for source_batch, target_batch in valid_dataloader:
      source = source_batch.to(device)
      target = target_batch.to(device)

      # Forward pass: compute predicted outputs by passing inputs to the model
      output = model(source, target[:-1, :])
      # Exclude the start token
      output = output.reshape(-1, output.shape[2])
      target = target[1:].reshape(-1)

      # Calculate the loss
      loss = criterion(output, target)
      # Save the validation loss
      valid_losses.append(loss.item())

    mean_valid_loss = sum(valid_losses) / len(valid_losses)

  return mean_valid_loss

# Training Setup

In [None]:
load_model = False

# Training Hyperparameters
num_epochs = 1000
learning_rate = 3e-4
batch_size = 128  # examples - make sure it is the same as in data preperation
early_stopping_epochs = 4  # stop training if the validation loss is not improved after n epochs

# Model Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() == True else "cpu")
src_vocab_size = len(source_vocab)
tgt_vocab_size = len(target_vocab)
embedding_size = 512
num_heads = 8
num_encoder_layers = 3  # 6
num_decoder_layers = 3  # 6
dropout = 0.1
max_len = 200
forward_expansion = 2048
norm_first = True
src_pad_idx = source_vocab["<pad>"]


# Tensorboard
writer = SummaryWriter(f"runs/loss_plot")
step = 0

model = Transformer(
  embedding_size,
  src_vocab_size,
  tgt_vocab_size,
  src_pad_idx,
  num_heads,
  num_encoder_layers,
  num_decoder_layers,
  forward_expansion,
  dropout,
  max_len,
  device,
  norm_first=True,
).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
  optimizer, factor=0.1, patience=10, verbose=True
)


tgt_pad_idx = target_vocab["<pad>"]
criterion = nn.CrossEntropyLoss(ignore_index=tgt_pad_idx)

if load_model:
  load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)

# Training Loop

In [None]:
sentence = ["where is your school?"]

best_model_loss = float("inf")

early_stopper = EarlyStopper(patience=early_stopping_epochs, min_delta=0)

for epoch in range(num_epochs):
  print(f"Epoch [{epoch} / {num_epochs}]")

  train_dataloader = DataLoader(train_dataset, collate_fn=collate_batch, batch_sampler=batch_sampler(train_dataset))

  # important if model.eval() was called earlier as in translate()
  model.train()

  train_losses = []

  for b, (source_batch, target_batch) in enumerate(train_dataloader):
    source = source_batch.to(device)
    target = target_batch.to(device)

    # Forward propagation
    output = model(source, target[:-1, :])
    # output shape: (target_len, batch_size, output_dim)

    # Exclude the start token
    # Reshape to match the accepted input form of CrossEntropyLoss
    # Keep the output dimention (whose size is tgt_vocab_size, for the probability of each token)...
    # and flatten the two first dimentions
    output = output.reshape(-1, output.shape[2])
    target = target[1:].reshape(-1)

    optimizer.zero_grad()
    loss = criterion(output, target)
    train_losses.append(loss.item())

    # Back propagation
    loss.backward()

    # Clip to avoid exploding gradients, makes sure grads are within a healthy range
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

    # Gradient descent step
    optimizer.step()

    writer.add_scalar("Training Loss", loss, global_step=step)
    step += 1

  mean_train_loss = sum(train_losses) / len(train_losses)
  scheduler.step(mean_train_loss)


  ### Validate the model
  mean_valid_loss = validate(valid_dataloader)

  # Save the best checkpoint
  if mean_valid_loss < best_model_loss:
    best_model_loss = mean_valid_loss
    checkpoint_path = save_checkpoint(model, optimizer, epoch, mean_valid_loss)
    print("Model checkpoint saved at:", checkpoint_path)
  else:
    print("Validation loss did not improve from the best model!")

  # Translate the example sentence
  translated_sentence = translate(sentence, model, "source.spm.model", source_vocab, target_vocab, device, max_length=100)
  print("Translated text:", translated_sentence, sep="\n")

  print("Training Loss: %s" %round(mean_train_loss, 2), "Validation Loss: %s" %round(mean_valid_loss, 2), sep="\t")
  print("---"*10)

  # Early Stopping, if the validation loss is not improving
  if early_stopper.early_stop(mean_valid_loss):
    print("Early stopping. The best model has been saved at:", checkpoint_path)
    break


Epoch [0 / 1000]
Saving checkpoint...
Saving checkpoint at: model_checkpoint.0.tar
• Source text:
['▁where', '▁is', '▁your', '▁school', '?']
Translated text:
 هل هل هل هل هل هل هل هل هل هل هل؟
Training Loss: 5.83	Validation Loss: 5.46
------------------------------
Epoch [1 / 1000]
Saving checkpoint...
Saving checkpoint at: model_checkpoint.1.tar
• Source text:
['▁where', '▁is', '▁your', '▁school', '?']
Translated text:
 هل هل؟
Training Loss: 5.25	Validation Loss: 5.05
------------------------------
Epoch [2 / 1000]
Saving checkpoint...
Saving checkpoint at: model_checkpoint.2.tar
• Source text:
['▁where', '▁is', '▁your', '▁school', '?']
Translated text:
 هل تُِِكَكَكَكَكََكَكَِكََكَََِكَكََََََِِكََََِكََََََََََََََََََََََََََََََََََََََََََََََََََََِ
Training Loss: 4.79	Validation Loss: 4.84
------------------------------
Epoch [3 / 1000]
Saving checkpoint...
Saving checkpoint at: model_checkpoint.3.tar
• Source text:
['▁where', '▁is', '▁your', '▁school', '?']
Translated text:
كت

---

# Translation

In [None]:
loaded_model = load_checkpoint_for_inference(checkpoint_path, model)

Model checkpoint loaded to GPU


In [None]:
sentence = ["where is your school?"]
translate(sentence, loaded_model, "source.spm.model", source_vocab, target_vocab, device, max_length=100)

• Source text:
['▁where', '▁is', '▁your', '▁school', '?']


' أين مدرستك؟'

### To-do

* [DONE] Prenorm in nn.Transformer (norm_first=True)
* Arrange sentences by length in one batch (<a href="https://github.com/pytorch/text/blob/master/examples/legacy_tutorial/migration_tutorial.ipynb">example</a>)
* Beam Search at inference time
* try: train_dataset,test_dataset=torch.utils.data.random_split(ants_dataset,(train_length,test_length))
* Add (more) validation data