Skip to content

Commit

Permalink
fix(api): ensure VAE is loaded on correct device
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Sep 12, 2023
1 parent fd8b9be commit d4b0130
Showing 1 changed file with 43 additions and 27 deletions.
70 changes: 43 additions & 27 deletions api/onnx_web/diffusers/load.py
Expand Up @@ -7,7 +7,12 @@
ORTStableDiffusionXLImg2ImgPipeline,
ORTStableDiffusionXLPipeline,
)
from optimum.onnxruntime.modeling_diffusion import ORTModelTextEncoder, ORTModelUnet
from optimum.onnxruntime.modeling_diffusion import (
ORTModelTextEncoder,
ORTModelUnet,
ORTModelVaeDecoder,
ORTModelVaeEncoder,
)
from transformers import CLIPTokenizer

from ..constants import ONNX_MODEL
Expand Down Expand Up @@ -363,26 +368,40 @@ def load_pipeline(
sess_options=device.sess_options(),
)
)
elif (
not params.is_xl() and path.exists(vae_decoder) and path.exists(vae_encoder)
):
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
elif path.exists(vae_decoder) and path.exists(vae_encoder):
if params.is_xl():
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder_session"] = OnnxRuntimeModel.load_model(
vae_decoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)

logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
vae_encoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)

else:
logger.debug("loading VAE decoder from %s", vae_decoder)
components["vae_decoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
vae_decoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)

logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
vae_encoder,
provider=device.ort_provider("vae"),
sess_options=device.sess_options(),
)
)

# additional options for panorama pipeline
if params.is_panorama():
Expand All @@ -402,33 +421,30 @@ def load_pipeline(

# make sure XL models are actually being used
if "text_encoder_session" in components:
logger.info(
"text encoder matches: %s, %s",
pipe.text_encoder.session == components["text_encoder_session"],
type(pipe.text_encoder),
)
pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder)

if "text_encoder_2_session" in components:
logger.info(
"text encoder 2 matches: %s, %s",
pipe.text_encoder_2.session == components["text_encoder_2_session"],
type(pipe.text_encoder_2),
)
pipe.text_encoder_2 = ORTModelTextEncoder(
text_encoder_2_session, text_encoder_2
)

if "unet_session" in components:
logger.info(
"unet matches: %s, %s",
pipe.unet.session == components["unet_session"],
type(pipe.unet),
)
# unload old UNet first
pipe.unet = None
run_gc([device])
# load correct one
pipe.unet = ORTModelUnet(unet_session, unet_model)

if "vae_decoder_session" in components:
pipe.vae_decoder = ORTModelVaeDecoder(
components["vae_decoder_session"], vae_decoder
)

if "vae_encoder_session" in components:
pipe.vae_encoder = ORTModelVaeEncoder(
components["vae_encoder_session"], vae_encoder
)

if not server.show_progress:
pipe.set_progress_bar_config(disable=True)

Expand Down

0 comments on commit d4b0130

Please sign in to comment.