Skip to content

Commit

Permalink
fix(api): various controlnet fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Dec 20, 2023
1 parent a716f6d commit ba9982a
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 4 deletions.
8 changes: 8 additions & 0 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,14 @@ def main(args=None) -> int:
server.opset = args.opset
server.token = args.token

# debug options
if server.debug:
import debugpy

debugpy.listen(5678)
logger.warning("waiting for debugger")
debugpy.wait_for_client()

register_plugins(server)

logger.info(
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/convert/diffusion/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def convert_diffusion_diffusers(
run_gc()

if conversion.control and not single_vae and not conversion.share_unet:
cnet_source = torch_source or source
cnet_source = torch_source or cache_path
logger.info("loading and converting CNet from %s", cnet_source)
cnet_path = convert_diffusion_diffusers_cnet(
conversion,
Expand Down
3 changes: 2 additions & 1 deletion api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ def load_pipeline(
pipeline_class.__name__,
)
pipe = pipeline_class(
components["vae"],
components["vae_encoder"],
components["vae_decoder"],
components["text_encoder"],
components["tokenizer"],
components["unet"],
Expand Down
4 changes: 2 additions & 2 deletions api/onnx_web/image/source_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def source_filter_hed(server: ServerContext, source: Image.Image) -> Image.Image
logger.debug("running HED detection on source image")

hed = HEDdetector.from_pretrained(
"lllyasviel/ControlNet",
"lllyasviel/Annotators",
cache_dir=server.cache_path,
)
image = hed(source)
Expand All @@ -172,7 +172,7 @@ def source_filter_scribble(server: ServerContext, source: Image.Image) -> Image.
logger.debug("running scribble detection on source image")

hed = HEDdetector.from_pretrained(
"lllyasviel/ControlNet",
"lllyasviel/Annotators",
cache_dir=server.cache_path,
)
image = hed(source, scribble=True)
Expand Down

0 comments on commit ba9982a

Please sign in to comment.