In [1]:
import torch
import numpy
import PIL.Image
from tha3.util import resize_PIL_image,extract_PIL_image_from_filelike, extract_pytorch_image_from_PIL_image
from tqdm import tqdm
import onnx
from onnxsim import simplify
import onnxruntime as ort
import numpy as np
from torch import Tensor
from torch.nn import Module
from typing import List, Optional
from torch.nn.functional import interpolate
import onnx_tool

MODEL_NAME = "separable_float"
HALF = False
DEVICE_NAME = 'cuda:0'
IMAGE_INPUT = "data\images\crypko_03.png"
USE_RANDOM_IMAGE = False

providers = [("CUDAExecutionProvider", {"device_id": 0, #torch.cuda.current_device(),
                                        "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
sess_options = ort.SessionOptions()

device = torch.device(DEVICE_NAME)
dtype = torch.float16 if HALF else torch.float32


In [2]:
#Prepare models
def load_poser(model: str, device: torch.device):
    print("Using the %s model." % model)
    if model == "standard_float":
        from tha3.poser.modes.standard_float import create_poser
        return create_poser(device)
    elif model == "standard_half":
        from tha3.poser.modes.standard_half import create_poser
        return create_poser(device)
    elif model == "separable_float":
        from tha3.poser.modes.separable_float import create_poser
        return create_poser(device)
    elif model == "separable_half":
        from tha3.poser.modes.separable_half import create_poser
        return create_poser(device)
    else:
        raise RuntimeError("Invalid model: '%s'" % model)
        
poser = load_poser(MODEL_NAME, DEVICE_NAME)
pose_size = poser.get_num_parameters()

eyebrow_decomposer = poser.get_modules()['eyebrow_decomposer']
eyebrow_morphing_combiner = poser.get_modules()['eyebrow_morphing_combiner']
face_morpher = poser.get_modules()['face_morpher']
two_algo_face_body_rotator = poser.get_modules()['two_algo_face_body_rotator']
editor = poser.get_modules()['editor']

Using the separable_float model.
Loading the eyebrow decomposer ... DONE!!!
Loading the eyebrow morphing conbiner ... DONE!!!


  return torch.load(f)


Loading the face morpher ... DONE!!!
Loading the face-body rotator ... DONE!!!
Loading the combiner ... DONE!!!


In [3]:
#Prepare one pass inference image data
pt_img = None
if USE_RANDOM_IMAGE:
    pt_img = torch.rand(1, 4, 512, 512,dtype=dtype, device=device) * 2.0 - 1.0
else:
    pil_image = resize_PIL_image(extract_PIL_image_from_filelike(IMAGE_INPUT), size=(512,512))
    
    if HALF:
        pt_img = extract_pytorch_image_from_PIL_image(pil_image).half().reshape(1,4,512,512).to(DEVICE_NAME)
    else:
        pt_img = extract_pytorch_image_from_PIL_image(pil_image).reshape(1,4,512,512).to(DEVICE_NAME)
zero_pose = torch.zeros(1, pose_size, dtype=dtype, device=device)

poser_torch_res = poser.pose(pt_img, zero_pose)

In [4]:
#Small bench for whole pose
from time import time
t1 = time()
for i in tqdm(range(100)):
    poser.pose(pt_img, zero_pose)
print(time() - t1)

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 27.06it/s]

3.70108699798584





In [5]:
# Work on eyebrow decomposer export and quantize
EYEBROW_DECOMPOSER_ONNX_MODEL_NAME = "eyebrow_decomposer.onnx"
EYEBROW_DECOMPOSER_SIMPLIFIED_ONNX_MODEL_NAME = "eyebrow_decomposer_sim.onnx"
EYEBROW_DECOMPOSER_INPUT_LIST = ['input_image']
EYEBROW_DECOMPOSER_OUTPUT_LIST = ["eyebrow_layer", "background_layer_1"]
EYEBROW_DECOMPOSER_INPUT_SHAPE = (1,4,128,128)
EYEBROW_DECOMPOSER_INPUT_IMAGE_SIZE = 128

eyebrow_decomposer_input_img = pt_img[:,:, 64:192, 64 + 128:192 + 128].reshape(EYEBROW_DECOMPOSER_INPUT_SHAPE)
eyebrow_decomposer_input_img_numpy = eyebrow_decomposer_input_img.cpu().numpy()
eyebrow_decomposer_torch_res = eyebrow_decomposer(eyebrow_decomposer_input_img) # Try one round inference to catch problem

In [6]:
class EyebrowDecomposerWrapper(Module):
    def __init__(self, eyebrow_decomposer_obj):
        super().__init__()
        self.eyebrow_decomposer = eyebrow_decomposer_obj
    def forward(self, image: Tensor, *args) -> List[Tensor]:
        cropped = image[:,:, 64:192, 64 + 128:192 + 128].reshape((1,4,128,128))
        decomposer_res = self.eyebrow_decomposer(cropped)
        return [decomposer_res[0], decomposer_res[3]]
eyebrow_decomposer_wrapper = EyebrowDecomposerWrapper(eyebrow_decomposer)
eyebrow_decomposer_wrapped_torch_res = eyebrow_decomposer_wrapper(pt_img)

In [7]:
#Export onnx model finally get a simplified decomposer onnx model
torch.onnx.export(eyebrow_decomposer_wrapper,               # model being run
                  pt_img,                         # model input (or a tuple for multiple inputs)
                  EYEBROW_DECOMPOSER_ONNX_MODEL_NAME,   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=16,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = EYEBROW_DECOMPOSER_INPUT_LIST,   # the model's input names
                  output_names = EYEBROW_DECOMPOSER_OUTPUT_LIST) # the model's output names
onnx_model = onnx.load(EYEBROW_DECOMPOSER_ONNX_MODEL_NAME)
onnx.checker.check_model(onnx_model)
onnx_model_sim, check = simplify(onnx_model)
if check:
    onnx.save(onnx_model_sim, EYEBROW_DECOMPOSER_SIMPLIFIED_ONNX_MODEL_NAME)
else:
    print("Simplify error!")



In [8]:
input_dict = {k:v.cpu().detach().numpy() for k,v in zip(EYEBROW_DECOMPOSER_INPUT_LIST, (pt_img))}
# Verify correctness compare to pytorch
pt_img_np = pt_img.cpu().detach().numpy()
ort_sess = ort.InferenceSession(EYEBROW_DECOMPOSER_SIMPLIFIED_ONNX_MODEL_NAME, sess_options=sess_options, providers=providers)
onnx_sim_output = ort_sess.run(None, {'input_image':pt_img_np,})
print("MSE is: ",((onnx_sim_output[0] - eyebrow_decomposer_torch_res[0].cpu().detach().numpy()) ** 2).mean())
print("MSE is: ",((onnx_sim_output[1] - eyebrow_decomposer_torch_res[3].cpu().detach().numpy()) ** 2).mean())

MSE is:  1.0019383e-16
MSE is:  4.0434206e-16


In [9]:
onnx_tool.model_profile(EYEBROW_DECOMPOSER_SIMPLIFIED_ONNX_MODEL_NAME, None, None)

Name                                                                                          Type                   Forward_MACs    FPercent    Memory       MPercent    Params     PPercent    InShape        OutShape
--------------------------------------------------------------------------------------------  ---------------------  --------------  ----------  -----------  ----------  ---------  ----------  -------------  -------------
Slice_0                                                                                       Slice                  0               0.00%       262,208      0.26%       8          0.00%       1x4x512x512    1x4x128x128
/eyebrow_decomposer/body/downsample_blocks.0/downsample_blocks.0.0/Conv                       Conv                   589,824         0.04%       262,288      0.26%       36         0.00%       1x4x128x128    1x4x128x128
/eyebrow_decomposer/body/downsample_blocks.0/downsample_blocks.0.1/Conv                       Conv                   4,19

In [10]:
# Small bench on cpu to check performance
t1 = time()
for i in tqdm(range(100)):
    ort_sess.run(None, {'input_image':pt_img_np,})
print(time()-t1)
#Since decomposer is not required for every run iteration, no need to quantize

100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 119.70it/s]

0.838186502456665





In [11]:
# Start working on eyebrow morphine combiner

EYEBROW_MORPHING_COMBINER_INPUT_LIST = ['input_image',"eyebrow_layer",'eyebrow_background_layer',  'eyebrow_pose']
EYEBROW_MORPHING_COMBINER_OUTPUT_LIST = ['eyebrow_image']  # 7
EYEBROW_POSE_SHAPE = (1, 12)
EYEBROW_MORPHING_COMBINER_ONNX_MODEL_NAME = "eyebrow_morphing_combiner.onnx"
EYEBROW_MORPHING_COMBINER_SIMPLIFIED_ONNX_MODEL_NAME = "eyebrow_morphing_combiner_sim.onnx"


eyebrow_pose_zero = torch.zeros(EYEBROW_POSE_SHAPE, dtype=dtype, device=device)
eyebrow_morphing_combiner_torch_res = eyebrow_morphing_combiner(eyebrow_decomposer_wrapped_torch_res[1], 
                                                                eyebrow_decomposer_wrapped_torch_res[0], eyebrow_pose_zero)

#Build a new eyebrow_morphing_combiner that does cropping
class EyebrowMorphingCombinerWrapper(Module):
    def __init__(self, eyebrow_morphing_combiner_obj):
        super().__init__()
        self.eyebrow_morphing_combiner = eyebrow_morphing_combiner_obj
    def forward(self, full_image:Tensor, eyebrow_layer: Tensor, background_layer: Tensor, pose: Tensor, *args) -> Tensor:
        im_morpher_crop = full_image[:, :, 32:32 + 192, (32 + 128):(32 + 192 + 128)].clone()
        im_morpher_crop[:, :, 32:32 + 128, 32:32 + 128] = self.eyebrow_morphing_combiner(background_layer, eyebrow_layer, pose)[2]
        return im_morpher_crop
eyebrow_morphing_combiner_wrapped = EyebrowMorphingCombinerWrapper(eyebrow_morphing_combiner)
eyebrow_morphing_combiner_wrapped_torch_res = eyebrow_morphing_combiner_wrapped(pt_img, eyebrow_decomposer_wrapped_torch_res[0], 
                                                                                eyebrow_decomposer_wrapped_torch_res[1], eyebrow_pose_zero)


input_tuple = (pt_img, eyebrow_decomposer_wrapped_torch_res[0], eyebrow_decomposer_wrapped_torch_res[1], eyebrow_pose_zero)
input_dict = {k:v.cpu().detach().numpy() for k,v in zip(EYEBROW_MORPHING_COMBINER_INPUT_LIST,input_tuple)}
torch.onnx.export(eyebrow_morphing_combiner_wrapped,               # model being run
                  input_tuple,                         # model input (or a tuple for multiple inputs)
                  EYEBROW_MORPHING_COMBINER_ONNX_MODEL_NAME,   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=16,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = EYEBROW_MORPHING_COMBINER_INPUT_LIST,   # the model's input names
                  output_names = EYEBROW_MORPHING_COMBINER_OUTPUT_LIST) 
onnx_model = onnx.load(EYEBROW_MORPHING_COMBINER_ONNX_MODEL_NAME)
onnx.checker.check_model(onnx_model)
eyebrow_morphing_combiner_onnx_model_sim, check = simplify(onnx_model)
if check:
    onnx.save(eyebrow_morphing_combiner_onnx_model_sim,EYEBROW_MORPHING_COMBINER_SIMPLIFIED_ONNX_MODEL_NAME)
else:
    print("Simplify error!")

# Verify correctness compare to pytorch
ort_sess_sim = ort.InferenceSession(EYEBROW_MORPHING_COMBINER_SIMPLIFIED_ONNX_MODEL_NAME, sess_options=sess_options, providers=providers)
onnx_sim_output = ort_sess_sim.run(EYEBROW_MORPHING_COMBINER_OUTPUT_LIST,input_dict)
for i in range(len(onnx_sim_output)):
    print("MSE is: ",((onnx_sim_output[i] - eyebrow_morphing_combiner_wrapped_torch_res[i].cpu().detach().numpy()) ** 2).mean())

  if n == self.last_n and device == self.last_device:


MSE is:  2.068599e-12


In [12]:
onnx_tool.model_profile(EYEBROW_MORPHING_COMBINER_SIMPLIFIED_ONNX_MODEL_NAME, None, None)

Name                                                                                                     Type                   Forward_MACs    FPercent    Memory       MPercent    Params     PPercent    InShape        OutShape
-------------------------------------------------------------------------------------------------------  ---------------------  --------------  ----------  -----------  ----------  ---------  ----------  -------------  -------------
Slice_0                                                                                                  Slice                  0               0.00%       589,888      0.55%       8          0.00%       1x4x512x512    1x4x192x192
/eyebrow_morphing_combiner/Concat                                                                        Concat                 0               0.00%       524,288      0.49%       0          0.00%       1x4x128x128    1x8x128x128
/eyebrow_morphing_combiner/body/Reshape                                      

In [13]:
# Small bench on cpu to check performance
t1 = time()
for i in tqdm(range(100)):
    ort_sess_sim.run(EYEBROW_MORPHING_COMBINER_OUTPUT_LIST,input_dict)
print(time()-t1)
#Since eyebrow morpher is not required for every run iteration, no need to quantize

100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 110.23it/s]

