# <font style="color:blue">Inference on Production</font>

Imagine a situation; you and your team are working on a project that has a few ML problems to be solved. You picked one problem and solved using the PyTorch framework. However, your colleague used Tensorflow to solve the other problem. Both problems are a part of a bigger project. In this scenario, it is obvious to wish for a common format to share ML models.

ONNX ([Open Neural Network Exchange](https://onnx.ai/)) is one such open format that allows us to model interchange between various [ML frameworks and tools](https://onnx.ai/supported-tools).


**In this notebook, we will see how to convert a PyTorch Lightning saved checkpoint to the ONNX model.  Let's take an example of the checkpoint saved by the last notebook of MNIST training.**

In [1]:
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

## <font style="color:green">Lightning Module</font>

In [2]:
class LeNet5(pl.LightningModule):  # here nn.Module is replaced by LightningModule
    def __init__(self, learning_rate=0.01):
        super().__init__()
        
        # Save the arguments as hyperparameters. 
        self.save_hyperparameters()

        # convolution layers
        self._body = nn.Sequential(
            # First convolution Layer
            # input size = (32, 32), output size = (28, 28)
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),
            # ReLU activation
            nn.ReLU(inplace=True),
            # Max pool 2-d
            nn.MaxPool2d(kernel_size=2),
            
            # Second convolution layer
            # input size = (14, 14), output size = (10, 10)
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),
            # output size = (5, 5)
        )
        
        # Fully connected layers
        self._head = nn.Sequential(
            # First fully connected layer
            # in_features = total number of weights in last conv layer = 16 * 5 * 5
            nn.Linear(in_features=16 * 5 * 5, out_features=120), 
            
            # ReLU activation
            nn.ReLU(inplace=True),
            
            # second fully connected layer
            # in_features = output of last linear layer = 120 
            nn.Linear(in_features=120, out_features=84), 
            
            # ReLU activation
            nn.ReLU(inplace=True),
            
            # Third fully connected layer. It is also output layer
            # in_features = output of last linear layer = 84
            # and out_features = number of classes = 10 (MNIST data 0-9)
            nn.Linear(in_features=84, out_features=10)
        )
        

    def forward(self, x):
        # apply feature extractor
        x = self._body(x)
        # flatten the output of conv layers
        # dimension should be batch_size * number_of weights_in_last conv_layer
        x = x.view(x.size()[0], -1)
        # apply classification head
        x = self._head(x)
        return x
    
    def training_step(self, batch, batch_idx):
        
        # get data and labels from batch
        data, target = batch

        # get prediction
        output = self(data)
        
        # calculate loss
        loss = F.cross_entropy(output, target)
        
        # get probability score using softmax
        prob = F.softmax(output, dim=1)
        
        # get the index of the max probability
        pred = prob.data.max(dim=1)[1]
        
        acc = accuracy(pred=pred, target=target)
        
        
        dic = {
            'train_loss': loss,
            'train_acc': acc
        }
        

        return {'loss': loss, 'log': dic, 'progress_bar': dic}

    def training_epoch_end(self, training_step_outputs):
        # training_step_outputs = [{'loss': loss, 'log': dic, 'progress_bar': dic}, ..., 
        #{'loss': loss, 'log': dic, 'progress_bar': dic}]
        avg_train_loss = torch.tensor([x['progress_bar']['train_loss'] for x in training_step_outputs]).mean()
        avg_train_acc = torch.tensor([x['progress_bar']['train_acc'] for x in training_step_outputs]).mean()
        
        
        dic = {
            'epoch_train_loss': avg_train_loss,
            'epoch_train_acc': avg_train_acc
        }
        return {'log': dic, 'progress_bar': dic}
        
    
    def validation_step(self, batch, batch_idx):
        
        # get data and labels from batch
        data, target = batch
        
        # get prediction
        output = self(data)
        
        # calculate loss
        loss = F.cross_entropy(output, target)
        
        # get probability score using softmax
        prob = F.softmax(output, dim=1)
        
        # get the index of the max probability
        pred = prob.data.max(dim=1)[1]
        
        acc = accuracy(pred=pred, target=target)
        
        dic = {
            'v_loss': loss,
            'v_acc': acc
        }
        
        return dic
    
    
    def validation_epoch_end(self, validation_step_outputs):
        # validation_step_outputs = [dic, ..., dic]
        
        avg_val_loss = torch.tensor([x['v_loss'] for x in validation_step_outputs]).mean()
        avg_val_acc = torch.tensor([x['v_acc'] for x in validation_step_outputs]).mean()
        
        
        dic = {
            'avg_val_loss': avg_val_loss,
            'avg_val_acc': avg_val_acc
        }
        return {'val_loss': avg_val_loss, 'log': dic, 'progress_bar': dic}
    
    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=self.hparams.learning_rate)

