In [None]:
import matplotlib.pyplot as plt
from PIL import Image
from tool.config import Cfg
from tool.translate import build_model, process_input, translate
import torch
import onnxruntime
import numpy as np

In [None]:
config = Cfg.load_config_from_file('/home/viethq/vocr/vietnamese_ocr_engine/lib/text_recognition/vietocr/config.yml')
config['cnn']['pretrained']=False
config['device'] = 'cpu'
model, vocab = build_model(config)
weight_path = '/home/viethq/vocr/vietnamese_ocr_engine/trained_model/text_recognition/vgg_seq2seq_fix.pth'

In [None]:
# load weight
model.load_state_dict(torch.load(weight_path, map_location=torch.device(config['device'])))
model = model.eval() 

## Export CNN part

In [None]:
def convert_cnn_part(img, save_path, model, max_seq_length=128, sos_token=1, eos_token=2): 
    with torch.no_grad(): 
        src = model.cnn(img)
        torch.onnx.export(model.cnn, img, save_path, export_params=True, 
                        opset_version=12, do_constant_folding=True, verbose=True, 
                        input_names=['img'], output_names=['output'], 
                        dynamic_axes={'img': {0: 'batch', 1: 'channel', 2:'height', 3: 'width'}, 
                                        'output': {0: 'channel', 1: 'batch'}})
    
    return src

In [None]:
img = torch.rand(1, 3, 32, 475)
src = convert_cnn_part(img, './weight/cnn.onnx', model)

## Export encoder part

In [None]:
def convert_encoder_part(model, src, save_path): 
    encoder_outputs, hidden = model.transformer.encoder(src) 
    torch.onnx.export(model.transformer.encoder, src, save_path, export_params=True, 
                    opset_version=11, do_constant_folding=True, input_names=['src'], 
                    output_names=['encoder_outputs', 'hidden'], 
                    dynamic_axes={'src':{0: "channel_input", 1:"batch"}, 
                                    'encoder_outputs': {0: 'channel_output', 1:'batch'},
                                    'hidden': {0: 'batch'}}) 
    return hidden, encoder_outputs

In [None]:
hidden, encoder_outputs = convert_encoder_part(model, src, './weight/encoder.onnx')

## Export decoder part

In [None]:
def convert_decoder_part(model, tgt, hidden, encoder_outputs, save_path):
    tgt = tgt[-1]
    
    torch.onnx.export(model.transformer.decoder,
        (tgt, hidden, encoder_outputs),
        save_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['tgt', 'hidden', 'encoder_outputs'],
        output_names=['output', 'hidden_out', 'last'],
        dynamic_axes={'tgt': {0:'batch'},
                    'encoder_outputs':{0: "channel_input", 1:'batch'},
                    'hidden': {0: 'batch'},
                    'output': {0:'batch'},
                    'hidden_out': {0 : 'batch'},
                    'last': {0: 'batch'}})

In [None]:

tgt = torch.CharTensor([[1] * len(img)])

In [None]:
convert_decoder_part(model, tgt, hidden, encoder_outputs, './weight/decoder.onnx')

## Load and check model

In [1]:
import onnx

In [3]:
cnn = onnx.load('./weight/cnn.onnx')
decoder = onnx.load('./weight/encoder.onnx')
# encoder = onnx.load('./weight/decoder.onnx')

In [4]:
# confirm model has valid schema
onnx.checker.check_model(cnn)
onnx.checker.check_model(decoder)
# onnx.checker.check_model(encoder)

In [5]:
# # Print a human readable representation of the graph
onnx.helper.printable_graph(decoder.graph)

"graph torch_jit (\n  %src[FLOAT, channel_inputxbatchx256]\n) initializers (\n  %fc.weight[FLOAT, 256x512]\n  %fc.bias[FLOAT, 256]\n  %onnx::Concat_149[INT64, 1]\n  %onnx::Concat_150[INT64, 1]\n  %onnx::GRU_191[FLOAT, 2x1536]\n  %onnx::GRU_192[FLOAT, 2x768x256]\n  %onnx::GRU_193[FLOAT, 2x768x256]\n) {\n  %onnx::Gather_11 = Shape(%src)\n  %onnx::Gather_12 = Constant[value = <Scalar Tensor []>]()\n  %onnx::Unsqueeze_13 = Gather[axis = 0](%onnx::Gather_11, %onnx::Gather_12)\n  %onnx::Concat_17 = Unsqueeze[axes = [0]](%onnx::Unsqueeze_13)\n  %onnx::ConstantOfShape_19 = Concat[axis = 0](%onnx::Concat_149, %onnx::Concat_17, %onnx::Concat_150)\n  %hidden.1 = ConstantOfShape[value = <Tensor>](%onnx::ConstantOfShape_19)\n  %onnx::Transpose_137, %onnx::Gather_138 = GRU[direction = 'bidirectional', hidden_size = 256, linear_before_reset = 1](%src, %onnx::GRU_192, %onnx::GRU_193, %onnx::GRU_191, %, %hidden.1)\n  %onnx::Reshape_139 = Transpose[perm = [0, 2, 1, 3]](%onnx::Transpose_137)\n  %onnx::Re

