In this notebook, we will convert the pytorch model to ONNX model then run with ONNXruntime, and check the performance between ONNX converted model and Pytorch original model. Also, we will convert the model into tensorflow.
<!--  -->
 ONNX runtime is performance focused engine for ONNX model, can be used for inference efficiently acrsoo multiple platforms and hardware. ONNXruntime has proved to considerable increased performace over multiple model.

In [None]:
import io
import numpy as np
from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
from PIL import Image
import torchvision.transforms as transforms
import netron
from torch.autograd import Variable

In [None]:
load_pretrained_model_from_server = True

Super-resolution is a way of increasing the resolution of images, videos and is widely used in image processing or video editing. 

Super resolution model explained in “Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network”


In [None]:
# Super Resolution model definition in PyTorch
import torch.nn as nn
import torch.nn.init as init

class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
        # Rearranges elements in a tensor of shape (*, Cxr^2, H, W) to a tensor of shape (*, C, H \times r, W \times r), where r is an upscale factor.
        # This is useful for implementing efficient sub-pixel convolution with a stride of 1/r

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)
        # Fills the input Tensor with a (semi) orthogonal matrix, as described in Exact solutions to the nonlinear dynamics of learning in deep linear neural networks 

# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)

 It is important to call torch_model.eval() or torch_model.train(False) before exporting the model, to turn the model to inference mode. This is required since operators like dropout or batchnorm behave differently in inference and training mode.

In [None]:
# Load pretrained model weights from server or already downloaded model in local storage
downloaded_model = 'C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/superres_epoch100-44c6958e.pth'
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1    # just a random number

# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
    if load_pretrained_model_from_server is True:
        # load the model's parameter dictionary using the deserialised state_dict
        '''
        A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor. 
        Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and 
        registered buffers (batchnorm’s running_mean) have entries in the model’s state_dict. 
        Optimizer objects (torch.optim) also have a state_dict, which contains information about the optimizer’s state, 
        as well as the hyperparameters used.
        
        Because state_dict objects are Python dictionaries, they can be easily saved, updated, altered, and restored, 
        adding a great deal of modularity to PyTorch models and optimizers.
        '''
        torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
    else:
        torch_model.load_state_dict(torch.load(downloaded_model))

# set the model to inference mode
torch_model.eval()

In [None]:
## Input to the model
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
torch_out = torch_model(x)
print(torch_out)

In [None]:
# Converting the torch model to ONNX format of pytorchModel
'''
 Exporting a model in PyTorch works via tracing or scripting. This tutorial will use as an example a model exported by tracing. 
 To export a model, we call the torch.onnx.export() function. This will execute the model, 
 recording a trace of what operators are used to compute the outputs. Because export runs the model, 
 we need to provide an input tensor x.

'''
class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
        # Rearranges elements in a tensor of shape (*, Cxr^2, H, W) to a tensor of shape (*, C, H \times r, W \times r), where r is an upscale factor.
        # This is useful for implementing efficient sub-pixel convolution with a stride of 1/r

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)
        # Fills the input Tensor with a (semi) orthogonal matrix, as described in Exact solutions to the nonlinear dynamics of learning in deep linear neural networks 

# Create the super-resolution model by using the above model definition.
torch_model_ = SuperResolutionNet(upscale_factor=3)
# torch.load with map_location=torch.device('cpu')
if torch.cuda.is_available() is False:
    torch_model = torch_model.load_state_dict(torch.load(downloaded_model, map_location = lambda storage, loc: storage))
else:
    torch_model_ = torch_model_.load_state_dict(torch.load(downloaded_model))
dummy_input = Variable(torch.randn(1, 1, 224, 224))
# torch.onnx.export(torch_model,               # model being run
#                   x,                         # model input (or a tuple for multiple inputs)
#                   "C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/super_resolution.onnx",   # where to save the model (can be a file or file-like object)
#                   export_params=True,        # store the trained parameter weights inside the model file
#                   opset_version=10,          # the ONNX version to export the model to
#                   do_constant_folding=True,  # whether to execute constant folding for optimization
#                   input_names = ['input'],   # the model's input names
#                   output_names = ['output'], # the model's output names
#                   dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes
#                                 'output' : {0 : 'batch_size'}})
torch.onnx.export(torch_model_,dummy_input,"C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/super_resolution.onnx")

