# Stable Diffusion Model Conversion for Mobile Deployment

This notebook documents the process of converting the UNet, VAE, and Text Encoder components of Stable Diffusion v1.4 to the ONNX & Pytorch Mobile format for Mobile deployment


## Table of Contents
* [1. Setup and Configuration](#chapter1)
    * [1.1 Download model from HuggingFace](#section_1_1)
    * [1.2 Set dummy inputs for inferencing](#section_1_2)
* [2 Conversion for Mobile Deployment](#chapter2)
    * [2.1 Text Encoder Conversion](#section_2_1)
    * [2.2 UNet Conversion](#section_2_2)
    * [2.3 VAE Conversion](#section_2_3)

### 1. Setup and Configuration <a class="anchor" id="chapter1"></a>

In [1]:
# Automatically install required Python packages if they are missing
import sys
import subprocess

# Helper function to install packages if not already installed
def install_if_missing(package):
    try:
        __import__(package)
    except ImportError:
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# Core packages
install_if_missing("torch")
install_if_missing("numpy")
install_if_missing("onnxruntime")
install_if_missing("diffusers")
install_if_missing("transformers")
install_if_missing("PIL")  
install_if_missing("matplotlib")
install_if_missing("accelerate")
install_if_missing("tqdm")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import numpy as np
import onnxruntime as ort
from diffusers import DiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import UNet2DConditionModel, PNDMScheduler, AutoencoderKL
from PIL import Image
import os
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

In [3]:
# List packages version used in this notebook
from diffusers import __version__ as diffusers_version
from transformers import __version__ as transformers_version

print(f"PyTorch version: {torch.__version__}")
print(f"ONNX Runtime version: {ort.__version__}")
print(f"Diffusers version: { diffusers_version}") 
print(f"Transformers version: {transformers_version}")

PyTorch version: 2.5.1+cpu
ONNX Runtime version: 1.20.1
Diffusers version: 0.32.2
Transformers version: 4.48.0


In [4]:
# Setting CPU or GPU (if avil) to speed up inferencing/conversion
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
print(f"Device set to {device.upper()}")

Device set to CPU


In [5]:
# Setting folder paths to store model files

MODEL_PATH = './stable-diffusion-v1-4'

UNET_FILE_PATH = 'unet/unet_onnx.onnx'
VAE_FILE_PATH = 'vae_onnx.onnx'
ENCODER_FILE_PATH = 'encoder_pt.pt'

##### 1.1 Download model from HuggingFace <a id="section_1_1"></a>

In [6]:
if os.path.exists(MODEL_PATH):
    print("Diffusion model already exists. Skipping download")
else:
    print("Downloading model from hugging face")
    pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
    pipeline.save_pretrained(MODEL_PATH)

Downloading model from hugging face


Loading pipeline components...: 100%|████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16.80it/s]


In [7]:
# Seperately load each components for conversion

unet = UNet2DConditionModel.from_pretrained(MODEL_PATH, subfolder="unet").to(device)
vae = AutoencoderKL.from_pretrained(MODEL_PATH, subfolder="vae").to(device)
text_encoder = CLIPTextModel.from_pretrained(MODEL_PATH, subfolder="text_encoder").to(device)
tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATH, subfolder="tokenizer")
scheduler = PNDMScheduler.from_pretrained(MODEL_PATH, subfolder="scheduler")

#### 1.2 Set dummy inputs for inferencing <a id="section_1_2"></a>

In [8]:
prompt = ["A realistic portrait of an old man"]
height, width = 512, 512  # default height and width
num_channel = 4
num_inference_steps = 10  # Number of denoising steps
guidance_scale = 7.5  # Scale for classifier-free guidance
generator = torch.Generator(device=device).manual_seed(0)  # Seed generator to create
batch_size = len(prompt)

# Create tokens for prompt and negative prompt
prompt_tokens = tokenizer(
    prompt,
    padding="max_length",
    max_length=tokenizer.model_max_length,
    truncation=True,
    return_tensors="pt",
)
neg_tokens = tokenizer([""] * batch_size, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt")

# Create text embeddings for prompt and negative prompt
prompt_embeddings = text_encoder(prompt_tokens.input_ids.to(device))[0]
neg_embeddings =  text_encoder(neg_tokens.input_ids.to(device))[0]
embeddings = torch.cat( [neg_embeddings, prompt_embeddings])

# Generate random latent noise
latent_noise = torch.randn((batch_size, num_channel, height // 8, width // 8)) 

# Generate time steps
scheduler.set_timesteps(num_inference_steps) # Generate time step

### 2. Conversion for Mobile Deployment

#### 2.1 Text Encoder Conversion <a id="section_2_1"></a>

In [9]:
class TextEncoderWrapper(torch.nn.Module):
    def __init__(self, text_encoder):
        super().__init__()
        self.text_encoder = text_encoder

    def forward(self, input_ids):
        # Get the dictionary output and return the required tensor
        outputs = self.text_encoder(input_ids)
        return outputs["last_hidden_state"]  

# Wrap the original text_encoder model
wrapped_text_encoder = TextEncoderWrapper(text_encoder)

# trace the wrapped model
traced_model = torch.jit.trace(wrapped_text_encoder, prompt_tokens.input_ids)
torch.jit.save(traced_model, ENCODER_FILE_PATH)

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
  if input_shape[-1] > 1 or self.sliding_window is not None:
  if past_key_values_length > 0:


#### 2.2 UNet Conversion <a id="section_2_2"></a>

In [10]:
t0 = scheduler.timesteps[0] # Only one timestep needed for UNet Conversion

# expand the latents to avoid doing two forward passes.
expand_latent_noise = torch.cat([latent_noise] * 2)
# by design when the model is >= 2gb, ONNX export produces hundreds of weight/bias/Matmul/etc. files alongside the .onnx file
# https://github.com/pytorch/pytorch/issues/94280
# text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
os.makedirs("unet")
torch.onnx.export(unet, (expand_latent_noise, t0, embeddings), UNET_FILE_PATH) 

  if dim % default_overall_up_factor != 0:
  assert hidden_states.shape[1] == self.channels
  assert hidden_states.shape[1] == self.channels
  assert hidden_states.shape[1] == self.channels
  if hidden_states.shape[0] >= 64:
  if hidden_states.numel() * scale_factor > pow(2, 31):
  if not return_dict:


#### 2.3 VAE Conversion <a id="section_2_3"></a>

In [11]:
for t in tqdm(scheduler.timesteps):
    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
    latent_model_input = torch.cat([latent_noise] * 2)
    latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
    
    # predict the noise residual
    with torch.inference_mode():
        noise_pred = unet(latent_model_input, t, embeddings).sample
    
    # perform guidance
    noise_pred_neg, noise_pred_prompt = noise_pred.chunk(2)
    noise_pred = noise_pred_neg + guidance_scale * (noise_pred_prompt - noise_pred_neg)

    # compute the previous noisy sample x_t -> x_t-1
    latent_noise = scheduler.step(noise_pred, t, latent_noise).prev_sample

# Scale latent after denoise loop
latent_noise = 1 / 0.18215 * latent_noise

100%|██████████████████████████████████████████████████████████████████████████████████| 11/11 [00:31<00:00,  2.85s/it]


In [12]:
class VAEWrapper(torch.nn.Module):
  def __init__(self, vae):
    super(VAEWrapper, self).__init__()
    self.vae = vae

  def forward(self, latents):
    return self.vae.decode(latents).sample


vae_wrapper = VAEWrapper(vae)
torch.onnx.export(vae_wrapper, latent_noise, VAE_FILE_PATH)

In [13]:
# END