In [2]:
import os
import torch
import torch.nn as nn
from diffusers.schedulers import PNDMScheduler
from pathlib import Path
from diffusers import DiffusionPipeline
import onnx
import onnxruntime
from onnxruntime.quantization.quantize import quantize_dynamic
from onnxruntime.quantization import QuantType

from util_onnx import onnx_export
import utils

import gc

gc.collect()

0

In [3]:
device = 'cuda'
dtype = torch.float32
save_path = '../onnx_models_cuda'
os.makedirs(save_path, exist_ok = True)

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32)

---
##### 0. INPUT 준비

In [4]:
import onnxruntime as ort

onnx_model_path = f'{save_path}/tokenizer/to_quant.onnx'
sessTokenizer = ort.InferenceSession(onnx_model_path, providers=['AzureExecutionProvider'])
onnx_model_path = f'{save_path}/text_encoder/te_quant.onnx'
sessionTextEncoder = ort.InferenceSession(onnx_model_path, providers=['AzureExecutionProvider'])

ascii_str   = utils.toAsciiTensor()
text_ids = sessTokenizer.run(None, {
    'input' : ascii_str.detach().cpu().numpy()
})[0]
ort_output = sessionTextEncoder.run(None, {
    'input' : text_ids
})[0]
ort_output.shape
# text_ids = torch.tensor(input).to(device = device)
# text_ids

(2, 77, 768)

--- 
##### 1. Define model

In [25]:
pipeline.scheduler.scale_model_input

<bound method PNDMScheduler.scale_model_input of PNDMScheduler {
  "_class_name": "PNDMScheduler",
  "_diffusers_version": "0.18.0",
  "beta_end": 0.012,
  "beta_schedule": "scaled_linear",
  "beta_start": 0.00085,
  "clip_sample": false,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "set_alpha_to_one": false,
  "skip_prk_steps": true,
  "steps_offset": 1,
  "timestep_spacing": "leading",
  "trained_betas": null
}
>

In [26]:
class UnetModel_pre(nn.Module):
    def __init__(self, scheduler):
        super().__init__()
        self.scheduler = scheduler
    
    def forward(self, latents):
        # expand the latents if we are doing classifier free guidance
        return torch.concat([latents] * 2)
        
onnx_export(
    UnetModel_pre(pipeline.scheduler),
    model_args=(
        torch.randn([1, 4, 64, 64]).to(device = device, dtype = dtype),
        # torch.tensor([981]).to(device = device, dtype = dtype),
    ),
    output_path = Path(f'{save_path}/unet/upre_origin.onnx'),
    ordered_input_names=[
        "latents", 
        # 'timestep'
    ],
    output_names=["output"],  # has to be different from "sample" for correct tracing
    dynamic_axes={ 
        "latents": {0: "batch"},
        # "timestep": {0: "batch"},
    },
    opset=12,
    use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
)
# model quantization
quantize_dynamic(
    model_input     =   f'{save_path}/unet/upre_origin.onnx', 
    model_output    =   f'{save_path}/unet/upre_quant.onnx', 
    per_channel     =   False,
    reduce_range    =   False,
    weight_type     =   QuantType.QUInt8,
)



ONNX export Start🚗
ONNX export Finish🍷


In [27]:
class UNet_post(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(
        self,
        sample: torch.FloatTensor                   # [2, 320, 64, 64]
    ):
        # perform guidance -guidance_scale : 7.5
        noise_pred_uncond, noise_pred_text = sample.chunk(2)
        sample = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond)
        return sample

