In [2]:
import torch
import torch.nn as nn
from diffusers.schedulers import PNDMScheduler
from pathlib import Path
from diffusers import DiffusionPipeline
from util import onnx_export

import gc
gc.collect()

0

In [3]:
device = "cpu"
dtype = torch.float32

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32)

---

## UNet

In [14]:
class UnetModel(nn.Module):
    def __init__(self, unet, device):
        super().__init__()
        self.unet = unet
        self.unet = self.unet.to(device = device)
    
    def forward(self, latents, prompt_embeds, timestep):
        latents = torch.cat([latents] * 2).to(prompt_embeds.device)       # [2, 4, 64, 64]
        # predict the noise residual
        noise_pred = self.unet(
            latents,
            timestep,
            encoder_hidden_states=prompt_embeds,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]
        # perform guidance 
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond)

        return noise_pred

In [13]:
onnx_export(
    UnetModel(pipeline.unet, device),
    model_args=(
        torch.randn([1,4,64,64]).to(device=device, dtype=dtype),
        torch.randn([2,4,768]).to(device=device, dtype=dtype),
        torch.tensor(981),
    ),
    output_path = Path('../onnx-models/UNet/model.onnx'),
    ordered_input_names=["latents", "prompt_embeds", "timestep"],
    output_names=["prev_cur_latents"],  # has to be different from "sample" for correct tracing
    dynamic_axes={ 
        "noise_pred": {0: "batch"},
        "prompt_embeds": {0: "batch"},
    },
    opset=14,
    use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
)

ONNX export Start🚗


  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


ONNX export Finish🍷


---
## Scheduler

t = 981

In [16]:
class Scheduler_step981(nn.Module):
    def __init__(self, scheduler, device):
        super().__init__()
        self.scheduler = scheduler
        self.scheduler.set_timesteps(50, device=device)
        self.alphas_cumprod = scheduler.alphas_cumprod
        self.final_alpha_cumprod = scheduler.final_alpha_cumprod

    def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
        # corresponds to denominator of e_θ(x_t, t) in formula (9)
        model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
            alpha_prod_t * beta_prod_t * alpha_prod_t_prev
        ) ** (0.5)
        # full formula (9)
        prev_sample = (
            sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
        )
        return prev_sample

    def forward(self, noise_pred, latents):
        cur_sample = latents
        prev_sample = self._get_prev_sample(latents, 981, 961, noise_pred)
        latents = torch.concat([prev_sample, cur_sample], dim=0)
        return latents

onnx_export(
    Scheduler_step981(pipeline.scheduler, device),
    model_args=(
        torch.randn([1,4,64,64]).to(device=device, dtype=dtype),
        torch.randn([1,4,64,64]).to(device=device, dtype=dtype),
    ),
    output_path = Path(f'../onnx-models/Schedulers/step-{981}.onnx'),
    ordered_input_names=["noise_pred", "latents"],
    output_names=["prev_cur_latents"],  # has to be different from "sample" for correct tracing
    dynamic_axes={ 
        "noise_pred": {0: "batch"}
    },
    opset=14,
    use_external_data_format=True,  # UNet is > 2GB, so the weights need to be split
)

ONNX export Start🚗
ONNX export Finish🍷