## Inference directly

In [None]:
img = Image.open('/home/viethq/Downloads/test/3.jpg')
img = process_input(img, config['dataset']['image_height'], 
                config['dataset']['image_min_width'], config['dataset']['image_max_width'])  
img = img.to(config['device'])

In [None]:
s = translate(img, model)[0].tolist()
s = vocab.decode(s)
s

## Inference with ONNX Runtime's Python API

In [None]:
# create inference session
cnn_session = onnxruntime.InferenceSession("./weight/cnn.onnx")
encoder_session = onnxruntime.InferenceSession("./weight/encoder.onnx")
decoder_session = onnxruntime.InferenceSession("./weight/decoder.onnx")

In [None]:
def translate_onnx(img, session, max_seq_length=128, sos_token=1, eos_token=2):
    """data: BxCxHxW"""
    cnn_session, encoder_session, decoder_session = session
    
    # create cnn input
    cnn_input = {cnn_session.get_inputs()[0].name: img}
    src = cnn_session.run(None, cnn_input)
    
    # create encoder input
    encoder_input = {encoder_session.get_inputs()[0].name: src[0]}
    encoder_outputs, hidden = encoder_session.run(None, encoder_input)
    translated_sentence = [[sos_token] * len(img)]
    max_length = 0

    while max_length <= max_seq_length and not all(
        np.any(np.asarray(translated_sentence).T == eos_token, axis=1)
    ):
        tgt_inp = translated_sentence
        decoder_input = {decoder_session.get_inputs()[0].name: tgt_inp[-1], decoder_session.get_inputs()[1].name: hidden, decoder_session.get_inputs()[2].name: encoder_outputs}

        output, hidden, _ = decoder_session.run(None, decoder_input)
        output = np.expand_dims(output, axis=1)
        output = torch.Tensor(output)

        values, indices = torch.topk(output, 1)
        indices = indices[:, -1, 0]
        indices = indices.tolist()

        translated_sentence.append(indices)
        max_length += 1

        del output

    translated_sentence = np.asarray(translated_sentence).T

    return translated_sentence

In [None]:
session = (cnn_session, encoder_session, decoder_session)
s = translate_onnx(np.array(img), session)[0].tolist()
s = vocab.decode(s)
s

In [None]:
from lib.config.settings import SIMPLE_MODEL_PATH, USE_GPU
from lib.text_recognition_v2.aster_pytorch.demo import create_model, batch_prediction, prediction
import os
import torch
class TextRecognitor:

    def __init__(self):
        super(TextRecognitor,self).__init__()
        self.text_recogition = create_model(resume=SIMPLE_MODEL_PATH, decoder_sdim=50, attDim=64, use_cuda=USE_GPU)
    def recognition(self, output_after_text_detect, debug=False, log_dir=None):
        line_list, _, _, list_ids = output_after_text_detect
        kq, sentences, text_blocks = batch_prediction(line_list=line_list, model=self.text_recogition, list_ids=list_ids, batch_size=8)
        return kq, sentences, text_blocks

def onnx_model(model, data_input, model_name="text_recognizer_faster.onnx", logs_dir = "./trained_model/text_recognition"):
    # Export the model
    os.makedirs(logs_dir, exist_ok=True)
    torch.onnx.export(model,               # model being run
                    data_input, # model input (or a tuple for multiple inputs)
                    os.path.join(logs_dir, model_name),   # where to save the model (can be a file or file-like object)
                    export_params=True,        # store the trained parameter weights inside the model file
                    opset_version=11,          # the ONNX version to export the model to
                    do_constant_folding=True,  # whether to execute constant folding for optimization
                    input_names = ['img'],   # the model's input names
                    output_names = ['outputs', 'probs'], # the model's output names
                    dynamic_axes={'img': {0: 'batch', 1: 'channel', 2:'height', 3: 'width'}, 
                                    'outputs' : {0 : 'batch_size', 1: 'sequence_len'},
                                    'probs' : {0 : 'batch_size', 1: 'sequence_len'},
                                    
                                }
                    )
if __name__ == '__main__':
    model = create_model(resume=SIMPLE_MODEL_PATH, decoder_sdim=50, attDim=64, use_cuda=USE_GPU)
    data_input = torch.rand(1, 3, 32, 475)
    onnx_model(model, data_input)