Skip to content

Commit

Permalink
fix(api): specify input channels when converting inpainting models (#356
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ssube committed Apr 29, 2023
1 parent 4c12615 commit 0175d7e
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion api/onnx_web/convert/diffusion/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
available_pipelines = {
"controlnet": StableDiffusionControlNetPipeline,
"img2img": StableDiffusionPipeline,
"inpaint": StableDiffusionInpaintPipeline,
"inpaint": StableDiffusionPipeline,
"lpw": StableDiffusionPipeline,
"panorama": StableDiffusionPanoramaPipeline,
"pix2pix": StableDiffusionInstructPix2PixPipeline,
Expand Down Expand Up @@ -229,6 +229,10 @@ def convert_diffusion_diffusers(
return (False, dest_path)

pipe_class = available_pipelines.get(pipe_type)
pipe_args = {}

if pipe_type == "inpaint":
pipe_args["num_in_channels"] = 9

if path.exists(source) and path.isdir(source):
logger.debug("loading pipeline from diffusers directory: %s", source)
Expand All @@ -242,6 +246,7 @@ def convert_diffusion_diffusers(
pipeline = pipe_class.from_ckpt(
source,
torch_dtype=dtype,
**pipe_args,
).to(device)
else:
logger.warning("pipeline source not found or not recognized: %s", source)
Expand Down

0 comments on commit 0175d7e

Please sign in to comment.