But before verifying the model’s output with ONNX Runtime, we will check the ONNX model with ONNX’s API. First, onnx.load("super_resolution.onnx") will load the saved model and will output a onnx.ModelProto structure (a top-level file/container format for bundling a ML model
Then, onnx.checker.check_model(onnx_model) will verify the model’s structure and confirm that the model has a valid schema. The validity of the ONNX graph is verified by checking the model’s version, the graph’s structure, as well as the nodes and their inputs and outputs.

In [None]:
import onnx
# loading the onnx model
onnx_model = onnx.load("C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/super_resolution.onnx")
# onnx.checker.check_model(onnx_model) will verify the model’s structure and confirm that the model has a valid schema. The validity of the ONNX graph is verified by checking the model’s version, the graph’s structure, as well as the nodes and their inputs and outputs.
onnx.checker.check_model(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))

In [None]:
# To compute the output using ONNX run time's python API, we need to start the session using InferenceSession() and .run() the session.
import onnxruntime

ort_session = onnxruntime.InferenceSession("C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/super_resolution.onnx")


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)


# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")



In [None]:

img = Image.open("C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/cat_224x224.jpg")
resize = transforms.Resize([224, 224])
img = resize(img)

#  we split the image into its Y, Cb, and Cr components. These components represent a greyscale image (Y), and the blue-difference (Cb) and red-difference (Cr) chroma components.

img_ycbcr = img.convert('YCbCr')
img_y, img_cb, img_cr = img_ycbcr.split()
#  split the image into its Y(greyscale image), Cb(blue difference), and Cr(red-difference) components

to_tensor = transforms.ToTensor()
img_y = to_tensor(img_y)
img_y.unsqueeze_(0)

In [None]:
# Started then session, get the inputs, run the session for predicted.
# take the tensor representing the greyscale resized cat image and run the super-resolution model in ONNX Runtime as explained previously.
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)}
ort_outs = ort_session.run(None, ort_inputs)
img_out_y = ort_outs[0]
print("Output as tensor:\n", img_out_y)

In [None]:
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L')

# get the output image follow post-processing step from PyTorch implementation
final_img = Image.merge("YCbCr", (img_out_y, img_cb.resize(img_out_y.size, Image.BICUBIC), img_cr.resize(img_out_y.size, Image.BICUBIC))).convert('RGB')

final_img.save("C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/inference_cat_superres_with_ort.jpg")

In [None]:
Image.open("C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/inference_cat_superres_with_ort.jpg")

COnvert the ONNX model to tensorflow

In [None]:
# Visualise the model graph using netron
import netron

In [None]:
# watching the original pytorch model in netron
netron.start("C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/superres_epoch100-44c6958e.pth")

In [None]:
# Watching the ONNX converted model in netron
netron.start("C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/super_resolution.onnx")

In [None]:
import tensorflow
print(tensorflow.__version__)

converted the onnx model to tensorflow

In [None]:
import tensorflow
from onnx_tf.backend import prepare
onnx_model  = onnx.load("C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/super_resolution.onnx")
tf_rep = prepare(onnx_model)
tf_rep.export_graph("C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/super_resolution.pb")

In [None]:
# Watching converted tensorflow model in netron
netron.start("C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/super_resolution.pb/saved_model.pb")

Writing Log directory; Generating event files for tensorboard.

In [None]:
import tensorflow as tf
import sys
from tensorflow.python.platform import gfile

from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat

