In [1]:
import torch
import torch.nn as nn
from diffusers.schedulers import PNDMScheduler
from pathlib import Path
from diffusers import DiffusionPipeline
from util import onnx_export
from onnxruntime.quantization.quantize import quantize_dynamic
from onnxruntime.quantization import QuantType
import onnx 
from diffusers.utils import BaseOutput
import gc
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

gc.collect()

  from .autonotebook import tqdm as notebook_tqdm


0

In [2]:
device = "cpu"
dtype = torch.float32

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32)

---
### (1) Down

In [3]:

class UNet2DConditionModel_Down(nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.num_upsamplers = unet.num_upsamplers
        self.time_proj = unet.time_proj
        self.time_embedding = unet.time_embedding
        self.conv_in = unet.conv_in
        self.down_blocks = unet.down_blocks
        
    def forward(
        self,
        sample: torch.FloatTensor,
        timestep,
        encoder_hidden_states: torch.Tensor,
    ):
                
        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            # This would be a good case for the `match` statement (Python 3.10+)
            is_mps = sample.device.type == "mps"
            if isinstance(timestep, float):
                dtype = torch.float32 if is_mps else torch.float64
            else:
                dtype = torch.int32 if is_mps else torch.int64
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
        
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])
        
        t_emb = self.time_proj(timesteps)
        
        emb = self.time_embedding(t_emb, None)
        aug_emb = None
        emb = emb + aug_emb if aug_emb is not None else emb
        
        # 2. pre-process
        sample = self.conv_in(sample)   # 2 320 64 64
        print(sample.shape)
        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
                sample, res_samples = downsample_block(
                    hidden_states=sample,
                    temb=emb,
                    encoder_hidden_states=encoder_hidden_states,
                    attention_mask=None,
                    cross_attention_kwargs=None,
                    encoder_attention_mask=None,
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples
        return sample, emb, down_block_res_samples[0], down_block_res_samples[1],down_block_res_samples[2],\
                down_block_res_samples[3], down_block_res_samples[4], down_block_res_samples[5], \
                down_block_res_samples[6], down_block_res_samples[7], down_block_res_samples[8], \
                down_block_res_samples[9], down_block_res_samples[10], down_block_res_samples[11]

In [4]:
onnx_export(
        UNet2DConditionModel_Down(pipeline.unet),
        model_args=(
            torch.randn([2,4,64,64]).to(device=device, dtype=dtype),
            torch.randn(1).to(device=device, dtype=dtype),
            torch.randn([2,77,768]).to(device=device, dtype=dtype),
        ),
        output_path = Path('../onnx-models/UNet-ver2/down.onnx'),
        ordered_input_names=["sample", "timestep", "encoder_hidden_states"],
        output_names=[
            "out_sample", "emb",
            'down_block_res_samples_0', 'down_block_res_samples_1', 'down_block_res_samples_2',
            'down_block_res_samples_3', 'down_block_res_samples_4', 'down_block_res_samples_5',
            'down_block_res_samples_6', 'down_block_res_samples_7', 'down_block_res_samples_8',
            'down_block_res_samples_9', 'down_block_res_samples_10', 'down_block_res_samples_11',
        ],  # has to be different from "sample" for correct tracing
        dynamic_axes={
            "timestep": {0: "batch"}
            # "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
            # "timestep": {0: "batch"},
            # "encoder_hidden_states": {0: "batch", 1: "sequence"},
        },
        opset=14,
        use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
    )

ONNX export Start🚗
torch.Size([2, 320, 64, 64])
ONNX export Finish🍷


In [12]:
!du -sh ../onnx-models/UNet-ver2/*

961M	../onnx-models/UNet-ver2/down.onnx


---
### (2) Mid

In [13]:
class UNet2DConditionModel_Mid(nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.mid_block = unet.mid_block
        
    def forward(
        self,
        sample: torch.FloatTensor,
        emb: torch.FloatTensor,
        encoder_hidden_states: torch.Tensor,
    ):
        # 4. mid
        if self.mid_block is not None:
            sample = self.mid_block(
                sample,
                emb,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=None,
                cross_attention_kwargs=None,
                encoder_attention_mask=None,
            )
        return sample

In [14]:
onnx_export(
        UNet2DConditionModel_Mid(pipeline.unet),
        model_args=(
            torch.randn([2, 1280, 8, 8]).to(device=device, dtype=dtype),
            torch.randn(2, 1280).to(device=device, dtype=dtype),
            torch.randn([2, 77, 768]).to(device=device, dtype=dtype),
        ),
        output_path = Path('../onnx-models/UNet-ver2/unet2dconditionalmodel_mid.onnx'),
        ordered_input_names=["sample", "emb", "encoder_hidden_states"],
        output_names=["out_sample"],  # has to be different from "sample" for correct tracing
        dynamic_axes={
            "emb": {1: "sequence"}
            # "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
            # "timestep": {0: "batch"},
            # "encoder_hidden_states": {0: "batch", 1: "sequence"},
        },
        opset=14,
        use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
    )

ONNX export Start🚗
ONNX export Finish🍷


In [15]:
!du -sh ../onnx-models/UNet-ver2/*

961M	../onnx-models/UNet-ver2/down.onnx
371M	../onnx-models/UNet-ver2/unet2dconditionalmodel_mid.onnx


---
### Up

In [16]:
class UNet2DConditionModel_Up(nn.Module):
    def __init__(self, unet_up_ith_block):
        super().__init__()
        self.upsample_block = unet_up_ith_block
        
    def forward(
        self,
        sample: torch.FloatTensor,
        emb: torch.FloatTensor,
        encoder_hidden_states: torch.Tensor,
        down_block_res_samples0,
        down_block_res_samples1,
        down_block_res_samples2,
    ):
        res_samples = [
            down_block_res_samples0, down_block_res_samples1, down_block_res_samples2
        ]
        # print(sample.shape)
        if hasattr(self.upsample_block, "has_cross_attention") and self.upsample_block.has_cross_attention:
            sample = self.upsample_block(
                hidden_states=sample,
                temb=emb,
                res_hidden_states_tuple=res_samples,
                encoder_hidden_states=encoder_hidden_states,
                cross_attention_kwargs=None,
                upsample_size=None,
                attention_mask=None,
                encoder_attention_mask=None,
            )
        else:
            sample = self.upsample_block(
                hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=None
            )
        return sample

In [17]:
class TextEmbedding(nn.Module):
    def __init__(self, tokenizer, textencoder, device='cpu'):
        super().__init__()
        self.tokenizer = tokenizer
        self.text_encoder = textencoder.to(device = device)
        self.device = device
       
    def forward(self, text_ids):
        # uncond-input 준비
        uncond_input = self.tokenizer(
            [""],
            padding="max_length",
            max_length=pipeline.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        ).input_ids
        # 인코딩 
        textembed               = self.text_encoder(text_ids.to(device=self.device, dtype=torch.int32)).last_hidden_state
        negative_prompt_embeds  = self.text_encoder(uncond_input.to(device=self.device, dtype=torch.int32)).last_hidden_state
        
        prompt_embeds = torch.cat([negative_prompt_embeds, textembed])
        return prompt_embeds
text_input = pipeline.tokenizer(
    ["A smile Tiger"],
    padding="max_length",
    max_length=pipeline.tokenizer.model_max_length,
    truncation=True,
    return_tensors="pt",
    ).input_ids.to(device=device, dtype=torch.int32)
te = TextEmbedding(pipeline.tokenizer, pipeline.text_encoder, device)
text_embed = te(text_input)

In [18]:
down_block_res_samples = [
    torch.randn([2, 320, 64, 64]).to(device=device, dtype=dtype),
    torch.randn([2, 320, 64, 64]).to(device=device, dtype=dtype),
    torch.randn([2, 320, 64, 64]).to(device=device, dtype=dtype),
    
    torch.randn([2, 320, 32, 32]).to(device=device, dtype=dtype),
    torch.randn([2, 640, 32, 32]).to(device=device, dtype=dtype),
    torch.randn([2, 640, 32, 32]).to(device=device, dtype=dtype),
    
    torch.randn([2, 640, 16, 16]).to(device=device, dtype=dtype),
    torch.randn([2, 1280, 16, 16]).to(device=device, dtype=dtype),
    torch.randn([2, 1280, 16, 16]).to(device=device, dtype=dtype),
            
    torch.randn([2, 1280, 8, 8]).to(device=device, dtype=dtype),
    torch.randn([2, 1280, 8, 8]).to(device=device, dtype=dtype),
    torch.randn([2, 1280, 8, 8]).to(device=device, dtype=dtype)   
]

pipeline.unet.ir_version = 11
for i in range(4):
    res_samples = down_block_res_samples[-3 :]
    down_block_res_samples = down_block_res_samples[: -3]
    if i == 3: continue
    onnx_export(
            UNet2DConditionModel_Up(pipeline.unet.up_blocks[i]),
            model_args=(
                torch.randn([2, 1280 if i!=3 else 640, 8*(2**(i)), 8*(2**(i))]).to(device=device, dtype=dtype),
                torch.randn([2, 1280]).to(device=device, dtype=dtype),
                text_embed, #torch.randn([2, 77, 768]).to(device=device, dtype=dtype),
                res_samples[0], res_samples[1], res_samples[2]
            ),
            output_path = Path(f'../onnx-models/UNet-ver2/unet2dconditionalmodel_up_{i}/model.onnx'),
            ordered_input_names=[
                "sample", "emb", "encoder_hidden_states", 
                "res_samples0", "res_samples1", "res_samples2",
            ],
            output_names=["out_sample"],  # has to be different from "sample" for correct tracing
            dynamic_axes={
                "sample": {0:"b"},
                "emb": {1: "sequence"},
                
            },
            opset=11,
            use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
        )
    # break

ONNX export Start🚗
ONNX export Finish🍷
ONNX export Start🚗
ONNX export Finish🍷
ONNX export Start🚗
ONNX export Finish🍷


In [20]:
class UNet2DConditionModel_Up_3(nn.Module):
    def __init__(self, resnets, attentions):
        super().__init__()
        self.resnet = resnets
        self.attention = attentions
    
    def forward(
        self,
        sample: torch.FloatTensor,
        emb: torch.FloatTensor,
        encoder_hidden_states: torch.Tensor,
        res_hidden_states,
    ):
        sample = torch.cat([sample, res_hidden_states], dim=1)
        sample = self.resnet(sample, emb)
        sample = self.attention(
            sample,
            encoder_hidden_states=encoder_hidden_states,
            cross_attention_kwargs=None,
            attention_mask=None,
            encoder_attention_mask=None,
            return_dict=False,
        )[0]
        print(sample.shape)
        return sample

In [21]:
res_samples = [          
    torch.randn([2, 320, 64, 64]).to(device=device, dtype=dtype),
    torch.randn([2, 320, 64, 64]).to(device=device, dtype=dtype),
    torch.randn([2, 320, 64, 64]).to(device=device, dtype=dtype),
]
pipeline.unet.ir_version = 11
for i in range(3):
    onnx_export(
        UNet2DConditionModel_Up_3(
            pipeline.unet.up_blocks[-1].resnets[i],
            pipeline.unet.up_blocks[-1].attentions[i]
        ),
        model_args=(
            torch.randn([2, 640 if i == 0 else 320, 64, 64]).to(device=device, dtype=dtype),
            torch.randn([2, 1280]).to(device=device, dtype=dtype),
            text_embed,
            res_samples[-i], 
        ),
        output_path = Path(f'../onnx-models/UNet-ver2/unet2dconditionalmodel_up_{3}_{i}/model.onnx'),
        ordered_input_names=[
            "sample", "emb", "encoder_hidden_states", 
            "res_samples0"
        ],
        output_names=["out_sample"],  # has to be different from "sample" for correct tracing
        dynamic_axes={
            "sample": {0:"b"},        
        },
        opset=11,
        use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
    )

ONNX export Start🚗
torch.Size([2, 320, 64, 64])
ONNX export Finish🍷
ONNX export Start🚗
torch.Size([2, 320, 64, 64])
ONNX export Finish🍷
ONNX export Start🚗
torch.Size([2, 320, 64, 64])
ONNX export Finish🍷


In [22]:
!du -sh ../onnx-models/UNet-ver2/**

961M	../onnx-models/UNet-ver2/down.onnx
371M	../onnx-models/UNet-ver2/unet2dconditionalmodel_mid.onnx
619M	../onnx-models/UNet-ver2/unet2dconditionalmodel_up_0
1002M	../onnx-models/UNet-ver2/unet2dconditionalmodel_up_1
480M	../onnx-models/UNet-ver2/unet2dconditionalmodel_up_2
1.1G	../onnx-models/UNet-ver2/unet2dconditionalmodel_up_3_0
1.1G	../onnx-models/UNet-ver2/unet2dconditionalmodel_up_3_1
1.1G	../onnx-models/UNet-ver2/unet2dconditionalmodel_up_3_2


---
### (4) post-process

In [23]:
class UNet_post_process(nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.conv_norm_out  = unet.conv_norm_out
        self.conv_act       = unet.conv_act
        self.conv_out       = unet.conv_out
    def forward(
        self,
        sample: torch.FloatTensor
    ):
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)
        return sample

In [25]:
onnx_export(
        UNet_post_process(pipeline.unet),
        model_args=(
            torch.randn([2, 320, 64, 64]).to(device=device, dtype=dtype),
        ),
        output_path = Path('../onnx-models/UNet-ver2/UNet_post_process/model.onnx'),
        ordered_input_names=["sample"],
        output_names=["out_sample"],  # has to be different from "sample" for correct tracing
        dynamic_axes={
            "sample": {0: "batch"}
        },
        opset=14,
        use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
    )
!du -sh ../onnx-models/UNet-ver2/**

ONNX export Start🚗
ONNX export Finish🍷
56K	../onnx-models/UNet-ver2/UNet_post_process
961M	../onnx-models/UNet-ver2/down.onnx
371M	../onnx-models/UNet-ver2/unet2dconditionalmodel_mid.onnx
619M	../onnx-models/UNet-ver2/unet2dconditionalmodel_up_0
1002M	../onnx-models/UNet-ver2/unet2dconditionalmodel_up_1
480M	../onnx-models/UNet-ver2/unet2dconditionalmodel_up_2
1.1G	../onnx-models/UNet-ver2/unet2dconditionalmodel_up_3_0
1.1G	../onnx-models/UNet-ver2/unet2dconditionalmodel_up_3_1
1.1G	../onnx-models/UNet-ver2/unet2dconditionalmodel_up_3_2


---
---
### 양자화

In [26]:
def quantize_onnx_model(input_path, output_path):
    print(f"start quantize : {input_path} => {output_path}")
    quantize_dynamic(
        model_input     = input_path,
        model_output    = output_path,
        per_channel     = False,
        reduce_range    = False,
        weight_type     = QuantType.QUInt8,
    )
    print(f"end : {output_path}")