Skip to content

Commit

Permalink
fix(api): allow SDXL VAE in any supported tensor format, ensure new S…
Browse files Browse the repository at this point in the history
…DXL models get hash file
  • Loading branch information
ssube committed Oct 7, 2023
1 parent 047e58c commit 1351b2f
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions api/onnx_web/convert/diffusion/diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from optimum.exporters.onnx import main_export

from ...constants import ONNX_MODEL
from ..utils import ConversionContext
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext

logger = getLogger(__name__)

Expand Down Expand Up @@ -42,13 +42,14 @@ def convert_diffusion_diffusers_xl(
"converting Stable Diffusion XL model %s: %s -> %s/", name, source, dest_path
)

if "hash" in model and not path.exists(model_hash):
logger.info("ONNX model does not have hash file, adding one")
with open(model_hash, "w") as f:
f.write(model["hash"])

if path.exists(dest_path) and path.exists(model_index):
logger.info("ONNX model already exists, skipping conversion")

if "hash" in model and not path.exists(model_hash):
logger.info("ONNX model does not have hash file, adding one")
with open(model_hash, "w") as f:
f.write(model["hash"])

return (False, dest_path)

# safetensors -> diffusers directory with torch models
Expand All @@ -63,7 +64,7 @@ def convert_diffusion_diffusers_xl(

if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae)
if replace_vae.endswith(".safetensors"):
if check_ext(replace_vae, RESOLVE_FORMATS):
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
else:
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
Expand All @@ -80,6 +81,11 @@ def convert_diffusion_diffusers_xl(
framework="pt",
)

if "hash" in model:
logger.debug("adding hash file to ONNX model")
with open(model_hash, "w") as f:
f.write(model["hash"])

if conversion.half:
unet_path = path.join(dest_path, "unet", ONNX_MODEL)
infer_shapes_path(unet_path)
Expand Down

0 comments on commit 1351b2f

Please sign in to comment.