0.9091794490814209





In [14]:
#Play with face morpher
FACE_POSE_SHAPE = (1,27)
face_pose_zero = torch.zeros(FACE_POSE_SHAPE, dtype=dtype, device=device)
face_morpher_torch_res = face_morpher(eyebrow_morphing_combiner_wrapped_torch_res, face_pose_zero)

class FaceMorpherWrapped(Module):
    def __init__(self, face_morpher_obj):
        super().__init__()
        self.face_morpher = face_morpher_obj
    def forward(self, input_image: Tensor, im_morpher_crop: Tensor, face_pose:Tensor,  *args) -> List[Tensor]:
        face_morphed_full = input_image.clone()
        face_morphed_full[:, :, 32:32 + 192, 32 + 128:32 + 192 + 128] = self.face_morpher(im_morpher_crop, face_pose)[0]
        face_morphed_half = interpolate(face_morphed_full, size=(256, 256), mode='bilinear', align_corners=False)
        return [face_morphed_full, face_morphed_half]
face_morpher_wrapped = FaceMorpherWrapped(face_morpher)
face_morpher_wrapped_torch_res = face_morpher_wrapped(pt_img, eyebrow_morphing_combiner_wrapped_torch_res, face_pose_zero) #May be error

FACE_MORPHER_ONNX_MODEL_NAME = 'face_morpher.onnx'
FACE_MORPHER_SIM_ONNX_MODEL_NAME = 'face_morpher_sim.onnx'

