In [1]:
import torch
import torch.nn as nn
from diffusers.schedulers import PNDMScheduler
from pathlib import Path
from diffusers import DiffusionPipeline
from util import onnx_export
from onnxruntime.quantization.quantize import quantize_dynamic
from onnxruntime.quantization import QuantType
import onnx 
from diffusers.utils import BaseOutput
import gc
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

gc.collect()

  from .autonotebook import tqdm as notebook_tqdm


0

In [2]:
device = "cpu"
dtype = torch.float32

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32)

---

In [3]:
class TextEmbedding(nn.Module):
    def __init__(self, tokenizer, textencoder, device='cpu'):
        super().__init__()
        self.tokenizer = tokenizer
        self.text_encoder = textencoder.to(device = device)
        self.device = device
       
    def forward(self, text_ids):
        # uncond-input 준비
        uncond_input = self.tokenizer(
            [""],
            padding="max_length",
            max_length=pipeline.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        ).input_ids
        # 인코딩 
        textembed               = self.text_encoder(text_ids.to(device=self.device, dtype=torch.int32)).last_hidden_state
        negative_prompt_embeds  = self.text_encoder(uncond_input.to(device=self.device, dtype=torch.int32)).last_hidden_state
        
        prompt_embeds = torch.cat([negative_prompt_embeds, textembed])
        return prompt_embeds
text_input = pipeline.tokenizer(
    ["A smile Tiger"],
    padding="max_length",
    max_length=pipeline.tokenizer.model_max_length,
    truncation=True,
    return_tensors="pt",
    ).input_ids.to(device=device, dtype=torch.int32)
te = TextEmbedding(pipeline.tokenizer, pipeline.text_encoder, device)

In [4]:
import onnxruntime
import numpy as np

sess_down = onnxruntime.InferenceSession('../onnx-models/UNet-ver2/down.onnx', providers=['AzureExecutionProvider'])
# sess_mid = onnxruntime.InferenceSession('../onnx-models/UNet-ver2/unet2dconditionalmodel_mid.onnx', providers=['AzureExecutionProvider'])
[i.name for i in sess_down.get_inputs()]

['sample', 'timestep', 'encoder_hidden_states']

In [5]:
sess_mid = onnxruntime.InferenceSession('../onnx-models/UNet-ver2/unet2dconditionalmodel_mid.onnx', providers=['AzureExecutionProvider'])
[i.name for i in sess_mid.get_inputs()]

['sample', 'emb', 'encoder_hidden_states']

In [6]:
sess_up0 = onnxruntime.InferenceSession('../onnx-models/UNet-ver2/unet2dconditionalmodel_up_0/model.onnx', providers=['AzureExecutionProvider'])
sess_up1 = onnxruntime.InferenceSession('../onnx-models/UNet-ver2/unet2dconditionalmodel_up_1/model.onnx', providers=['AzureExecutionProvider'])
sess_up2 = onnxruntime.InferenceSession('../onnx-models/UNet-ver2/unet2dconditionalmodel_up_2/model.onnx', providers=['AzureExecutionProvider'])
[i.name for i in sess_up0.get_inputs()]

['sample', 'emb', 'res_samples0', 'res_samples1', 'res_samples2']

In [7]:
sess_up3_0 = onnxruntime.InferenceSession('../onnx-models/UNet-ver2/unet2dconditionalmodel_up_3_0/model.onnx', providers=['AzureExecutionProvider'])
sess_up3_1 = onnxruntime.InferenceSession('../onnx-models/UNet-ver2/unet2dconditionalmodel_up_3_1/model.onnx', providers=['AzureExecutionProvider'])
sess_up3_2 = onnxruntime.InferenceSession('../onnx-models/UNet-ver2/unet2dconditionalmodel_up_3_2/model.onnx', providers=['AzureExecutionProvider'])
[i.name for i in sess_up3_0.get_inputs()], [i.name for i in sess_up3_2.get_inputs()]

(['sample', 'emb', 'encoder_hidden_states', 'res_samples0'],
 ['sample', 'emb', 'encoder_hidden_states', 'res_samples0'])

In [8]:

sess_post = onnxruntime.InferenceSession('../onnx-models/UNet-ver2/UNet_post_process/model.onnx', providers=['AzureExecutionProvider'])
[i.name for i in sess_post.get_inputs()]

['sample']

---

In [9]:
# input 준비
input_sample = torch.randn([2,4,64,64]).to(device=device, dtype=dtype)
input_timestemp = torch.tensor([961]).to(device=device, dtype=dtype)
input_text_embed = te(text_input)
input_sample.shape, input_timestemp.shape, input_text_embed.shape

(torch.Size([2, 4, 64, 64]), torch.Size([1]), torch.Size([2, 77, 768]))

- Down

In [10]:
outputs = sess_down.run(None, {
    sess_down.get_inputs()[0].name: input_sample.detach().numpy(),
    sess_down.get_inputs()[1].name: input_timestemp.detach().numpy(),
    sess_down.get_inputs()[2].name: input_text_embed.detach().numpy()
})
sample, emb, down_block_res_samples = outputs[0], outputs[1], outputs[2:]

- Mid

In [11]:
outputs = sess_mid.run(None, {
    sess_mid.get_inputs()[0].name: sample,
    sess_mid.get_inputs()[1].name: emb,
    sess_mid.get_inputs()[2].name: input_text_embed.detach().numpy()
})
sample = outputs[0]

- Up

In [12]:
sessUpList = [sess_up0, sess_up1, sess_up2]
print('> ', sample.shape)
for i in range(3):
    res_samples = down_block_res_samples[-3 :]
    down_block_res_samples = down_block_res_samples[: -3]
    if i == 0:
        outputs = sessUpList[i].run(None, {
            sessUpList[i].get_inputs()[0].name: sample,
            sessUpList[i].get_inputs()[1].name: emb,
            sessUpList[i].get_inputs()[2].name: res_samples[0],
            sessUpList[i].get_inputs()[3].name: res_samples[1],
            sessUpList[i].get_inputs()[4].name: res_samples[2],
        })
        sample = outputs[0]
        print(sample.shape)
    else:
        outputs = sessUpList[i].run(None, {
            sessUpList[i].get_inputs()[0].name: sample,
            sessUpList[i].get_inputs()[1].name: emb,
            sessUpList[i].get_inputs()[2].name: input_text_embed.detach().numpy(),
            sessUpList[i].get_inputs()[3].name: res_samples[0],
            sessUpList[i].get_inputs()[4].name: res_samples[1],
            sessUpList[i].get_inputs()[5].name: res_samples[2],
        })
        sample = outputs[0]
        print(sample.shape)

>  (2, 1280, 8, 8)
(2, 1280, 16, 16)
(2, 1280, 32, 32)
(2, 640, 64, 64)


In [13]:
res_samples[-1].shape

(2, 640, 32, 32)

In [14]:
up_3_list = [sess_up3_0, sess_up3_1, sess_up3_2]
for i in range(3):
    outputs = up_3_list[i].run(None, {
        up_3_list[i].get_inputs()[0].name: sample,
        up_3_list[i].get_inputs()[1].name: emb,
        up_3_list[i].get_inputs()[2].name: input_text_embed.detach().numpy(),
        up_3_list[i].get_inputs()[3].name: down_block_res_samples[-(i+1)],
    })
    sample = outputs[0]

In [15]:
outputs = sess_post.run(
    None,
    {
        sess_post.get_inputs()[0].name: sample,
    }
)
sample = outputs[0]

In [16]:
sample.shape

(2, 4, 64, 64)