In [1]:
import os
import torch
import torch.nn as nn
from diffusers.schedulers import PNDMScheduler
from pathlib import Path
from diffusers import DiffusionPipeline
import onnx
import onnxruntime
from onnxruntime.quantization.quantize import quantize_dynamic
from onnxruntime.quantization import QuantType

from util_onnx import onnx_export
import utils

import gc

gc.collect()

0

In [2]:
device = 'cpu'
dtype = torch.float32
save_path = f'../onnx_models_{device}'
os.makedirs(save_path, exist_ok = True)

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32)

Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

---
##### 0. INPUT 준비

In [3]:
import onnxruntime as ort

onnx_model_path = f'{save_path}/tokenizer/to_quant.onnx'
sessTokenizer = ort.InferenceSession(onnx_model_path, providers=['AzureExecutionProvider'])

ascii_str   = utils.toAsciiTensor()
text_ids = sessTokenizer.run(None, {
    'input' : ascii_str.detach().cpu().numpy()
})[0]
text_ids = torch.tensor(text_ids).to(device = device)
text_ids

tensor([[49407,   320,  3490,  2368, 49406,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]])

--- 
##### 1. Define model

In [4]:
uncond_input = pipeline.tokenizer(
            [""],
            padding         = "max_length",
            max_length      = 77,
            truncation      = True,
            return_tensors  = "pt",
        ).input_ids

In [5]:
class TextEmbedding(nn.Module):
    def __init__(self, textencoder, uncond_input, device='cpu'):
        super().__init__()
        self.text_encoder = textencoder.to(device = device)
        self.uncond_input = uncond_input
        self.device = device
       
    def forward(self, text_ids):
        textembed       = self.text_encoder(
            text_ids.to(device=self.device, dtype=torch.int32)
        ).last_hidden_state
        negative_prompt_embeds  = self.text_encoder(
            self.uncond_input.to(device=self.device, dtype=torch.int32)
        ).last_hidden_state
        
        prompt_embeds = torch.cat([negative_prompt_embeds, textembed])
        return prompt_embeds
TextEmbedding(pipeline.text_encoder, uncond_input)(text_ids).shape

torch.Size([2, 77, 768])

-----
#### 💛 Conversion

In [6]:
# onnx conversion
os.makedirs(f'{save_path}/text_encoder/', exist_ok= True)
torch.onnx.export(
    model               =   TextEmbedding(pipeline.text_encoder, uncond_input, device=device),                            # 실행될 모델
    args                =   (text_ids),        # 모델 입력값(tuple or 여러 입력값)
    f                   =   f'{save_path}/text_encoder/te_origin.onnx',                     # 모델 저장 경로
    export_params       =   True,                 # 모델 파일 안에 학습된 모델 가중치 저장 여부
    opset_version       =   14,                   # 모델 변환할 때 사용할 onnx 버전
    do_constant_folding =   True,         # 최적화시 상수폴딩 사용할지 여부
    input_names     =   ['input'],
    output_names    =   ["output"],
    dynamic_axes    =   {
        'input'     : {0 : 'batch_size'},    # 가변적인 길이를 가진 차원
    }
) 
# model quantization
quantize_dynamic(
    model_input     =   f'./{save_path}/text_encoder/te_origin.onnx', 
    model_output    =   f'./{save_path}/text_encoder/te_quant.onnx', 
    per_channel     =   False,
    reduce_range    =   False,
    weight_type     =   QuantType.QUInt8,
)

  if input_shape[-1] > 1 or self.sliding_window is not None:
  if past_key_values_length > 0:
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):


---
#### 💚 ONNX-Runtime Test

In [8]:
import onnxruntime as ort

# Load the ONNX model
onnx_model_path = f'{save_path}/text_encoder/te_quant.onnx'
session = ort.InferenceSession(onnx_model_path)

# Print the input names and shapes
input_names = [input.name for input in session.get_inputs()]
output_names = [output.name for output in session.get_outputs()]

print("Input names:", input_names)
print("Output names:", output_names)

Input names: ['input']
Output names: ['output']


In [9]:
# test running
ort_inputs  = {'input': text_ids.detach().cpu().numpy()}
ort_outputs = session.run(None, ort_inputs)
print(ort_outputs[0].shape)
print(ort_outputs[0])

(2, 77, 768)
[[[-0.3885849   0.02323209 -0.05236366 ... -0.48972186 -0.30705208
    0.06718016]
  [-0.3963155  -1.440579   -0.3369534  ...  0.9637965   0.17683399
   -1.0900829 ]
  [-0.52471566 -1.461724   -0.30159184 ...  1.0555189   0.0728555
   -1.0248568 ]
  ...
  [ 0.5502783  -0.9023385  -0.5174666  ...  1.6339296  -1.0447981
   -0.26709685]
  [ 0.5530994  -0.89288574 -0.51752    ...  1.681461   -1.0652006
   -0.26443684]
  [ 0.5487205  -0.73710895 -0.35334367 ...  1.632203   -1.0066445
   -0.2860986 ]]

 [[-0.367568    0.05707917  0.01100038 ... -0.469685   -0.21827777
    0.05954561]
  [-0.27619705 -1.1812177  -0.16578478 ...  0.3780036   0.4180037
   -1.3891574 ]
  [-0.5270904   0.56846243  1.5313094  ...  1.4484036  -0.8622036
   -0.73963755]
  ...
  [ 0.34101295 -0.5812703   0.76695126 ...  0.04324484  0.87072957
   -1.2074572 ]
  [ 0.7566298  -1.0436481   0.5474509  ...  0.12419611  0.77519834
   -1.2423668 ]
  [-0.14187689 -0.872546    1.1237452  ... -0.5237114   0.9084585