FACE_MORPHER_OUTPUT_LIST = ['face_morphed_full', 'face_morphed_half']
FACE_MORPHER_INPUT_LIST = ['image_input', 'im_morpher_crop', 'face_pose']
input_tuple = (pt_img, eyebrow_morphing_combiner_wrapped_torch_res, face_pose_zero)
input_dict = {k:v.cpu().detach().numpy() for k,v in zip(FACE_MORPHER_INPUT_LIST,input_tuple)}

torch.onnx.export(face_morpher_wrapped,               # model being run
                  input_tuple,                         # model input (or a tuple for multiple inputs)
                  FACE_MORPHER_ONNX_MODEL_NAME,   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=16,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = FACE_MORPHER_INPUT_LIST,   # the model's input names
                  output_names = FACE_MORPHER_OUTPUT_LIST) 

onnx_model = onnx.load(FACE_MORPHER_ONNX_MODEL_NAME)
onnx.checker.check_model(onnx_model)
onnx_model_sim, check = simplify(onnx_model)
if check:
    onnx.save(onnx_model_sim,FACE_MORPHER_SIM_ONNX_MODEL_NAME)
else:
    print("Simplify error!")

ort_sess_sim = ort.InferenceSession(FACE_MORPHER_SIM_ONNX_MODEL_NAME, sess_options=sess_options, providers=providers)
onnx_sim_output = ort_sess_sim.run(None,input_dict)
face_morphed_full = pt_img.clone()
face_morphed_full[:, :, 32:32 + 192, 32 + 128:32 + 192 + 128] = face_morpher_torch_res[0]
print("MSE is: ",((onnx_sim_output[0] - face_morphed_full.cpu().detach().numpy()) ** 2).mean())

