Skip to content

Commit

Permalink
fix(api): only fetch diffusion models if they have not already been c…
Browse files Browse the repository at this point in the history
…onverted (#398)
  • Loading branch information
ssube committed Dec 10, 2023
1 parent c9b1df9 commit 9c1fcd1
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 45 deletions.
2 changes: 0 additions & 2 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,12 @@ def convert_model_diffusion(conversion: ConversionContext, model):
model["name"] = name

model_format = source_format(model)
dest = fetch_model(conversion, name, model["source"], format=model_format)

pipeline = model.get("pipeline", "txt2img")
converter = model_converters.get(pipeline)
converted, dest = converter(
conversion,
model,
dest,
model_format,
)

Expand Down
90 changes: 51 additions & 39 deletions api/onnx_web/convert/diffusion/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@
from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet
from ...utils import run_gc
from ..client import fetch_model
from ..client.huggingface import HuggingfaceClient
from ..utils import (
RESOLVE_FORMATS,
ConversionContext,
check_ext,
is_torch_2_0,
load_tensor,
onnx_export,
remove_prefix,
)
from .checkpoint import convert_extract_checkpoint

Expand Down Expand Up @@ -267,14 +270,13 @@ def collate_cnet(cnet_path):
def convert_diffusion_diffusers(
conversion: ConversionContext,
model: Dict,
source: str,
format: Optional[str],
hf: bool = False,
) -> Tuple[bool, str]:
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
name = model.get("name")
source = model.get("source")

# optional
config = model.get("config", None)
Expand Down Expand Up @@ -320,9 +322,11 @@ def convert_diffusion_diffusers(
logger.info("ONNX model already exists, skipping")
return (False, dest_path)

cache_path = fetch_model(conversion, name, source, format=format)

pipe_class = CONVERT_PIPELINES.get(pipe_type)
v2, pipe_args = get_model_version(
source, conversion.map_location, size=image_size, version=version
cache_path, conversion.map_location, size=image_size, version=version
)

is_inpainting = False
Expand All @@ -334,50 +338,58 @@ def convert_diffusion_diffusers(
pipe_args["from_safetensors"] = True

torch_source = None
if path.exists(source) and path.isdir(source):
logger.debug("loading pipeline from diffusers directory: %s", source)
pipeline = pipe_class.from_pretrained(
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
elif path.exists(source) and path.isfile(source):
if conversion.extract:
logger.debug("extracting SD checkpoint to Torch models: %s", source)
torch_source = convert_extract_checkpoint(
conversion,
source,
f"{name}-torch",
is_inpainting=is_inpainting,
config_file=config,
vae_file=replace_vae,
)
logger.debug("loading pipeline from extracted checkpoint: %s", torch_source)
if path.exists(cache_path):
if path.isdir(cache_path):
logger.debug("loading pipeline from diffusers directory: %s", source)
pipeline = pipe_class.from_pretrained(
torch_source,
cache_path,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)

# VAE replacement already happened during extraction, skip
replace_vae = None
else:
logger.debug("loading pipeline from SD checkpoint: %s", source)
pipeline = download_from_original_stable_diffusion_ckpt(
source,
original_config_file=config_path,
pipeline_class=pipe_class,
**pipe_args,
).to(device, torch_dtype=dtype)
elif hf:
logger.debug("downloading pretrained model from Huggingface hub: %s", source)
elif path.isfile(source):
if conversion.extract:
logger.debug("extracting SD checkpoint to Torch models: %s", source)
torch_source = convert_extract_checkpoint(
conversion,
source,
f"{name}-torch",
is_inpainting=is_inpainting,
config_file=config,
vae_file=replace_vae,
)
logger.debug(
"loading pipeline from extracted checkpoint: %s", torch_source
)
pipeline = pipe_class.from_pretrained(
torch_source,
torch_dtype=dtype,
).to(device)

# VAE replacement already happened during extraction, skip
replace_vae = None
else:
logger.debug("loading pipeline from SD checkpoint: %s", source)
pipeline = download_from_original_stable_diffusion_ckpt(
source,
original_config_file=config_path,
pipeline_class=pipe_class,
**pipe_args,
).to(device, torch_dtype=dtype)
elif source.startswith(HuggingfaceClient.protocol):
hf_path = remove_prefix(source, HuggingfaceClient.protocol)
logger.debug("downloading pretrained model from Huggingface hub: %s", hf_path)
pipeline = pipe_class.from_pretrained(
source,
hf_path,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
else:
logger.warning("pipeline source not found or not recognized: %s", source)
raise ValueError(f"pipeline source not found or not recognized: {source}")
logger.warning(
"pipeline source not found and protocol not recognized: %s", source
)
raise ValueError(
f"pipeline source not found and protocol not recognized: {source}"
)

if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae)
Expand Down
9 changes: 5 additions & 4 deletions api/onnx_web/convert/diffusion/diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from optimum.exporters.onnx import main_export

from ...constants import ONNX_MODEL
from ..client import fetch_model
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext

logger = getLogger(__name__)
Expand All @@ -19,14 +20,13 @@
def convert_diffusion_diffusers_xl(
conversion: ConversionContext,
model: Dict,
source: str,
format: Optional[str],
hf: bool = False,
) -> Tuple[bool, str]:
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
name = model.get("name")
source = model.get("source")
replace_vae = model.get("vae", None)

device = conversion.training_device
Expand All @@ -52,15 +52,16 @@ def convert_diffusion_diffusers_xl(

return (False, dest_path)

cache_path = fetch_model(conversion, name, model["source"], format=format)
# safetensors -> diffusers directory with torch models
temp_path = path.join(conversion.cache_path, f"{name}-torch")

if format == "safetensors":
pipeline = StableDiffusionXLPipeline.from_single_file(
source, use_safetensors=True
cache_path, use_safetensors=True
)
else:
pipeline = StableDiffusionXLPipeline.from_pretrained(source)
pipeline = StableDiffusionXLPipeline.from_pretrained(cache_path)

if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae)
Expand Down

0 comments on commit 9c1fcd1

Please sign in to comment.