Skip to content

Commit

Permalink
fix(api): switch to diffusers ckpt loading, add more pipelines to con…
Browse files Browse the repository at this point in the history
…version (#337, #356)
  • Loading branch information
ssube committed Apr 29, 2023
1 parent d699c75 commit 4c12615
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 1,748 deletions.
20 changes: 5 additions & 15 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
from .diffusion.control import convert_diffusion_control
from .diffusion.diffusers import convert_diffusion_diffusers
from .diffusion.lora import blend_loras
from .diffusion.original import convert_diffusion_original
from .diffusion.textual_inversion import blend_textual_inversions
from .upscaling.bsrgan import convert_upscaling_bsrgan
from .upscaling.resrgan import convert_upscale_resrgan
from .upscaling.swinir import convert_upscaling_swinir
from .utils import (
ConversionContext,
download_progress,
model_formats_original,
remove_prefix,
source_format,
tuple_to_correction,
Expand Down Expand Up @@ -351,19 +349,11 @@ def convert_models(conversion: ConversionContext, args, models: Models):
conversion, name, model["source"], format=model_format
)

converted = False
if model_format in model_formats_original:
converted, dest = convert_diffusion_original(
conversion,
model,
source,
)
else:
converted, dest = convert_diffusion_diffusers(
conversion,
model,
source,
)
converted, dest = convert_diffusion_diffusers(
conversion,
model,
source,
)

# make sure blending only happens once, not every run
if converted:
Expand Down
37 changes: 28 additions & 9 deletions api/onnx_web/convert/diffusion/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
AutoencoderKL,
OnnxRuntimeModel,
OnnxStableDiffusionPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionPanoramaPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
from onnx import load_model, save_model

Expand All @@ -34,6 +38,17 @@

logger = getLogger(__name__)

available_pipelines = {
"controlnet": StableDiffusionControlNetPipeline,
"img2img": StableDiffusionPipeline,
"inpaint": StableDiffusionInpaintPipeline,
"lpw": StableDiffusionPipeline,
"panorama": StableDiffusionPanoramaPipeline,
"pix2pix": StableDiffusionInstructPix2PixPipeline,
"txt2img": StableDiffusionPipeline,
"upscale": StableDiffusionUpscalePipeline,
}


def convert_diffusion_diffusers_cnet(
conversion: ConversionContext,
Expand Down Expand Up @@ -184,7 +199,7 @@ def convert_diffusion_diffusers(
source = source or model.get("source")
single_vae = model.get("single_vae")
replace_vae = model.get("vae")
pipe_type = model.get("pipeline", "image")
pipe_type = model.get("pipeline", "txt2img")

device = conversion.training_device
dtype = conversion.torch_dtype()
Expand Down Expand Up @@ -213,25 +228,29 @@ def convert_diffusion_diffusers(
logger.info("ONNX model already exists, skipping")
return (False, dest_path)

if pipe_type == "image":
pipeline = StableDiffusionPipeline.from_pretrained(
pipe_class = available_pipelines.get(pipe_type)

if path.exists(source) and path.isdir(source):
logger.debug("loading pipeline from diffusers directory: %s", source)
pipeline = pipe_class.from_pretrained(
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
elif pipe_type == "inpaint":
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
elif path.exists(source) and path.isfile(source):
logger.debug("loading pipeline from SD checkpoint: %s", source)
pipeline = pipe_class.from_ckpt(
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
else:
raise ValueError(f"unknown pipeline type: {pipe_type}")

output_path = Path(dest_path)
logger.warning("pipeline source not found or not recognized: %s", source)
raise ValueError(f"pipeline source not found or not recognized: {source}")

optimize_pipeline(conversion, pipeline)

output_path = Path(dest_path)

# TEXT ENCODER
num_tokens = pipeline.text_encoder.config.max_position_embeddings
text_hidden_size = pipeline.text_encoder.config.hidden_size
Expand Down

0 comments on commit 4c12615

Please sign in to comment.