Skip to content

Commit

Permalink
fix(api): trim whitespace from model names because it breaks things (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Dec 10, 2023
1 parent 9c1fcd1 commit 4da4cd9
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/convert/diffusion/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def convert_diffusion_diffusers(
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
name = model.get("name")
name = str(model.get("name")).strip()
source = model.get("source")

# optional
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/convert/diffusion/diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def convert_diffusion_diffusers_xl(
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
name = model.get("name")
name = str(model.get("name")).strip()
source = model.get("source")
replace_vae = model.get("vae", None)

Expand Down
4 changes: 4 additions & 0 deletions api/onnx_web/server/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from jsonschema import ValidationError, validate

from ..convert.utils import fix_diffusion_name
from ..image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
Expand Down Expand Up @@ -189,6 +190,9 @@ def load_extras(server: ServerContext):
for model in data[model_type]:
model_name = model["name"]

if model_type == "diffusion":
model_name = fix_diffusion_name(model_name)

if "hash" in model:
logger.debug(
"collecting hash for model %s from %s",
Expand Down

0 comments on commit 4da4cd9

Please sign in to comment.