Skip to content

Commit

Permalink
fix(api): include SD upscaling in diffusion prefixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Dec 9, 2023
1 parent 46d9fc0 commit 293a1bb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
15 changes: 1 addition & 14 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DEFAULT_OPSET,
ConversionContext,
download_progress,
fix_diffusion_name,
remove_prefix,
source_format,
tuple_to_correction,
Expand Down Expand Up @@ -265,20 +266,6 @@ def fetch_model(
return source, False


DIFFUSION_PREFIX = ["diffusion-", "diffusion/", "diffusion\\", "stable-diffusion-"]


def fix_diffusion_name(name: str):
if not any([name.startswith(prefix) for prefix in DIFFUSION_PREFIX]):
logger.warning(
"diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match",
name,
)
return f"diffusion-{name}"

return name


def convert_models(conversion: ConversionContext, args, models: Models):
model_errors = []

Expand Down
14 changes: 14 additions & 0 deletions api/onnx_web/convert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,17 @@ def onnx_export(
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
)


DIFFUSION_PREFIX = ["diffusion-", "diffusion/", "diffusion\\", "stable-diffusion-", "upscaling-"]


def fix_diffusion_name(name: str):
if not any([name.startswith(prefix) for prefix in DIFFUSION_PREFIX]):
logger.warning(
"diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match",
name,
)
return f"diffusion-{name}"

return name

0 comments on commit 293a1bb

Please sign in to comment.