In [6]:
import torch
import torch.nn as nn
from diffusers.schedulers import PNDMScheduler
from pathlib import Path
from diffusers import DiffusionPipeline
from util import onnx_export
import numpy as np

import gc
gc.collect()

0

In [7]:
device = "cpu"
dtype = torch.float32

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32)

---

## Decoder

In [20]:
from PIL import Image

class Decoder(nn.Module):
    def __init__(self, vae_decoder, device='cpu'):
        super().__init__()
        self.vae_decoder = vae_decoder
        self.vae_decoder = vae_decoder.to(device = device)
        self.device = device
  
    
    def denormalize(self, images):
        """
        Denormalize an image array to [0,1].
        """
        return (images / 2 + 0.5).clamp(0, 1)   

    
    def pt_to_numpy(self, images: torch.FloatTensor) -> np.ndarray:
        """
        Convert a PyTorch tensor to a NumPy image.
        """
        images = images.permute(0, 2, 3, 1).float().cpu().detach().numpy()
        return images
    
    
    def numpy_to_pil(self, images: np.ndarray) -> Image.Image:
        """
        Convert a numpy image or a batch of images to a PIL image.
        """
        if images.ndim == 3:
            images = images[None, ...]
        images = (images * 255).round().astype("uint8")
        if images.shape[-1] == 1:
            # special case for grayscale (single channel) images
            pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
        else:
            pil_images = [Image.fromarray(image) for image in images]

        return pil_images
    
    def forward(self, latent_sample):
        # 0. preprocessing
        latent_sample = latent_sample / 0.18215 
        # 1. run
        image = self.vae_decoder(latent_sample)['sample']       # [1, 3, 512, 512]
        
        # 2. postprocessing
        image = self.denormalize(image).permute(0, 2, 3, 1)
        #image = self.pt_to_numpy(image)
         
        return image #self.numpy_to_pil(image)

In [21]:
vae_decoder = pipeline.vae
vae_decoder.forward = pipeline.vae.decode
decoder = Decoder(vae_decoder, device)

# decoder(torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype))
onnx_export(
    decoder,
    model_args=(
        torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype)
    ),
    output_path = Path('../onnx-models/Decoder.onnx'),
    ordered_input_names = ["latent_sample"],   # 입력값 : latent_sample
    output_names = ["output"],
    dynamic_axes = {
            "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
        },
    opset=14,
)

ONNX export Start🚗
ONNX export Finish🍷
