In [1]:
import os
import torch
import torch.nn as nn
from diffusers.schedulers import PNDMScheduler
from pathlib import Path
from diffusers import DiffusionPipeline
import onnx
import onnxruntime
from onnxruntime.quantization.quantize import quantize_dynamic
from onnxruntime.quantization import QuantType
from util_onnx import onnx_export
import utils

import gc
gc.collect()

0

In [2]:
device = 'cpu'
dtype = torch.float32
save_path = '../onnx_models_cuda'
os.makedirs(save_path, exist_ok = True)

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32)

---
##### 0. INPUT 준비

In [3]:
input = torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype)
input.shape

torch.Size([1, 4, 64, 64])

--- 
##### 1. Define model

In [4]:
class Decoder(nn.Module):
    def __init__(self, vae_decoder):
        super().__init__()
        self.vae_decoder = vae_decoder.to(device = device)
  
    def denormalize(self, images):      # Denormalize an image array to [0,1].
        return (images / 2 + 0.5).clamp(0, 1)   

    def forward(self, latent_sample):
        # 0. preprocessing
        latent_sample = (1 / 0.18215 * latent_sample)
        # 1. run
        image = self.vae_decoder(latent_sample)['sample']       # [1, 3, 512, 512]
        image = torch.clip(image / 2 + 0.5, 0, 1)
        image = image.permute(0, 2, 3, 1)
        # 2. postprocessing
        # image = self.denormalize(image).permute(0, 2, 3, 1)
        #image = self.pt_to_numpy(image)
         
        return image #self.numpy_to_pil(image)

-----
#### 💛 Conversion

In [5]:
vae_decoder = pipeline.vae
vae_decoder.forward = pipeline.vae.decode
onnx_export(
    Decoder(vae_decoder),
    model_args=(
        input
    ),
    output_path = Path(f'{save_path}/decoder/de_origin.onnx'),
    ordered_input_names=["input"],
    output_names=["output"],  # has to be different from "sample" for correct tracing
    dynamic_axes={ 
        "input": {0: "batch"}
    },
    opset=12,
    use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
)

# # model quantization
quantize_dynamic(
    model_input     =   f'{save_path}/decoder/de_origin.onnx', 
    model_output    =   f'{save_path}/decoder/de_quant.onnx', 
    per_channel     =   False,
    reduce_range    =   False,
    weight_type     =   QuantType.QUInt8,
)

ONNX export Start🚗




ONNX export Finish🍷


In [7]:
!du -sh ../onnx_models_cuda/decoder/**

189M	../onnx_models_cuda/decoder/de_origin.onnx
48M	../onnx_models_cuda/decoder/de_quant.onnx


---
#### 💚 ONNX-Runtime Test

In [8]:
import onnxruntime as ort

# Load the ONNX model
onnx_model_path = f'{save_path}/decoder/de_quant.onnx'
session_de = ort.InferenceSession(onnx_model_path, providers=['AzureExecutionProvider'])

In [9]:
# test running
ort_inputs  = {}
latents = session_de.run(None, {
    'input': input.cpu().detach().numpy()
})[0]
latents.shape

(1, 512, 512, 3)