## Accelerated Training of Sequence-to-Sequence Transformer Model
This notebook demonstrates using [ONNX Runtime](https://cloudblogs.microsoft.com/opensource/2020/05/19/announcing-support-for-accelerated-training-with-onnx-runtime/) to accelerate the training of a simple sequence-to-sequence model. It uses a slightly modified version of the implementation available at [PyTorch tutorial](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) on [nn.Transformer](https://pytorch.org/docs/master/nn.html?highlight=nn%20transformer#torch.nn.Transformer) module. This example uses [Adam](https://arxiv.org/abs/1412.6980) optimization instead of [SGD](https://pytorch.org/docs/stable/optim.html#torch.optim.SGD) that is used in the original implementation. This notebook is available at [https://github.com/skaarthik/onnxruntime-training-databricks](https://github.com/skaarthik/onnxruntime-training-databricks). Recipes on using ONNX Runtime to accelerate pretraining and finetuning of BERT and GPT-2 models are available at [https://github.com/microsoft/onnxruntime-training-examples](https://github.com/microsoft/onnxruntime-training-examples).

####Prerequisites for the demo
Databricks cluster:
* Databricks cluster with `7.3 LTS ML` runtime
* Node with `V100` GPU (like, `Standard_NC6s_v3` node in Azure)

Custom containers are not supported on ML Runtime in Databricks. This demo needs ML Runtime and GPUs. As an alternative to packaging dependencies in a customer container, use the commands below to prepare Databricks environment for this notebook. It is an one-time process for a given Databricks cluster.

Run the following commands to install the following packages to the Conda environment you will be using for this demo:
* pip install `torchtext`
* pip install `onnx`
* pip install `cerberus`
* pip install `https://onnxtraining.blob.core.windows.net/ort-databricks-demo/onnxruntime_gpu-1.5.1-cp37-cp37m-linux_x86_64.whl`

####Define the model
Refer to https://pytorch.org/tutorials/beginner/transformer_tutorial.html#define-the-model for details.

In [4]:
import math
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)

        self.init_weights()

    def _generate_square_subsequent_mask(self, src):
        sz = src.size(0)
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        if self.src_mask == None or self.src_mask.size(0) != src.size(0):
            device = src.device
            mask = self._generate_square_subsequent_mask(src).to(device)
            self.src_mask = mask

        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, self.src_mask)
        output = self.decoder(output)
        return output

####Load and batch data
Refer to https://pytorch.org/tutorials/beginner/transformer_tutorial.html#load-and-batch-data for details.

In [6]:
import torchtext
from torchtext.data.utils import get_tokenizer
TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"),
                            init_token='<sos>',
                            eos_token='<eos>',
                            lower=True)
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
TEXT.build_vocab(train_txt)
device = torch.device("cuda")

####Functions to generate input and target sequence
Refer to https://pytorch.org/tutorials/beginner/transformer_tutorial.html#functions-to-generate-input-and-target-sequence for details.

In [8]:
bptt = 35
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target
  
def batchify(data, bsz):
    data = TEXT.numericalize([data.examples[0].text])
    # Divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_txt, batch_size)
val_data = batchify(val_txt, eval_batch_size)
test_data = batchify(test_txt, eval_batch_size)  

####Initialize variables needed for model creation
Refer to https://pytorch.org/tutorials/beginner/transformer_tutorial.html#initiate-an-instance for details.

In [10]:
ntokens = len(TEXT.vocab.stoi) # the size of vocabulary
emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
lr = 0.001 # learning rate

def calculate_loss(output, targets):
    output = output.view(-1, len(TEXT.vocab.stoi))
    return criterion(output, targets) 

####Initiate an instance of the model for training `without ONNX Runtime acceleration` (referred to as `baseline` training in this notebook)
Refer to https://pytorch.org/tutorials/beginner/transformer_tutorial.html#initiate-an-instance for details.

After the next few steps below that are required for `baseline` training, this notebook will cover the steps required for accelerated training using ONNX Runtime (and that is referred to as `ort` training).

In [12]:
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)
criterion = nn.CrossEntropyLoss()
learning_rate = lr
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
metric_prefix = 'baseline'  

####Define function to use mlflow logging