onnx_export(
    UNet_post(),
    model_args=(
        torch.randn([2, 4, 64, 64]).to(device = device, dtype = dtype),
    ),
    output_path = Path(f'{save_path}/unet/upost_final_origin.onnx'),
    ordered_input_names=[
        "sample", 
        # 'timestep'
    ],
    output_names=["output"],  # has to be different from "sample" for correct tracing
    dynamic_axes={ 
        "sample": {0: "batch"},
        # "timestep": {0: "batch"},
    },
    opset=12,
    use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
)
# model quantization
quantize_dynamic(
    model_input     =   f'{save_path}/unet/upost_final_origin.onnx', 
    model_output    =   f'{save_path}/unet/upost_final_quant.onnx', 
    per_channel     =   False,
    reduce_range    =   False,
    weight_type     =   QuantType.QUInt8,
)



ONNX export Start🚗
ONNX export Finish🍷


In [13]:
class UNet_post_process(nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.conv_norm_out  = unet.conv_norm_out.to(device = device)    # (32, 320)
        self.conv_act       = unet.conv_act.to(device = device)         # SiLU()
        self.conv_out       = unet.conv_out.to(device = device)         # (320, 4)
        
    def forward(
        self,
        sample: torch.FloatTensor                   # [2, 320, 64, 64]
    ):
        #* 6. post-process
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)
        # perform guidance -guidance_scale : 7.5
        noise_pred_uncond, noise_pred_text = sample.chunk(2)
        sample = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond)
        return sample

