In [2]:
import torch
import torch.nn as nn
from diffusers.schedulers import PNDMScheduler
import onnx
import onnxruntime
from torch.onnx import export
from pathlib import Path
from packaging import version
from diffusers import DiffusionPipeline
import os
import shutil
from typing import List, Optional, Tuple, Union
import numpy as np
import inspect
from tqdm import tqdm

import gc
gc.collect()

  from .autonotebook import tqdm as notebook_tqdm


0

In [3]:
device = "cpu"
dtype = torch.float32

In [4]:
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=dtype)

In [5]:
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")

def onnx_export(
    
    model,
    model_args: tuple,
    output_path: Path,
    ordered_input_names,
    output_names,
    dynamic_axes,
    opset,
    use_external_data_format=False,
):
    print('ONNX export Start🚗')
    output_path.parent.mkdir(parents=True, exist_ok=True)
    # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
    # so we check the torch version for backwards compatibility
    if is_torch_less_than_1_11:
        export(
            model,
            model_args,
            f=output_path.as_posix(),
            input_names=ordered_input_names,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            do_constant_folding=True,
            use_external_data_format=use_external_data_format,
            enable_onnx_checker=True,
            opset_version=opset,
        )
    else:
        export(
            model,
            model_args,
            f=output_path.as_posix(),
            input_names=ordered_input_names,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            do_constant_folding=True,
            opset_version=opset,
        )
    print('ONNX export Finish🍷')

---
#### Text embedding

In [7]:
class TextEmbedding(nn.Module):
    def __init__(self, tokenizer, textencoder, device='cpu'):
        super().__init__()
        self.tokenizer = tokenizer
        self.text_encoder = textencoder.to(device = device)
        self.device = device
       
    def forward(self, text_ids):
        # uncond-input 준비
        uncond_input = self.tokenizer(
            [""],
            padding="max_length",
            max_length=pipeline.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        ).input_ids
        # 인코딩 
        textembed               = self.text_encoder(text_ids.to(device=self.device, dtype=torch.int32)).last_hidden_state
        negative_prompt_embeds  = self.text_encoder(uncond_input.to(device=self.device, dtype=torch.int32)).last_hidden_state
        
        prompt_embeds = torch.cat([negative_prompt_embeds, textembed])
        return prompt_embeds
        
te = TextEmbedding(pipeline.tokenizer, pipeline.text_encoder, 'cuda')
text_input = pipeline.tokenizer(
    ["A sample prompt"],
    padding="max_length",
    max_length=pipeline.tokenizer.model_max_length,
    truncation=True,
    return_tensors="pt",
    ).input_ids.to(device=device, dtype=torch.int32)

In [9]:
onnx_export(
    TextEmbedding(pipeline.tokenizer, pipeline.text_encoder, 'cpu'),
    model_args=(
        text_input
    ),
    output_path = Path('./onnx-models/TextEmbedding/model.onnx'),
    ordered_input_names=["prompt"],
    output_names=["out_sample"],  # has to be different from "sample" for correct tracing
    dynamic_axes={
        "prompt": {0: "batch", 1: "sequence"},
    },
    opset=14,
    use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
)

ONNX export Start🚗
ONNX export Finish🍷


---
### VAE Decoder

In [8]:
class Decoder(nn.Module):
    def __init__(self, vae_decoder, device='cpu'):
        super().__init__()
        self.vae_decoder = vae_decoder
        self.vae_decoder = vae_decoder.to(device = device)
        self.device = device
       
    def forward(self, latent_sample):
        latent_sample = 1 / 0.18215 * latent_sample
        # 통과
        image = self.vae_decoder(latent_sample)['sample']       # [1, 3, 512, 512]
        image = torch.clip(image / 2 + 0.5, 0, 1)
        return image

In [9]:
vae_decoder = pipeline.vae
vae_decoder.forward = pipeline.vae.decode

onnx_export(
    Decoder(vae_decoder, 'cpu'),
    model_args=(
        torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype),
    ),
    output_path = Path('./onnx-models/Decoder/model.onnx'),
    ordered_input_names=["latent_sample"],
    output_names=["out_sample"],  # has to be different from "sample" for correct tracing
    dynamic_axes={
        "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
    },
    opset=14,
    use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
)

ONNX export Start🚗
ONNX export Finish🍷