MSE is:  5.0948877e-11


In [15]:
# Try to split out the Encoder part of the mopher model
FACE_MORPHER_ENCODER = 'face_morpher_sim_encoder.onnx'
onnx.utils.extract_model(FACE_MORPHER_SIM_ONNX_MODEL_NAME, FACE_MORPHER_ENCODER, ['im_morpher_crop'], 
                         ['/face_morpher/body/downsample_blocks.3/downsample_blocks.3.3/Relu_output_0'])
onnx.checker.check_model(onnx.load(FACE_MORPHER_ENCODER))
FACE_MORPHER_NEW = 'face_morpher_new.onnx'
onnx.utils.extract_model(FACE_MORPHER_SIM_ONNX_MODEL_NAME, FACE_MORPHER_NEW, 
                         ['im_morpher_crop','image_input','face_pose',
                          '/face_morpher/body/downsample_blocks.3/downsample_blocks.3.3/Relu_output_0'], 
                         ['face_morphed_full', 'face_morphed_half'])
onnx.checker.check_model(onnx.load(FACE_MORPHER_NEW))

In [16]:
EYEBROW_COMBINER_NEW = 'eyebrow_morphing_combiner_new.onnx'
eyebrow_combiner_model =  onnx.load(EYEBROW_MORPHING_COMBINER_SIMPLIFIED_ONNX_MODEL_NAME)
face_morpher_encoder_model = onnx.load(FACE_MORPHER_ENCODER)