In [6]:
class UNet2DConditionModel_Down(nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.num_upsamplers = unet.num_upsamplers
        self.time_proj      = unet.time_proj.to(device = device)
        self.time_embedding = unet.time_embedding.to(device = device)
        self.conv_in        = unet.conv_in.to(device = device)
        self.down_blocks    = unet.down_blocks.to(device = device)
        self.mid_block      = unet.mid_block.to(device = device)
        
    def forward(
        self,
        sample: torch.FloatTensor,
        timestep,
        encoder_hidden_states: torch.Tensor,
    ):

        #* 1. time -  tensor([981])
        timestep = timestep.expand(sample.shape[0])
        t_emb = self.time_proj(timestep).to(dtype=sample.dtype)
        emb = self.time_embedding(t_emb, None)
        
        #* 2. pre-process
        sample = self.conv_in(sample).to(device = sample.device)   # 2 320 64 64
        
        #* 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
                sample, res_samples = downsample_block(
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=None,
                    cross_attention_kwargs=None,
                    encoder_attention_mask=None,
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
            down_block_res_samples += res_samples
        
        
        return sample, emb, \
            down_block_res_samples[0], down_block_res_samples[1],down_block_res_samples[2],\
            down_block_res_samples[3], down_block_res_samples[4], down_block_res_samples[5], \
            down_block_res_samples[6], down_block_res_samples[7], down_block_res_samples[8], \
            down_block_res_samples[9], down_block_res_samples[10], down_block_res_samples[11]

In [26]:
class UNet2DConditionModel_Mid(nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.mid_block = unet.mid_block
        
    def forward(
        self,
        sample: torch.FloatTensor,
        emb: torch.FloatTensor,
        encoder_hidden_states: torch.Tensor,
    ):
        # 4. mid
        if self.mid_block is not None:
            sample = self.mid_block(
                sample,
                emb,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=None,
                cross_attention_kwargs=None,
                encoder_attention_mask=None,
            )
        return sample

In [35]:
class UNet2DConditionModel_Up(nn.Module):
    def __init__(self, unet_up_ith_block):
        super().__init__()
        self.upsample_block = unet_up_ith_block.to(device = device)
        
    def forward(
        self,
        sample: torch.FloatTensor,
        emb: torch.FloatTensor,
        encoder_hidden_states: torch.Tensor,
        down_block_res_samples0,
        down_block_res_samples1,
        down_block_res_samples2,
    ):
        res_samples = [
            down_block_res_samples0, down_block_res_samples1, down_block_res_samples2
        ]
        if hasattr(self.upsample_block, "has_cross_attention") and self.upsample_block.has_cross_attention:
            sample = self.upsample_block(
                hidden_states=sample,
                temb=emb,
                res_hidden_states_tuple=res_samples,
                encoder_hidden_states=encoder_hidden_states,
                cross_attention_kwargs=None,
                upsample_size=None,
                attention_mask=None,
                encoder_attention_mask=None,
            )
        else:
            sample = self.upsample_block(
                hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=None
            )
        return sample

In [39]:
class UNet2DConditionModel_Up_3(nn.Module):
    def __init__(self, resnets, attentions):
        super().__init__()
        self.resnet = resnets.to(device = device)
        self.attention = attentions.to(device = device)
    
    def forward(
        self,
        sample: torch.FloatTensor,
        emb: torch.FloatTensor,
        encoder_hidden_states: torch.Tensor,
        res_hidden_states,
    ):
        sample = torch.cat([sample, res_hidden_states], dim=1)
        sample = self.resnet(sample, emb)
        sample = self.attention(
            sample,
            encoder_hidden_states=encoder_hidden_states,
            cross_attention_kwargs=None,
            attention_mask=None,
            encoder_attention_mask=None,
            return_dict=False,
        )[0]
        # print(sample.shape)
        return sample

-----
#### 💛 Conversion

In [14]:

onnx_export(
    UNet_post_process(pipeline.unet),
    model_args=(
        torch.randn([2, 320, 64, 64]).to(device = device)
    ),
    output_path = Path(f'{save_path}/unet/upost_origin.onnx'),
    ordered_input_names=["input"],
    output_names=["output"],  # has to be different from "sample" for correct tracing
    dynamic_axes={ 
        "input": {0: "batch"}
    },
    opset=12,
    use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
)



quantize_dynamic(
    model_input     =   f'{save_path}/unet/upost_origin.onnx', 
    model_output    =   f'{save_path}/unet/upost_quant.onnx', 
    per_channel     =   False,
    reduce_range    =   False,
    weight_type     =   QuantType.QUInt8,
)

ONNX export Start🚗




ONNX export Finish🍷


In [8]:
onnx_export(
        UNet2DConditionModel_Down(pipeline.unet),
        model_args=(
            torch.randn([2, 4, 64, 64]).to(device=device, dtype=dtype),
            torch.randn([1]).to(device=device, dtype=dtype),
            torch.randn([2, 77, 768]).to(device=device, dtype=dtype),
        ),
        output_path = Path(f'{save_path}/unet/udown_origin.onnx'),
        ordered_input_names=["sample", "timestep", "encoder_hidden_states"],
        output_names=[
            "out_sample", "emb",
            'down_block_res_samples_0', 'down_block_res_samples_1', 'down_block_res_samples_2',
            'down_block_res_samples_3', 'down_block_res_samples_4', 'down_block_res_samples_5',
            'down_block_res_samples_6', 'down_block_res_samples_7', 'down_block_res_samples_8',
            'down_block_res_samples_9', 'down_block_res_samples_10', 'down_block_res_samples_11',
        ],  # has to be different from "sample" for correct tracing
        dynamic_axes={
            "timestep": {0: "batch"}
            # "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
            # "timestep": {0: "batch"},
            # "encoder_hidden_states": {0: "batch", 1: "sequence"},
        },
        opset=14,
        use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
    )

quantize_dynamic(
    model_input     =   f'{save_path}/unet/udown_origin.onnx', 
    model_output    =   f'{save_path}/unet/udown_quant.onnx', 
    per_channel     =   False,
    reduce_range    =   False,
    weight_type     =   QuantType.QUInt8,
)

ONNX export Start🚗
ONNX export Finish🍷




In [27]:
onnx_export(
        UNet2DConditionModel_Mid(pipeline.unet),
        model_args=(
            torch.randn([2, 1280, 8, 8]).to(device=device, dtype=dtype),
            torch.randn(2, 1280).to(device=device, dtype=dtype),
            torch.randn([2, 77, 768]).to(device=device, dtype=dtype),
        ),
        output_path = Path(f'{save_path}/unet/umid_origin.onnx'),
        ordered_input_names=["sample", "emb", "encoder_hidden_states"],
        output_names=["out_sample"],  # has to be different from "sample" for correct tracing
        dynamic_axes={
            "timestep": {0: "batch"}
            # "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
            # "timestep": {0: "batch"},
            # "encoder_hidden_states": {0: "batch", 1: "sequence"},
        },
        opset=14,
        use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
    )

quantize_dynamic(
    model_input     =   f'{save_path}/unet/umid_origin.onnx', 
    model_output    =   f'{save_path}/unet/umid_quant.onnx', 
    per_channel     =   False,
    reduce_range    =   False,
    weight_type     =   QuantType.QUInt8,
)

ONNX export Start🚗




ONNX export Finish🍷




In [37]:
down_block_res_samples = [
    torch.randn([2, 320, 64, 64]).to(device=device, dtype=dtype),
    torch.randn([2, 320, 64, 64]).to(device=device, dtype=dtype),
    torch.randn([2, 320, 64, 64]).to(device=device, dtype=dtype),
    
    torch.randn([2, 320, 32, 32]).to(device=device, dtype=dtype),
    torch.randn([2, 640, 32, 32]).to(device=device, dtype=dtype),
    torch.randn([2, 640, 32, 32]).to(device=device, dtype=dtype),
    
    torch.randn([2, 640, 16, 16]).to(device=device, dtype=dtype),
    torch.randn([2, 1280, 16, 16]).to(device=device, dtype=dtype),
    torch.randn([2, 1280, 16, 16]).to(device=device, dtype=dtype),
            
    torch.randn([2, 1280, 8, 8]).to(device=device, dtype=dtype),
    torch.randn([2, 1280, 8, 8]).to(device=device, dtype=dtype),
    torch.randn([2, 1280, 8, 8]).to(device=device, dtype=dtype)   
]

pipeline.unet.ir_version = 11
for i in range(4):
    res_samples = down_block_res_samples[-3 :]
    down_block_res_samples = down_block_res_samples[: -3]
    if i == 3: continue
    onnx_export(
            UNet2DConditionModel_Up(pipeline.unet.up_blocks[i]),
            model_args=(
                torch.randn([2, 1280 if i!=3 else 640, 8*(2**(i)), 8*(2**(i))]).to(device=device, dtype=dtype),
                torch.randn([2, 1280]).to(device=device, dtype=dtype),
                torch.tensor(ort_output).to(device=device, dtype=dtype), #torch.randn([2, 77, 768]).to(device=device, dtype=dtype),
                res_samples[0], res_samples[1], res_samples[2]
            ), 
            output_path = Path(f'{save_path}/unet/uup-{i}_origin.onnx'),
            ordered_input_names=[
                "sample", "emb", "encoder_hidden_states", 
                "res_samples0", "res_samples1", "res_samples2",
            ],
            output_names=["out_sample"],  # has to be different from "sample" for correct tracing
            dynamic_axes={
                "sample": {0:"b"},
                "emb": {1: "sequence"},
                
            },
            opset=11,
            use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
    )
    quantize_dynamic(
        model_input     =   f'{save_path}/unet/uup-{i}_origin.onnx', 
        model_output    =   f'{save_path}/unet/uup-{i}_quant.onnx', 
        per_channel     =   False,
        reduce_range    =   False,
        weight_type     =   QuantType.QUInt8,
    )

ONNX export Start🚗
ONNX export Finish🍷




ONNX export Start🚗
ONNX export Finish🍷




ONNX export Start🚗
ONNX export Finish🍷




In [41]:
res_hidden_states_tuple = [          
    torch.randn([2, 320, 64, 64]).to(device=device, dtype=dtype),
    torch.randn([2, 320, 64, 64]).to(device=device, dtype=dtype),
    torch.randn([2, 320, 64, 64]).to(device=device, dtype=dtype),
]
pipeline.unet.ir_version = 11
for i in range(3):
    # pop res hidden states
    res_hidden_states = res_hidden_states_tuple[-1]
    res_hidden_states_tuple = res_hidden_states_tuple[:-1]
    onnx_export(
        UNet2DConditionModel_Up_3(
            pipeline.unet.up_blocks[-1].resnets[i],
            pipeline.unet.up_blocks[-1].attentions[i]
        ),
        model_args=(
            torch.randn([2, 640 if i == 0 else 320, 64, 64]).to(device=device, dtype=dtype),
            torch.randn([2, 1280]).to(device=device, dtype=dtype),
            torch.tensor(ort_output).to(device=device, dtype=dtype),
            res_hidden_states, 
        ),
        output_path = Path(f'{save_path}/unet/uup-{3}-{i}_origin.onnx'),
        ordered_input_names=[
            "sample", "emb", "encoder_hidden_states", 
            "res_samples0"
        ],
        output_names=["out_sample"],  # has to be different from "sample" for correct tracing
        dynamic_axes={
            "sample": {0:"b"},        
        },
        opset=11,
        use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
    )
    quantize_dynamic(
        model_input     =   f'{save_path}/unet/uup-{3}-{i}_origin.onnx',
        model_output    =   f'{save_path}/unet/uup-{3}-{i}_quant.onnx',
        per_channel     =   False,
        reduce_range    =   False,
        weight_type     =   QuantType.QUInt8,
    )

ONNX export Start🚗
ONNX export Finish🍷




ONNX export Start🚗
ONNX export Finish🍷




ONNX export Start🚗
ONNX export Finish🍷




In [43]:
!du -sh ../onnx_models_cuda/unet/**_origin.onnx

961M	../onnx_models_cuda/unet/udown_origin.onnx
371M	../onnx_models_cuda/unet/umid_origin.onnx
4.0K	../onnx_models_cuda/unet/upost_origin.onnx
4.0K	../onnx_models_cuda/unet/upre_origin.onnx
619M	../onnx_models_cuda/unet/uup-0_origin.onnx
1002M	../onnx_models_cuda/unet/uup-1_origin.onnx
479M	../onnx_models_cuda/unet/uup-2_origin.onnx
1.1G	../onnx_models_cuda/unet/uup-3-0_origin.onnx
1.1G	../onnx_models_cuda/unet/uup-3-1_origin.onnx
1.1G	../onnx_models_cuda/unet/uup-3-2_origin.onnx


In [44]:
!du -sh ../onnx_models_cuda/unet/**_quant.onnx

242M	../onnx_models_cuda/unet/udown_quant.onnx
94M	../onnx_models_cuda/unet/umid_quant.onnx
4.0K	../onnx_models_cuda/unet/upost_quant.onnx
4.0K	../onnx_models_cuda/unet/upre_quant.onnx
155M	../onnx_models_cuda/unet/uup-0_quant.onnx
263M	../onnx_models_cuda/unet/uup-1_quant.onnx
275M	../onnx_models_cuda/unet/uup-2_quant.onnx
1.1G	../onnx_models_cuda/unet/uup-3-0_quant.onnx
1.1G	../onnx_models_cuda/unet/uup-3-1_quant.onnx
1.1G	../onnx_models_cuda/unet/uup-3-2_quant.onnx


---
#### 💚 ONNX-Runtime Test

In [23]:
import onnxruntime as ort

# Load the ONNX model
onnx_model_path = f'{save_path}/unet/upre_quant.onnx'
session_pre = ort.InferenceSession(onnx_model_path, providers=['AzureExecutionProvider'])
session_down = ort.InferenceSession(onnx_model_path, providers=['AzureExecutionProvider'])

In [25]:
# test running
ort_inputs  = {}
latents = session_pre.run(None, {
    'input': torch.randn([1, 4, 64, 64]).to(dtype = dtype).numpy()
})[0]