## <font style="color:green">Get the Checkpoint</font>

In [3]:
import os

def get_latest_run_version_ckpt_epoch_no(lightning_logs_dir='lightning_logs', run_version=None):
    if run_version is None:
        run_version = 0
        for dir_name in os.listdir(lightning_logs_dir):
            if 'version' in dir_name:
                if int(dir_name.split('_')[1]) > run_version:
                    run_version = int(dir_name.split('_')[1])
                
    checkpoints_dir = os.path.join(lightning_logs_dir, 'version_{}'.format(run_version), 'checkpoints')
    
    files = os.listdir(checkpoints_dir)
    ckpt_filename = None
    for file in files:
        if file.endswith('.ckpt'):
            ckpt_filename = file
        
    if ckpt_filename is not None:
        ckpt_path = os.path.join(checkpoints_dir, ckpt_filename)
    else:
        print('CKPT file is not present')
    
    return ckpt_path

In [4]:
# get checkpoint path
ckpt_path = get_latest_run_version_ckpt_epoch_no(run_version=5)
print('ckpt_path: {}'.format(ckpt_path))

ckpt_path: lightning_logs/version_5/checkpoints/epoch=7.ckpt


## <font style="color:green">Convert to ONNX Format</font>

Get details [here](https://pytorch-lightning.readthedocs.io/en/latest/production_inference.html).

In [5]:
import onnxruntime

def convert_to_onnx_model(model, ckpt_path, onnx_path=None):
    
    # ONNX filename
    if onnx_path is None:
        onnx_path = ckpt_path[:-4] + 'onnx'
        
    # Load the checkpoint
    ckpt_model = model.load_from_checkpoint(ckpt_path)
    
    # Freeze the network
    ckpt_model.freeze()
    
    # Add a sample input. Here input shape = (batch_size, num_channel, height, width)
    input_sample = torch.randn((1, 1, 32, 32))
    
    # convert to ONNX model
    ckpt_model.to_onnx(onnx_path, input_sample, export_params=True)
    
    return onnx_path

In [6]:
# initiate the model
model = LeNet5()

# convert the checkpoint to onnx format
onnx_model_path = convert_to_onnx_model(model, ckpt_path)
print('onnx_model_path: {}'.format(onnx_model_path))

onnx_model_path: lightning_logs/version_5/checkpoints/epoch=7.onnx


## <font style="color:green">Sample Inference</font>

In [7]:
import numpy as np

# init a session
sess = onnxruntime.InferenceSession(onnx_model_path)

# get input name from session
input_name = sess.get_inputs()[0].name

# prepare inputs
inputs = {input_name: np.random.randn(1, 1, 32, 32).astype(np.float32)}

# get output
outputs = sess.run(None, inputs)

print(outputs)

[array([[-0.7201431 ,  0.54916894,  2.4066284 ,  0.9954219 , -1.7776055 ,
         0.395705  , -1.6241193 ,  0.5397513 , -0.983725  , -1.1929322 ]],
      dtype=float32)]


# <font style="color:blue">References</font>


1. https://pytorch-lightning.readthedocs.io/en/latest/production_inference.html

1. https://docs.microsoft.com/en-us/windows/ai/windows-ml/get-onnx-model

1. https://onnx.ai/