eyebrow_combiner_new_model = onnx.compose.merge_models(
    eyebrow_combiner_model, face_morpher_encoder_model,
    io_map=[("eyebrow_image", "im_morpher_crop")]
)
onnx.save(eyebrow_combiner_new_model, "temp.onnx")
onnx.utils.extract_model("temp.onnx", EYEBROW_COMBINER_NEW, ['input_image', 'eyebrow_background_layer', 'eyebrow_layer', 'eyebrow_pose'], 
                         ['eyebrow_image', '/face_morpher/body/downsample_blocks.3/downsample_blocks.3.3/Relu_output_0'])
onnx.checker.check_model(onnx.load(EYEBROW_COMBINER_NEW))

In [17]:
onnx_tool.model_profile("face_morpher_new.onnx", None, None)

Name                                                                                    Type                   Forward_MACs    FPercent    Memory       MPercent    Params     PPercent    InShape        OutShape
--------------------------------------------------------------------------------------  ---------------------  --------------  ----------  -----------  ----------  ---------  ----------  -------------  -------------
/face_morpher/body/Reshape                                                              Reshape                0               0.00%       140          0.00%       4          0.00%       1x27           1x27x1x1
/face_morpher/body/Tile                                                                 Tile                   0               0.00%       62,240       0.04%       4          0.00%       1x27x1x1       1x27x24x24
/face_morpher/body/Concat                                                               Concat                 0               0.00%       1,241,856 

In [18]:
ROTATION_POSE_SHAPE = (1,6)
rotation_pose_zero = torch.zeros(ROTATION_POSE_SHAPE, dtype=dtype, device=device)
two_algo_face_body_rotator_torch_res = two_algo_face_body_rotator(face_morpher_wrapped_torch_res[1], rotation_pose_zero)
class TwoAlgoFaceBodyRotatorWrapped(Module):
    def __init__(self, two_algo_face_body_rotator_obj):
        super().__init__()
        self.two_algo_face_body_rotator = two_algo_face_body_rotator_obj
    def forward(self, image: Tensor, pose: Tensor, *args) -> List[Tensor]:
        res = self.two_algo_face_body_rotator(image, pose)
        full_warped_image = interpolate(res[1], size=(512, 512), mode='bilinear', align_corners=False)
        full_grid_change = interpolate(res[2], size=(512, 512), mode='bilinear', align_corners=False)
        return [full_warped_image, full_grid_change]
two_algo_face_body_rotator_wrapped = TwoAlgoFaceBodyRotatorWrapped(two_algo_face_body_rotator)
rotator_wrapped_torch_res = two_algo_face_body_rotator_wrapped(face_morpher_wrapped_torch_res[1], rotation_pose_zero)

ROTATOR_ONNX_MODEL_NAME = 'two_algo_face_body_rotator.onnx'
ROTATOR_SIM_ONNX_MODEL_NAME = 'two_algo_face_body_rotator_sim.onnx'

ROTATOR_OUTPUT_LIST = ['full_warped_image', 'full_grid_change']
ROTATOR_INPUT_LIST = ['face_morphed_half', 'rotation_pose']
input_tuple = (face_morpher_wrapped_torch_res[1], rotation_pose_zero)
input_dict = {k:v.cpu().detach().numpy() for k,v in zip(ROTATOR_INPUT_LIST,input_tuple)}