with tf.compat.v1.Session() as sess:

    model_filename = 'C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/super_resolution.pb/saved_model.pb'
    with gfile.FastGFile(model_filename, 'rb') as f:
        data = compat.as_bytes(f.read())
        sm = saved_model_pb2.SavedModel()
        sm.ParseFromString(data)
        # print(sm)
        if 1 != len(sm.meta_graphs):
            print('More than one graph found. Not sure which to write')
            sys.exit(1)

        # graph_def = tf.GraphDef()
        # graph_def.ParseFromString(sm.meta_graphs[0])
        g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
LOGDIR = 'C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/super_resolution_events/'
tf.compat.v1.disable_eager_execution()
train_writer = tf.compat.v1.summary.FileWriter(LOGDIR)
# train_writer = tf.summary.create_file_writer(LOGDIR)
train_writer.add_graph(sess.graph)


https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/tools/import_pb_to_tensorboard.py

Trying to open tensorboard and plotting .pb file as graph

In [None]:
"""Imports a protobuf model as a graph in Tensorboard."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

from tensorflow.core.framework import graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.summary import summary


def import_to_tensorboard(model_dir, log_dir):
    """View an imported protobuf model (`.pb` file) as a graph in Tensorboard.
    Args:
      model_dir: The location of the protobuf (`pb`) model to visualize
      log_dir: The location for the Tensorboard log to begin visualization from.
    Usage:
      Call this function with your model location and desired log directory.
      Launch Tensorboard by pointing it to the log directory.
      View your imported `.pb` model as a graph.
    """
    with session.Session(graph=ops.Graph()) as sess:
        print("I am here..")
        with gfile.FastGFile(model_dir, "rb") as f:
            print("Yes")
            graph_def = graph_pb2.GraphDef()
            graph_def.ParseFromString(f.read())
            importer.import_graph_def(graph_def)
        print("yes")
        pb_visual_writer = summary.FileWriter(log_dir)
        pb_visual_writer.add_graph(sess.graph)
        print("Model Imported. Visualize by running: "
              "tensorboard --logdir={}".format(log_dir))


# if __name__ == "__main__":
#   parser = argparse.ArgumentParser()
#   parser.register("type", "bool", lambda v: v.lower() == "true")
#   parser.add_argument(
#       "--model_dir",
#       type=str,
#       default="",
#       required=True,
#       help="The location of the protobuf (\'pb\') model to visualize.")
#   parser.add_argument(
#       "--log_dir",
#       type=str,
#       default="",
#       required=True,
#       help="The location for the Tensorboard log to begin visualization from.")
#   FLAGS, unparsed = parser.parse_known_args()
#   app.run(main=main, argv=[sys.argv[0]] + unparsed)


In [None]:
# model_dir = "C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/super_resolution.pb/saved_model.pb"
# log_dir = "C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/super_resolution_events/"
# import_to_tensorboard(model_dir, log_dir)

In [None]:
# Inference using .pb file

In [None]:
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
import logging, os
logging.disable(logging.warning)

# INPUT_TENSOR_NAME = 'input.1:0'
INPUT_TENSOR_NAME = 'serving_default_input.1'
OUTPUT_TENSOR_NAME = 'output:0'
image_path = 'C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/cat_224x224.jpg'
pb_path = "C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/super_resolution.pb/saved_model.pb"

img = Image.open("C:/Users/uib43225/PycharmProjects/DSAlgo/onnx/cat_224x224.jpg")
resize = transforms.Resize([224, 224])
img = resize(img)

# image = cv2.imread(image)
# image = np.dot(image[...,:3], [0.299, 0.587, 0.114])
# image = cv2.resize(image, dsize=(28, 28), interpolation = cv2.INTER_AREA)
# image.resize((1, 1, 28, 28))

with tf.gfile.FastGFile(pb_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name= "")
    
input_tensor = graph.get_tensor_by_name(INPUT_TENSOR_NAME)
output_tensor = graph.get_tensor_by_name(OUTPUT_TENSOR_NAME)

with tf.Session(graph=graph) as sess:
    output_vals = sess.run(output_tensor, feed_dict ={input_tensor:img})
    
print(output_vals)
