Skip to content

Commit

Permalink
feat(api): add option to reload CNet for conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed May 20, 2023
1 parent 20107f5 commit 9c28154
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
23 changes: 19 additions & 4 deletions api/onnx_web/convert/diffusion/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,8 @@ def convert_diffusion_diffusers(
)

cnet_path = None
if conversion.control and not single_vae:
# if converting only the CNet, the rest of the model has already been converted
if conversion.control and not single_vae and conversion.share_unet:
logger.debug("converting CNet from loaded UNet")
cnet_path = convert_diffusion_diffusers_cnet(
conversion,
source,
Expand All @@ -465,12 +465,27 @@ def convert_diffusion_diffusers(
unet=pipeline.unet,
v2=v2,
)
else:
logger.debug("skipping CNet for single-VAE model")

del pipeline.unet
run_gc()

if conversion.control and not single_vae and not conversion.share_unet:
logger.info("loading and converting CNet")
cnet_path = convert_diffusion_diffusers_cnet(
conversion,
source,
device,
output_path,
dtype,
unet_in_channels,
unet_sample_size,
num_tokens,
text_hidden_size,
unet=None,
v2=v2,
)


if cnet_path is not None:
collate_cnet(cnet_path)

Expand Down
5 changes: 4 additions & 1 deletion api/onnx_web/convert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
prune: Optional[List[str]] = None,
control: bool = True,
reload: bool = True,
share_unet: bool = True,
**kwargs,
) -> None:
super().__init__(model_path=model_path, cache_path=cache_path, **kwargs)
Expand All @@ -53,6 +54,7 @@ def __init__(
self.opset = opset
self.prune = prune or []
self.reload = reload
self.share_unet = share_unet
self.token = token

if device is not None:
Expand All @@ -66,8 +68,9 @@ def __init__(
def from_environ(cls):
context = super().from_environ()
context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True)
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))
context.reload = get_boolean(environ, "ONNX_WEB_CONVERT_RELOAD", True)
context.share_unet = get_boolean(environ, "ONNX_WEB_CONVERT_SHARE_UNET", True)
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))
return context


Expand Down

0 comments on commit 9c28154

Please sign in to comment.