Skip to content

Commit

Permalink
feat(api): add conversion for SDXL models
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Sep 10, 2023
1 parent 78f834a commit fe68670
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 8 deletions.
28 changes: 20 additions & 8 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from onnx import load_model, save_model
from transformers import CLIPTokenizer

from .diffusion.diffusion_xl import convert_diffusion_diffusers_xl

from ..constants import ONNX_MODEL, ONNX_WEIGHTS
from ..utils import load_config
from .correction.gfpgan import convert_correction_gfpgan
from .diffusion.control import convert_diffusion_control
from .diffusion.diffusers import convert_diffusion_diffusers
from .diffusion.diffusion import convert_diffusion_diffusers
from .diffusion.lora import blend_loras
from .diffusion.textual_inversion import blend_textual_inversions
from .upscaling.bsrgan import convert_upscaling_bsrgan
Expand Down Expand Up @@ -357,13 +359,23 @@ def convert_models(conversion: ConversionContext, args, models: Models):
conversion, name, model["source"], format=model_format
)

converted, dest = convert_diffusion_diffusers(
conversion,
model,
source,
model_format,
hf=hf,
)
pipeline = model.get("pipeline", "txt2img")
if pipeline.endswith("-sdxl"):
converted, dest = convert_diffusion_diffusers_xl(
conversion,
model,
source,
model_format,
hf=hf,
)
else:
converted, dest = convert_diffusion_diffusers(
conversion,
model,
source,
model_format,
hf=hf,
)

# make sure blending only happens once, not every run
if converted:
Expand Down
File renamed without changes.
72 changes: 72 additions & 0 deletions api/onnx_web/convert/diffusion/diffusion_xl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from logging import getLogger
from os import path
from typing import Dict, Optional, Tuple

import torch
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from optimum.exporters.onnx import main_export

from ..utils import ConversionContext

logger = getLogger(__name__)


@torch.no_grad()
def convert_diffusion_diffusers_xl(
conversion: ConversionContext,
model: Dict,
source: str,
format: Optional[str],
hf: bool = False,
) -> Tuple[bool, str]:
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
name = model.get("name")
# TODO: support alternate VAE

device = conversion.training_device
dtype = conversion.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", dtype)

dest_path = path.join(conversion.model_path, name)
model_index = path.join(dest_path, "model_index.json")
model_hash = path.join(dest_path, "hash.txt")

# diffusers go into a directory rather than .onnx file
logger.info(
"converting Stable Diffusion XL model %s: %s -> %s/", name, source, dest_path
)

if "hash" in model and not path.exists(model_hash):
logger.info("ONNX model does not have hash file, adding one")
with open(model_hash, "w") as f:
f.write(model["hash"])

if path.exists(dest_path) and path.exists(model_index):
logger.info("ONNX model already exists, skipping conversion")
return (False, dest_path)

# safetensors -> diffusers directory with torch models
temp_path = path.join(conversion.cache_path, f"{name}-torch")

if format == "safetensors":
pipeline = StableDiffusionXLPipeline.from_single_file(source, use_safetensors=True)
else:
pipeline = StableDiffusionXLPipeline.from_pretrained(source)

pipeline.save_pretrained(temp_path)

# directory -> onnx using optimum exporters
main_export(
temp_path,
output=dest_path,
task="stable-diffusion-xl",
device=device,
fp16=conversion.half,
framework="pt",
)

# TODO: optimize UNet to fp16

return False, dest_path

0 comments on commit fe68670

Please sign in to comment.