In [14]:
import mlflow
def log_metrics(epoch, batch, train_data_len, bptt, lr, elapsed, log_interval, cur_loss, log_prefix):
  ms = elapsed * 1000 / log_interval
  ppl = math.exp(cur_loss)
  print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.3f} | ms/batch {:5.2f} | loss {:5.2f} | ppl {:8.2f}'.
        format(epoch, batch, train_data_len // bptt, lr, ms, cur_loss, ppl))
  mlflow.log_metric(log_prefix + '_milliseconds/batch', ms, step=batch)
  mlflow.log_metric(log_prefix + '_loss', cur_loss, step=batch)
  mlflow.log_metric(log_prefix + '_ppl', ppl, step=batch)

####Define `train` and `evaluate` methods
Refer to https://pytorch.org/tutorials/beginner/transformer_tutorial.html#run-the-model for details.

The `train` and `evaluate` methods are shared for both `baseline` and `ort` training. The boolean flag `accelerate_using_ort` is used for conditional execution needed for each approach.

In [16]:
import time
def train(epoch, accelerate_using_ort):
    if not accelerate_using_ort:
      model.train() # Turn on the train mode
      
    total_loss = 0.
    start_time = time.time()

    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        
        if not accelerate_using_ort:
          optimizer.zero_grad()
          output = model(data)
          loss = calculate_loss(output, targets)
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
          optimizer.step()
          current_learning_rate = learning_rate
        else:
          loss, output = trainer.train_step(data, targets)
          current_learning_rate = learning_rate

        total_loss += loss.item()
        log_interval = 50

        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time

            log_metrics(epoch, 
                        batch, 
                        len(train_data),
                        bptt, 
                        current_learning_rate,
                        elapsed,
                        log_interval,
                        cur_loss,
                        metric_prefix)
            
            total_loss = 0
            start_time = time.time()           

In [17]:
def evaluate(eval_model, data_source, accelerate_using_ort=False):
    if not accelerate_using_ort:
      eval_model.eval() # Turn on the evaluation mode
      
    total_loss = 0.
    ntokens = len(TEXT.vocab.stoi)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            if not accelerate_using_ort:
              output = eval_model(data)
              output_flat = output.view(-1, ntokens)
              total_loss += len(data) * criterion(output_flat, targets).item()
            else:
              loss, outputs = trainer.eval_step(data, targets)
              total_loss += len(data) * loss.item()              
            
    return total_loss / (len(data_source) - 1)

####Define `train_model` method to execute the train loop
Refer to https://pytorch.org/tutorials/beginner/transformer_tutorial.html#run-the-model for details.

This method is shared for both `baseline` and `ort` training. The boolean parameter `accelerate_using_ort` is used for conditional execution needed for each approach when calling `train` and `evaluate` methods.

In [19]:
def train_model(accelerate_using_ort=False):
  best_val_loss = float("inf")
  epochs = 3 # The number of epochs
  best_model = None

  for epoch in range(1, epochs + 1):
      epoch_start_time = time.time()
      
      train(epoch, accelerate_using_ort)
        
      val_loss = evaluate(model, val_data, accelerate_using_ort)
      print('-' * 89)
      print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | valid ppl {:8.2f}'
            .format(epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss)))
      print('-' * 89)

      if val_loss < best_val_loss:
          best_val_loss = val_loss
          best_model = model
  return best_model

####Train model `without ONNX Runtime acceleration`

In [21]:
baseline_model = train_model(accelerate_using_ort=False)

####Evaluate the model with the test dataset
The `baseline` model trained without ONNX Runtime accelerated is evaluated in the step below.

Refer to https://pytorch.org/tutorials/beginner/transformer_tutorial.html#evaluate-the-model-with-the-test-dataset for details.

In [23]:
test_loss = evaluate(baseline_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(test_loss, math.exp(test_loss)))
print('=' * 89)

###Accelerated model training

####Initiate an instance of the model for training `with ONNX Runtime acceleration` (referred to as `ort` training in this notebook)

To start with a new model from scratch, the step below instantiates a model with the additional code needed for accelerated training using ONNX Runtime. This step is necessary to make sure the model created in the section above titled `Initiate an instance of the model for training without ONNX Runtime acceleration (referred to as baseline training in this notebook)` is not reused in the steps below. Otherwise, loss and ppl metrics will not be comparable between the `baseline` and `ort` model training approaches shown in this notebook.

Refer to https://pytorch.org/tutorials/beginner/transformer_tutorial.html#initiate-an-instance for details.
Refer to https://github.com/microsoft/onnxruntime-training-examples/tree/master/getting-started for details on changes needed to accelerate a PyTorch implementation using ONNX Runtime (ORT).

In [26]:
from onnxruntime.training import ORTTrainer, optim

model_description = {'inputs':  [('src', ['bptt', 'batch_size']),
                                 ('label', ['bptt_x_batch_size'])],
                     'outputs': [('loss', [], True),
                                 ('output', ['bptt', 'batch_size', ntokens])]}

model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)
criterion = nn.CrossEntropyLoss()
learning_rate = lr
optimizer_config = optim.AdamConfig(lr=learning_rate)

trainer = ORTTrainer(model,               # model
                     model_description,   # model description
                     optimizer_config,    # optimizer configuration
                     calculate_loss)      # loss function
metric_prefix = 'onnxruntime'

####Train model `with ONNX Runtime acceleration`

In [28]:
ort_model = train_model(accelerate_using_ort=True)

####Evaluate the model with the test dataset
The `ort` model trained without ONNX Runtime accelerated is evaluated in the step below

Refer to https://pytorch.org/tutorials/beginner/transformer_tutorial.html#evaluate-the-model-with-the-test-dataset for details.

In [30]:
test_loss = evaluate(ort_model, test_data, accelerate_using_ort=True)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(test_loss, math.exp(test_loss)))
print('=' * 89)

###Summary
This notebook shows how to accelerate training of a `Transformer` model using `ONNX Runtime`. The optimizations on compute and memory utilization in `ONNX Runtime` **reduced the training time** of a simple sequence-to-sequence model by **~40%** (in each epoch and total training time) without any change in hyper parameters and without impacting training and test metrics on loss and ppl.