Skip to content

Commit

Permalink
fix(api): pass correct text model type when converting v2 checkpoints (
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 30, 2023
1 parent 4eba9a6 commit 2690eaf
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 7 deletions.
59 changes: 53 additions & 6 deletions api/onnx_web/convert/diffusion/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from os import mkdir, path
from pathlib import Path
from shutil import rmtree
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union

import torch
from diffusers import (
Expand All @@ -25,14 +25,17 @@
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
from onnx import load_model, save_model

from ...constants import ONNX_MODEL, ONNX_WEIGHTS
from ...diffusers.load import optimize_pipeline
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet
from ..utils import ConversionContext, is_torch_2_0, onnx_export
from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export

logger = getLogger(__name__)

Expand All @@ -48,6 +51,47 @@
}


def get_model_version(
checkpoint,
size=None,
) -> Tuple[bool, Dict[str, Union[bool, int, str]]]:
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
print("global_step key not found in model")
global_step = None

if size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
size = 512 if global_step == 875000 else 768

v2 = False
opts = {
"extract_ema": True,
"image_size": size,
}

key_name = (
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
)
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
v2 = True
if size != 512:
# v2.1 needs to upcast attention
logger.debug("setting upcast_attention")
opts["upcast_attention"] = True

if v2 and size != 512:
opts["model_type"] = "FrozenOpenCLIPEmbedder"
opts["prediction_type"] = "v_prediction"
else:
opts["model_type"] = "FrozenCLIPEmbedder"
opts["prediction_type"] = "epsilon"

return (v2, opts)


def convert_diffusion_diffusers_cnet(
conversion: ConversionContext,
source: str,
Expand Down Expand Up @@ -199,16 +243,18 @@ def convert_diffusion_diffusers(
"""
name = model.get("name")
source = source or model.get("source")
config = model.get("config", None)
single_vae = model.get("single_vae")
replace_vae = model.get("vae")
pipe_type = model.get("pipeline", "txt2img")
pipe_config = model.get("config", None)

device = conversion.training_device
dtype = conversion.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", dtype)

config_path = None if pipe_config is None else path.join(conversion.model_path, "config", pipe_config)
config_path = (
None if config is None else path.join(conversion.model_path, "config", config)
)
dest_path = path.join(conversion.model_path, name)
model_index = path.join(dest_path, "model_index.json")
model_cnet = path.join(dest_path, "cnet", ONNX_MODEL)
Expand All @@ -233,7 +279,7 @@ def convert_diffusion_diffusers(
return (False, dest_path)

pipe_class = available_pipelines.get(pipe_type)
pipe_args = {}
v2, pipe_args = get_model_version(load_tensor(source, conversion.map_location))

if pipe_type == "inpaint":
pipe_args["num_in_channels"] = 9
Expand All @@ -247,9 +293,10 @@ def convert_diffusion_diffusers(
).to(device)
elif path.exists(source) and path.isfile(source):
logger.debug("loading pipeline from SD checkpoint: %s", source)
pipeline = pipe_class.from_ckpt(
pipeline = download_from_original_stable_diffusion_ckpt(
source,
original_config_file=config_path,
pipeline_class=pipe_class,
torch_dtype=dtype,
**pipe_args,
).to(device)
Expand Down
4 changes: 3 additions & 1 deletion api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def run_loopback(
# load img2img pipeline once
pipe_type = params.get_valid_pipeline("img2img")
if pipe_type == "controlnet":
logger.debug("controlnet pipeline cannot be used for loopback, switching to img2img")
logger.debug(
"controlnet pipeline cannot be used for loopback, switching to img2img"
)
pipe_type = "img2img"

logger.debug("using %s pipeline for loopback", pipe_type)
Expand Down

0 comments on commit 2690eaf

Please sign in to comment.