In [1]:
import numpy as np
import onnxruntime as ort
import time
from onnxruntime.quantization import QuantFormat, QuantType, quantize_static
import cv2
import os
import matplotlib.pyplot as plt
from onnxruntime.quantization import CalibrationDataReader
from tqdm import tqdm

In [2]:
def _preprocess_images(images_folder: str):
    providers = ['DmlExecutionProvider']
    options = ort.SessionOptions()
    options.enable_mem_pattern = False
    options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
    decomposer_sess = ort.InferenceSession('./preprocessed/decomposer.onnx', sess_options=options, providers=providers)
    batch_filenames = os.listdir(images_folder)
    batch_data = []
    for image_name in tqdm(batch_filenames):
        image_filepath = os.path.join(images_folder, image_name)
        img = cv2.imread(image_filepath, cv2.IMREAD_UNCHANGED)
        if img is None:
            continue
        img = cv2.resize(img, (450,900),interpolation = cv2.INTER_LANCZOS4)
        padding_img = np.zeros((1024,1024,4),np.uint8)
        padding_img[40:40+900,287:287+450,:] = img
        img = cv2.resize(padding_img, (512,512),interpolation = cv2.INTER_LANCZOS4)

        decomposer_res = decomposer_sess.run(None, {'input_image':img})
        for i in range(3): #Repeat 3 times
            eyebrow_pose = np.zeros((1,12), np.float32)
            rand_idx = int(np.random.randint(low = 0, high = 5))
            eyebrow_pose[0, 2 * rand_idx] = np.random.default_rng().uniform(0.0,1.0)
            eyebrow_pose[0, 2 * rand_idx + 1] = np.random.default_rng().uniform(0.0,1.0)
            batch_data.append({'image_prepared':decomposer_res[2], 'eyebrow_background_layer':decomposer_res[0], 'eyebrow_layer':decomposer_res[1], 'eyebrow_pose':eyebrow_pose})
        
    return batch_data

In [3]:
class CombinerDataReader(CalibrationDataReader):
    def __init__(self, calibration_image_folder: str):
        self.enum_data = None

        # Convert image to input data
        self.data_list = _preprocess_images(
            calibration_image_folder
        )
        self.datasize = len(self.data_list)

    def get_next(self):
        if self.enum_data is None:
            self.enum_data = iter(self.data_list)
        return next(self.enum_data, None)

    def rewind(self):
        self.enum_data = None

In [4]:
dr = CombinerDataReader('Z:/ComfyUI-aki-v1.5/output/')

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 298/298 [00:23<00:00, 12.88it/s]


In [5]:
quantize_static(
        './preprocessed/combiner.onnx',
        './quantized/combiner.onnx',
        dr,
        quant_format=QuantFormat.QDQ,
        per_channel=True,
        weight_type=QuantType.QInt8,
        nodes_to_exclude = ['/eyebrow_morphing_combiner/Sub', '/eyebrow_morphing_combiner/Add_2', '/eyebrow_morphing_combiner/Sub_1'],
        extra_options = {
            'ActivationSymmetric':True,
            'QuantizeBias':False
        }
    )



['Reshape', 'Transpose', 'Concat', 'Resize', 'Gather', 'Slice', 'AveragePool', 'Unsqueeze', 'GlobalAveragePool', 'BatchNormalization', 'Add', 'Pad', 'Conv', 'ArgMax', 'LeakyRelu', 'Gemm', 'InstanceNormalization', 'Softmax', 'Clip', 'GatherElements', 'Split', 'MaxPool', 'Relu', 'MatMul', 'EmbedLayerNormalization', 'LayerNormalization', 'Sigmoid', 'ConvTranspose', 'Squeeze', 'Where', 'Mul']
 16
 16
com.microsoft.nchwc 1
ai.onnx.ml 5
ai.onnx.training 1
ai.onnx.preview.training 1
com.microsoft 1
com.microsoft.experimental 1
org.pytorch.aten 1
com.microsoft.dml 1
[domain: ""
version: 16
, domain: ""
version: 16
]


In [6]:
non_quantized_session =ort.InferenceSession('./preprocessed/combiner.onnx', None)
quantized_session =ort.InferenceSession('./quantized/combiner.onnx', None)

In [7]:
dr.rewind()
inp = dr.get_next()
non_res =  non_quantized_session.run(None, inp)
qt_res = quantized_session.run(None, inp)
((qt_res[0] -non_res[0])**2).mean()

np.float32(0.00036856649)