torch.onnx.export(two_algo_face_body_rotator_wrapped,               # model being run
                  input_tuple,                         # model input (or a tuple for multiple inputs)
                  ROTATOR_ONNX_MODEL_NAME,   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=16,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ROTATOR_INPUT_LIST,   # the model's input names
                  output_names = ROTATOR_OUTPUT_LIST) 

onnx_model = onnx.load(ROTATOR_ONNX_MODEL_NAME)
onnx.checker.check_model(onnx_model)
onnx_model_sim, check = simplify(onnx_model)
if check:
    onnx.save(onnx_model_sim,ROTATOR_SIM_ONNX_MODEL_NAME)
else:
    print("Simplify error!")

ort_sess_sim = ort.InferenceSession(ROTATOR_SIM_ONNX_MODEL_NAME, sess_options=sess_options, providers=providers)
onnx_sim_output = ort_sess_sim.run(None,input_dict)
res1 = interpolate(two_algo_face_body_rotator_torch_res[1], size=(512, 512), mode='bilinear', align_corners=False).cpu().detach().numpy()
res2 = interpolate(two_algo_face_body_rotator_torch_res[2], size=(512, 512), mode='bilinear', align_corners=False).cpu().detach().numpy()

print("MSE is: ",((onnx_sim_output[0] - res1) ** 2).mean())
print("MSE is: ",((onnx_sim_output[1] - res2) ** 2).mean())
onnx.checker.check_model(onnx.load(ROTATOR_SIM_ONNX_MODEL_NAME))

MSE is:  1.2770033e-09
MSE is:  2.6083244e-10


In [19]:
onnx_tool.model_profile(ROTATOR_SIM_ONNX_MODEL_NAME, None, None)

Name                                                                                                           Type                   Forward_MACs    FPercent    Memory       MPercent    Params     PPercent    InShape        OutShape
-------------------------------------------------------------------------------------------------------------  ---------------------  --------------  ----------  -----------  ----------  ---------  ----------  -------------  -------------
/two_algo_face_body_rotator/Reshape                                                                            Reshape                0               0.00%       56           0.00%       4          0.00%       1x6            1x6x1x1
/two_algo_face_body_rotator/Tile                                                                               Tile                   0               0.00%       1,572,896    0.36%       4          0.00%       1x6x1x1        1x6x256x256
/two_algo_face_body_rotator/Concat                       

In [20]:
editor_torch_res = editor(face_morpher_wrapped_torch_res[0], 
                          rotator_wrapped_torch_res[0], 
                          rotator_wrapped_torch_res[1], 
                          rotation_pose_zero)
class EditorWrapped(Module):
    def __init__(self, editor_obj):
        super().__init__()
        self.editor = editor_obj
    def forward(self,
                morphed_image: Tensor,
                rotated_warped_image: Tensor,
                rotated_grid_change: Tensor,
                pose: Tensor,
                *args) -> List[Tensor]:
        res = self.editor(morphed_image, rotated_warped_image, rotated_grid_change, pose)[0]
        res = res.reshape(4, 512 * 512).transpose(0,1).reshape(512,512,4)
        return ((res/2.0 + 0.5)*255).clip(0.0, 255.0)
editor_wrapped = EditorWrapped(editor)
editor_wrapped_torch_res = editor_wrapped(face_morpher_wrapped_torch_res[0], 
                                          rotator_wrapped_torch_res[0], 
                                          rotator_wrapped_torch_res[1], 
                                          rotation_pose_zero)

EDITOR_ONNX_MODEL_NAME = 'editor.onnx'
EDITOR_SIM_ONNX_MODEL_NAME = 'editor_sim.onnx'

EDITOR_OUTPUT_LIST = ['result']
EDITOR_INPUT_LIST = ['morphed_image', 'rotated_warped_image','rotated_grid_change','rotation_pose']
input_tuple = (face_morpher_wrapped_torch_res[0], 
                          rotator_wrapped_torch_res[0], 
                          rotator_wrapped_torch_res[1], 
                          rotation_pose_zero)
input_dict = {k:v.cpu().detach().numpy() for k,v in zip(EDITOR_INPUT_LIST,input_tuple)}

torch.onnx.export(editor_wrapped,               # model being run
                  input_tuple,                         # model input (or a tuple for multiple inputs)
                  EDITOR_ONNX_MODEL_NAME,   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=16,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = EDITOR_INPUT_LIST,   # the model's input names
                  output_names = EDITOR_OUTPUT_LIST) 

