In [2]:
import torch
import torch.nn as nn
from diffusers.schedulers import PNDMScheduler
from pathlib import Path
from diffusers import DiffusionPipeline
from util import onnx_export
from onnxruntime.quantization.quantize import quantize_dynamic
from onnxruntime.quantization import QuantType
import gc
gc.collect()

  from .autonotebook import tqdm as notebook_tqdm


0

In [3]:
device = "cpu"
dtype = torch.float32

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32)

---

#### Text Embedding

In [3]:
class TextEmbedding(nn.Module):
    def __init__(self, tokenizer, textencoder, device='cpu'):
        super().__init__()
        self.tokenizer = tokenizer
        self.text_encoder = textencoder.to(device = device)
        self.device = device
       
    def forward(self, text_ids):
        # uncond-input 준비
        uncond_input = self.tokenizer(
            [""],
            padding="max_length",
            max_length=pipeline.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        ).input_ids
        # 인코딩 
        textembed               = self.text_encoder(text_ids.to(device=self.device, dtype=torch.int32)).last_hidden_state
        negative_prompt_embeds  = self.text_encoder(uncond_input.to(device=self.device, dtype=torch.int32)).last_hidden_state
        
        prompt_embeds = torch.cat([negative_prompt_embeds, textembed])
        return prompt_embeds

In [3]:
text_input = pipeline.tokenizer(
    ["A smile Tiger"],
    padding="max_length",
    max_length=pipeline.tokenizer.model_max_length,
    truncation=True,
    return_tensors="pt",
    ).input_ids.to(device=device, dtype=torch.int32)

onnx_export(
    TextEmbedding(pipeline.tokenizer, pipeline.text_encoder, device),
    model_args=(
        text_input
    ),
    output_path = Path('../onnx-models/TextEmbedding.onnx'),
    ordered_input_names=["prompt"],
    output_names=["out_sample"],  # has to be different from "sample" for correct tracing
    dynamic_axes={
        "prompt": {0: "batch", 1: "sequence"},
    },
    opset=14,
    use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
)

NameError: name 'TextEmbedding' is not defined

- test용 Text Encoder

In [4]:
class TextEmbedding_test(nn.Module):
    def __init__(self, tokenizer, textencoder, device='cpu'):
        super().__init__()
        self.tokenizer = tokenizer
        self.text_encoder = textencoder.to(device = device)
        self.device = device
       
    def forward(self, text_ids):
        text_ids = self.tokenizer(
            ["a smile cat"],
            padding="max_length",
            max_length=pipeline.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        ).input_ids
        # uncond-input 준비
        uncond_input = self.tokenizer(
            [""],
            padding="max_length",
            max_length=pipeline.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        ).input_ids
        # 인코딩 
        textembed               = self.text_encoder(text_ids.to(device=self.device, dtype=torch.int32)).last_hidden_state
        negative_prompt_embeds  = self.text_encoder(uncond_input.to(device=self.device, dtype=torch.int32)).last_hidden_state
        
        prompt_embeds = torch.cat([negative_prompt_embeds, textembed])
        return prompt_embeds

In [8]:
onnx_export(
    TextEmbedding_test(pipeline.tokenizer, pipeline.text_encoder, 'cpu'),
    model_args=(
        torch.zeros([1, 77])
    ),
    output_path = Path('../onnx-models/TextEmbedding-test.onnx'),
    ordered_input_names=["prompt"],
    output_names=["out_sample"],  # has to be different from "sample" for correct tracing
    dynamic_axes={
        "prompt": {0: "batch", 1: "sequence"},
    },
    opset=12,
    use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
)
quantize_dynamic(
    model_input = f'../onnx-models/TextEmbedding-test.onnx',
    model_output=f'../onnx-models/TextEmbedding-test-quant.onnx',
    per_channel=False,
    reduce_range=False,
    weight_type=QuantType.QUInt8,
)

ONNX export Start🚗
ONNX export Finish🍷




In [10]:
import onnxruntime
import numpy as np

unetSession = onnxruntime.InferenceSession(f'../onnx-models/TextEmbedding-test-quant.onnx', providers=['AzureExecutionProvider'])
[i.name for i in unetSession.get_inputs()]

[]