In [1]:
import torch
import torch.nn as nn
from diffusers.schedulers import PNDMScheduler
from pathlib import Path
from diffusers import DiffusionPipeline
from util import onnx_export
from onnxruntime.quantization.quantize import quantize_dynamic
from onnxruntime.quantization import QuantType
import onnx 

import gc
gc.collect()

  from .autonotebook import tqdm as notebook_tqdm


0

In [5]:
device = "cpu"
dtype = torch.float32

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32)

(before)

- 3.3G    weights.pb
- 1.9M    model.onnx

(after)

- 824M    model-quant.onnx

In [6]:
torch.tensor([981, 981]).to(dtype=dtype, device=device).shape

torch.Size([2])

In [7]:
class UnetModelCustom(nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.unet = unet
 
    def forward(self, timestep):
        # dummy = dummy * 1.0
        # timestep = torch.tensor([981]).to(dtype=dtype, device=device)
        sample = torch.randn(2, 4, 64, 64).to(device=device, dtype=dtype)
        encoder_hidden_states = torch.randn(2, 77, 768).to(device=device, dtype=dtype)
        
        return self.unet(sample, timestep, encoder_hidden_states)
        #latents = torch.cat([latents] * 2).to(prompt_embeds.device)       # [2, 4, 64, 64]
        # predict the noise residual
        
        # # perform guidance 
        # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        # noise_pred = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond)

        # return noise_pred
onnx_export(
        UnetModelCustom(pipeline.unet),
        model_args=(
            torch.randn(2).to(device=device, dtype=dtype)
            # torch.randn(2).to(device=device, dtype=dtype),
        ),
        output_path = Path('../onnx-models/UNet-raw-dummy/model.onnx'),
        ordered_input_names=["timestep"],
        output_names=["out_sample"],  # has to be different from "sample" for correct tracing
        dynamic_axes={
            "timestep": {0: "batch"}
            # "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
            # "timestep": {0: "batch"},
            # "encoder_hidden_states": {0: "batch", 1: "sequence"},
        },
        opset=14,
        use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
    )


ONNX export Start🚗


  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


ONNX export Finish🍷


In [8]:
import shutil, os

unet_model_path = str(Path('../onnx-models/UNet-raw-dummy/model.onnx').absolute().as_posix())
unet_dir = os.path.dirname(unet_model_path)
unet = onnx.load(unet_model_path)
# clean up existing tensor files
shutil.rmtree(unet_dir)
os.mkdir(unet_dir)
# collate external tensor files into one
onnx.save_model(
    unet,
    unet_model_path,
    save_as_external_data=True,
    all_tensors_to_one_file=True,
    location="weights.pb",
    convert_attribute=False,
)

In [10]:
quantize_dynamic(
        model_input = f'../onnx-models/UNet-raw-dummy/model.onnx',
        model_output=f'../onnx-models//UNet-raw-dummy-quant.onnx',
        per_channel=False,
        reduce_range=False,
        weight_type=QuantType.QUInt8,
)



In [11]:
import onnxruntime
import numpy as np

unetSession = onnxruntime.InferenceSession(f'../onnx-models/UNet-raw-dummy-quant.onnx', providers=['AzureExecutionProvider'])
[i.name for i in unetSession.get_inputs()]

['timestep']

In [19]:
# 테스트 실행
ort_inputs  = {
    'sample': np.array(torch.randn(2, 4, 64, 64).to(device=device, dtype=dtype)),
    #'timestep': np.array(torch.randn(1).to(device=device, dtype=dtype)),
    'encoder_hidden_states': np.array(torch.randn(2, 77, 768).to(device=device, dtype=dtype))
}
#print(ort_inputs)
ort_outputs = unetSession.run(None, ort_inputs)

In [15]:
ort_outputs[0].shape

(2, 4, 64, 64)

In [21]:
np.array(torch.randn(1).to(device=device, dtype=dtype))

array([0.16983595], dtype=float32)

---
### Raw한 Unet 모델의 후처리를 위한, 모듈 추가

In [30]:
class UnetModel_post(nn.Module):
    def __init__(self):
        super().__init__()
 
    def forward(self, noise_pred):
        #latents = torch.cat([latents] * 2).to(prompt_embeds.device)       # [2, 4, 64, 64]
        # predict the noise residual
        
        # # perform guidance 
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond)

        return noise_pred

onnx_export(
    UnetModel_post(),
    model_args=(
        torch.randn([2,4,64,64]).to(device=device, dtype=dtype)
    ),
    output_path = Path('../onnx-models/UNet_post/model.onnx'),
    ordered_input_names=["noise_pred"],
    output_names=["noise_pred_out"],  # has to be different from "sample" for correct tracing
    dynamic_axes={ 
        "noise_pred": {0: "batch"}
    },
    opset=12,
    use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
)

quantize_dynamic(
        model_input = f'/root/ONNX-Models/projects/stabel_diffusion/onnx-models/UNet_post/model.onnx',
        model_output=f'/root/ONNX-Models/projects/stabel_diffusion/onnx-models/UNet_post/model-quant.onnx',
        per_channel=False,
        reduce_range=False,
        weight_type=QuantType.QUInt8,
    )



ONNX export Start🚗
ONNX export Finish🍷
