Skip to content

Commit

Permalink
fix(api): check diffusers version before imports (#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 15, 2023
1 parent ad5c69e commit 95841ff
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 32 deletions.
5 changes: 2 additions & 3 deletions api/onnx_web/convert/diffusion/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from typing import Dict

import torch
from diffusers.models.controlnet import ControlNetModel
from diffusers.models.cross_attention import CrossAttnProcessor
from ...diffusers.version_safe_diffusers import AttnProcessor, ControlNetModel

from ...constants import ONNX_MODEL
from ..utils import ConversionContext, is_torch_2_0, onnx_export
Expand Down Expand Up @@ -43,7 +42,7 @@ def convert_diffusion_control(

# UNET
if is_torch_2_0:
controlnet.set_attn_processor(CrossAttnProcessor())
controlnet.set_attn_processor(AttnProcessor())

cnet_path = output_path / "cnet" / ONNX_MODEL
onnx_export(
Expand Down
6 changes: 3 additions & 3 deletions api/onnx_web/convert/diffusion/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
OnnxStableDiffusionPipeline,
StableDiffusionPipeline,
)
from diffusers.models.cross_attention import CrossAttnProcessor
from onnx import load_model, save_model

from ...constants import ONNX_MODEL, ONNX_WEIGHTS
from ...diffusers.load import optimize_pipeline
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet
from ..utils import ConversionContext, is_torch_2_0, onnx_export

Expand All @@ -51,7 +51,7 @@ def convert_diffusion_diffusers_cnet(
)

if is_torch_2_0:
pipe_cnet.set_attn_processor(CrossAttnProcessor())
pipe_cnet.set_attn_processor(AttnProcessor())

cnet_path = output_path / "cnet" / ONNX_MODEL
onnx_export(
Expand Down Expand Up @@ -262,7 +262,7 @@ def convert_diffusion_diffusers(
unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool)

if is_torch_2_0:
pipeline.unet.set_attn_processor(CrossAttnProcessor())
pipeline.unet.set_attn_processor(AttnProcessor())

unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size
Expand Down
42 changes: 17 additions & 25 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,24 @@

import numpy as np
import torch
from diffusers import (
from onnx import load_model
from transformers import CLIPTokenizer

from ..constants import ONNX_MODEL
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.utils import expand_prompt
from ..models.meta import NetworkModel
from ..params import DeviceParams, Size
from ..server import ServerContext
from ..utils import run_gc
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
from .pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
from .version_safe_diffusers import (
DDIMScheduler,
DDPMScheduler,
DEISMultistepScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
Expand All @@ -21,31 +36,8 @@
OnnxStableDiffusionPipeline,
PNDMScheduler,
StableDiffusionPipeline,
UniPCMultistepScheduler,
)
from onnx import load_model
from transformers import CLIPTokenizer

try:
from diffusers import DEISMultistepScheduler
except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler

try:
from diffusers import UniPCMultistepScheduler
except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler

from ..constants import ONNX_MODEL
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from ..diffusers.pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
from ..diffusers.utils import expand_prompt
from ..models.meta import NetworkModel
from ..params import DeviceParams, Size
from ..server import ServerContext
from ..utils import run_gc
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline

logger = getLogger(__name__)

Expand Down
30 changes: 30 additions & 0 deletions api/onnx_web/diffusers/version_safe_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import diffusers
from diffusers import * # NOQA
from packaging import version

is_diffusers_0_15 = version.parse(
version.parse(diffusers.__version__).base_version
) >= version.parse("0.15")


try:
from diffusers import DEISMultistepScheduler # NOQA
except ImportError:
from ..diffusers.stub_scheduler import (
StubScheduler as DEISMultistepScheduler, # NOQA
)

try:
from diffusers import UniPCMultistepScheduler # NOQA
except ImportError:
from ..diffusers.stub_scheduler import (
StubScheduler as UniPCMultistepScheduler, # NOQA
)


if is_diffusers_0_15:
from diffusers.models.attention_processor import AttnProcessor # NOQA
else:
from diffusers.models.cross_attention import (
CrossAttnProcessor as AttnProcessor, # NOQA
)
3 changes: 2 additions & 1 deletion api/onnx_web/image/laion_face.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# from https://github.com/ForserX/StableDiffusionUI/blob/main/data/repo/diffusion_scripts/modules/controlnet/laion_face_common.py
# from https://huggingface.co/CrucibleAI/ControlNetMediaPipeFace/blob/main/laion_face_common.py
# and https://github.com/ForserX/StableDiffusionUI/blob/main/data/repo/diffusion_scripts/modules/controlnet/laion_face_common.py

from typing import Mapping

Expand Down

0 comments on commit 95841ff

Please sign in to comment.