onnx_model = onnx.load(EDITOR_ONNX_MODEL_NAME)
onnx.checker.check_model(onnx_model)
onnx_model_sim, check = simplify(onnx_model)
if check:
    onnx.save(onnx_model_sim,EDITOR_SIM_ONNX_MODEL_NAME)
else:
    print("Simplify error!")

ort_sess_sim = ort.InferenceSession(EDITOR_SIM_ONNX_MODEL_NAME, sess_options=sess_options, providers=providers)
onnx_sim_output = ort_sess_sim.run(None,input_dict)

In [21]:
poser_res = poser_torch_res[0].reshape(4, 512 * 512).transpose(0,1).reshape(512,512,4)
poser_res = ((poser_res/2.0 + 0.5)*255).clip(0.0, 255.0).cpu().detach().numpy()
print("MSE is: ",((onnx_sim_output - poser_res) ** 2).mean())

MSE is:  8.392428e-06


In [22]:
# So far model sequanse is :
[
    "eyebrow_decomposer_sim.onnx",
    'eyebrow_morphing_combiner_new.onnx',
    'face_morpher_new.onnx',
    'two_algo_face_body_rotator_sim.onnx',
    'editor_sim.onnx'
]

['eyebrow_decomposer_sim.onnx',
 'eyebrow_morphing_combiner_new.onnx',
 'face_morpher_new.onnx',
 'two_algo_face_body_rotator_sim.onnx',
 'editor_sim.onnx']

In [23]:
model = onnx.load(FACE_MORPHER_NEW)
model.graph.input

[name: "im_morpher_crop"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 4
      }
      dim {
        dim_value: 192
      }
      dim {
        dim_value: 192
      }
    }
  }
}
, name: "image_input"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 4
      }
      dim {
        dim_value: 512
      }
      dim {
        dim_value: 512
      }
    }
  }
}
, name: "face_pose"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 27
      }
    }
  }
}
, name: "/face_morpher/body/downsample_blocks.3/downsample_blocks.3.3/Relu_output_0"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 512
      }
      dim {
        dim_value: 24
      }
      dim {
        dim_value: 24
      }
    }
  }
}
]

In [24]:
class RunTest():
    def __init__(self):
        self.decomposer_sess = ort.InferenceSession("eyebrow_decomposer_sim.onnx", sess_options=sess_options, providers=providers)
        self.combiner_sess = ort.InferenceSession("eyebrow_morphing_combiner_new.onnx", sess_options=sess_options, providers=providers)
        self.morpher_sess = ort.InferenceSession("face_morpher_new.onnx", sess_options=sess_options, providers=providers)
        self.rotator_sess = ort.InferenceSession("two_algo_face_body_rotator_sim.onnx", sess_options=sess_options, providers=providers)
        self.editor_sess = ort.InferenceSession("editor_sim.onnx", sess_options=sess_options, providers=providers)
        self.img = np.random.rand(1, 4, 512, 512).astype(np.float32) * 2.0 - 1.0
        self.eyebrow_pose_zero = np.zeros((1,12), dtype=np.float32)
        self.face_pose_zero = np.zeros((1,27), dtype=np.float32)
        self.rotation_pose_zero = np.zeros((1,6), dtype=np.float32)
    def runAll(self):
        decomposer_res = self.decomposer_sess.run(None, {'input_image':self.img,})
        combiner_res = self.combiner_sess.run(None, {'input_image':self.img,
                                                     'eyebrow_background_layer': decomposer_res[0],
                                                     "eyebrow_layer": decomposer_res[1],
                                                     'eyebrow_pose':self.eyebrow_pose_zero,})
        morpher_res = self.morpher_sess.run(None, {'image_input':self.img,
                                                   'im_morpher_crop': combiner_res[0], 
                                                   'face_pose': self.face_pose_zero,
                                                   '/face_morpher/body/downsample_blocks.3/downsample_blocks.3.3/Relu_output_0':combiner_res[1]})
        rotator_res = self.rotator_sess.run(None, {'face_morphed_half':morpher_res[1], 
                                                   'rotation_pose':self.rotation_pose_zero})
        editor_res = self.editor_sess.run(None, {'morphed_image':morpher_res[0],
                                                 'rotated_warped_image':rotator_res[0],
                                                 'rotated_grid_change': rotator_res[1], 
                                                 'rotation_pose':self.rotation_pose_zero})
RunTest().runAll()