In [30]:
import torch
import numpy
import PIL.Image
from tha3.util import resize_PIL_image,extract_PIL_image_from_filelike, extract_pytorch_image_from_PIL_image
from tqdm import tqdm
import onnx
from onnxsim import simplify
import onnxruntime as ort
import numpy as np

MODEL_NAME = "separable_float"
HALF = False
DEVICE_NAME = 'cuda:0'
IMAGE_INPUT = "data\images\crypko_03.png"
USE_RANDOM_IMAGE = False

EYEBROW_DECOMPOSER_ONNX_MODEL_NAME = "eyebrow_decomposer.onnx"
EYEBROW_DECOMPOSER_SIMPLIFIED_ONNX_MODEL_NAME = "eyebrow_decomposer_sim.onnx"
EYEBROW_DECOMPOSER_INPUT_LIST = ['input']
EYEBROW_DECOMPOSER_OUTPUT_LIST = ["eyebrow_layer",  # 0
                                  "eyebrow_layer_alpha",  # 1
                                  "eyebrow_layer_color_change",  # 2
                                  "background_layer_1",  # 3
                                  "background_layer_alpha",  # 4
                                  "background_layer_color_change"]  # 5

EYEBROW_MORPHING_COMBINER_INPUT_LIST = ['eyebrow_background_layer', "eyebrow_layer", 'eyebrow_pose']
EYEBROW_MORPHING_COMBINER_OUTPUT_LIST = ['eyebrow_image',  # 0
                                        'combine_alpha',  # 1
                                        'eyebrow_image_no_combine_alpha',  # 2
                                        'morphed_eyebrow_layer',  # 3
                                        'morphed_eyebrow_layer_alpha',  # 4
                                        'morphed_eyebrow_layer_color_change',  # 5
                                        'warped_eyebrow_layer',  # 6
                                        'morphed_eyebrow_layer_grid_change']  # 7
EYEBROW_POSE_SHAPE = (1, 12)
EYEBROW_MORPHING_COMBINER_ONNX_MODEL_NAME = "eyebrow_morphing_combiner.onnx"
EYEBROW_MORPHING_COMBINER_SIMPLIFIED_ONNX_MODEL_NAME = "eyebrow_morphing_combiner_sim.onnx"

device = torch.device(DEVICE_NAME)
dtype = torch.float16 if HALF else torch.float32

In [2]:
#Prepare models
def load_poser(model: str, device: torch.device):
    print("Using the %s model." % model)
    if model == "standard_float":
        from tha3.poser.modes.standard_float import create_poser
        return create_poser(device)
    elif model == "standard_half":
        from tha3.poser.modes.standard_half import create_poser
        return create_poser(device)
    elif model == "separable_float":
        from tha3.poser.modes.separable_float import create_poser
        return create_poser(device)
    elif model == "separable_half":
        from tha3.poser.modes.separable_half import create_poser
        return create_poser(device)
    else:
        raise RuntimeError("Invalid model: '%s'" % model)
        
poser = load_poser(MODEL_NAME, DEVICE_NAME)
pose_size = poser.get_num_parameters()

eyebrow_decomposer = poser.get_modules()['eyebrow_decomposer']
eyebrow_morphing_combiner = poser.get_modules()['eyebrow_morphing_combiner']
face_morpher = poser.get_modules()['face_morpher']
two_algo_face_body_rotator = poser.get_modules()['two_algo_face_body_rotator']
editor = poser.get_modules()['editor']

Using the separable_float model.
Loading the eyebrow decomposer ... DONE!!!


  return torch.load(f)


Loading the eyebrow morphing conbiner ... DONE!!!
Loading the face morpher ... DONE!!!
Loading the face-body rotator ... DONE!!!
Loading the combiner ... DONE!!!


In [3]:
#Prepare one pass inference image data
pt_img = None
if USE_RANDOM_IMAGE:
    pt_img = torch.rand(1, 4, 512, 512,dtype=dtype, device=device) * 2.0 - 1.0
else:
    pil_image = resize_PIL_image(extract_PIL_image_from_filelike(IMAGE_INPUT), size=(512,512))
    
    if HALF:
        pt_img = extract_pytorch_image_from_PIL_image(pil_image).half().reshape(1,4,512,512).to(DEVICE_NAME)
    else:
        pt_img = extract_pytorch_image_from_PIL_image(pil_image).reshape(1,4,512,512).to(DEVICE_NAME)
zero_pose = torch.zeros(1, pose_size, dtype=dtype, device=device)

In [4]:
#Small bench for whole pose
from time import time
t1 = time()
for i in tqdm(range(10)):
    poser.pose(pt_img, zero_pose)
print(time() - t1)

100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 11.20it/s]

0.8980967998504639





In [5]:
# Work on eyebrow decomposer export and quantize
EYEBROW_DECOMPOSER_INPUT_SHAPE = (1,4,128,128)
EYEBROW_DECOMPOSER_INPUT_IMAGE_SIZE = 128
eyebrow_decomposer_input_img = pt_img[:,:, 64:192, 64 + 128:192 + 128].reshape(EYEBROW_DECOMPOSER_INPUT_SHAPE)
eyebrow_decomposer_input_img_numpy = eyebrow_decomposer_input_img.cpu().numpy()
eyebrow_decomposer_torch_res = eyebrow_decomposer(eyebrow_decomposer_input_img) # Try one round inference to catch problem

