Skip to content

Commit

Permalink
fix(api): use Torch pipelines while loading models for conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Sep 24, 2023
1 parent c99481f commit 5d3a7d7
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions api/onnx_web/convert/diffusion/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,22 @@
from typing import Any, Dict, Optional, Tuple, Union

import torch
from diffusers import AutoencoderKL, OnnxRuntimeModel, OnnxStableDiffusionPipeline
from diffusers import (
AutoencoderKL,
OnnxRuntimeModel,
OnnxStableDiffusionPipeline,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
from onnx import load_model, save_model

from ...constants import ONNX_MODEL, ONNX_WEIGHTS
from ...diffusers.load import available_pipelines, optimize_pipeline
from ...diffusers.load import optimize_pipeline
from ...diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet
Expand All @@ -33,6 +41,17 @@

logger = getLogger(__name__)

CONVERT_PIPELINES = {
"controlnet": OnnxStableDiffusionControlNetPipeline,
"img2img": StableDiffusionPipeline,
"inpaint": StableDiffusionPipeline,
"lpw": StableDiffusionPipeline,
"panorama": StableDiffusionPipeline,
"pix2pix": StableDiffusionInstructPix2PixPipeline,
"txt2img": StableDiffusionPipeline,
"upscale": StableDiffusionUpscalePipeline,
}


def get_model_version(
source,
Expand Down Expand Up @@ -295,7 +314,7 @@ def convert_diffusion_diffusers(
logger.info("ONNX model already exists, skipping")
return (False, dest_path)

pipe_class = available_pipelines.get(pipe_type)
pipe_class = CONVERT_PIPELINES.get(pipe_type)
v2, pipe_args = get_model_version(
source, conversion.map_location, size=image_size, version=version
)
Expand Down

0 comments on commit 5d3a7d7

Please sign in to comment.