In [7]:
#Export onnx model finally get a simplified decomposer onnx model
torch.onnx.export(eyebrow_decomposer,               # model being run
                  eyebrow_decomposer_input_img,                         # model input (or a tuple for multiple inputs)
                  EYEBROW_DECOMPOSER_ONNX_MODEL_NAME,   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=16,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = EYEBROW_DECOMPOSER_INPUT_LIST,   # the model's input names
                  output_names = EYEBROW_DECOMPOSER_OUTPUT_LIST) # the model's output names
onnx_model = onnx.load(EYEBROW_DECOMPOSER_ONNX_MODEL_NAME)
onnx.checker.check_model(onnx_model)
eyebrow_decomposer_onnx_model_sim, check = simplify(onnx_model)
if check:
    onnx.save(eyebrow_decomposer_onnx_model_sim,EYEBROW_DECOMPOSER_SIMPLIFIED_ONNX_MODEL_NAME)
else:
    print("Simplify error!")



In [15]:
# Verify correctness compare to pytorch
ort_sess_decomposer_sim = ort.InferenceSession(EYEBROW_DECOMPOSER_SIMPLIFIED_ONNX_MODEL_NAME)
onnx_sim_output = ort_sess_decomposer_sim.run(EYEBROW_DECOMPOSER_OUTPUT_LIST,{'input': eyebrow_decomposer_input_img_numpy})
for i in range(len(onnx_sim_output)):
    print("MSE is: ",((onnx_sim_output[i] - eyebrow_decomposer_torch_res[i].cpu().detach().numpy()) ** 2).mean())

MSE is:  1.4288554e-09
MSE is:  4.8480964e-10
MSE is:  9.468815e-10
MSE is:  1.0209479e-10
MSE is:  9.383485e-10
MSE is:  1.7238013e-07


In [17]:
# Small bench on cpu to check performance
t1 = time()
for i in tqdm(range(100)):
    ort_sess_decomposer_sim.run(EYEBROW_DECOMPOSER_OUTPUT_LIST,{'input': eyebrow_decomposer_input_img_numpy})
print(time()-t1)
#Since decomposer is not required for every run iteration, no need to quantize

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.41it/s]

4.907040357589722





In [26]:
# Start working on eyebrow morphine combiner
eyebrow_pose_zero = torch.zeros(EYEBROW_POSE_SHAPE, dtype=dtype, device=device)
eyebrow_morphing_combiner_torch_res = eyebrow_morphing_combiner(eyebrow_decomposer_torch_res[3], eyebrow_decomposer_torch_res[0], eyebrow_pose_zero)
input_tuple = (eyebrow_decomposer_torch_res[3], eyebrow_decomposer_torch_res[0], eyebrow_pose_zero)
input_dict = {k:v.cpu().detach().numpy() for k,v in zip(EYEBROW_MORPHING_COMBINER_INPUT_LIST,input_tuple)}
torch.onnx.export(eyebrow_morphing_combiner,               # model being run
                  input_tuple,                         # model input (or a tuple for multiple inputs)
                  EYEBROW_MORPHING_COMBINER_ONNX_MODEL_NAME,   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=16,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = EYEBROW_MORPHING_COMBINER_INPUT_LIST,   # the model's input names
                  output_names = EYEBROW_MORPHING_COMBINER_OUTPUT_LIST) 
onnx_model = onnx.load(EYEBROW_MORPHING_COMBINER_ONNX_MODEL_NAME)
onnx.checker.check_model(onnx_model)
eyebrow_morphing_combiner_onnx_model_sim, check = simplify(onnx_model)
if check:
    onnx.save(eyebrow_morphing_combiner_onnx_model_sim,EYEBROW_MORPHING_COMBINER_SIMPLIFIED_ONNX_MODEL_NAME)
else:
    print("Simplify error!")

# Verify correctness compare to pytorch
ort_sess_sim = ort.InferenceSession(EYEBROW_MORPHING_COMBINER_SIMPLIFIED_ONNX_MODEL_NAME)
onnx_sim_output = ort_sess_sim.run(EYEBROW_MORPHING_COMBINER_OUTPUT_LIST,input_dict)
for i in range(len(onnx_sim_output)):
    print("MSE is: ",((onnx_sim_output[i] - eyebrow_morphing_combiner_torch_res[i].cpu().detach().numpy()) ** 2).mean())

MSE is:  5.700757e-09
MSE is:  7.0313577e-09
MSE is:  2.3814898e-11
MSE is:  5.180266e-11
MSE is:  7.75215e-17
MSE is:  1.320137e-08
MSE is:  5.180364e-11
MSE is:  2.2449425e-09


In [27]:
# Small bench on cpu to check performance
t1 = time()
for i in tqdm(range(100)):
    ort_sess_sim.run(EYEBROW_MORPHING_COMBINER_OUTPUT_LIST,input_dict)
print(time()-t1)
#Since decomposer is not required for every run iteration, no need to quantize


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.84it/s]

4.804630517959595





In [None]:
#Verify quantization for combiner, this function could be called in runtime loop


In [37]:
from thop import profile
poser.training = False
poser.eval = lambda : True
poser.apply = lambda input: poser.pose(self,input)
profile(poser, inputs=(pt_img, zero_pose))

NameError: name 'self' is not defined