Skip to content

Commit

Permalink
fix(api): run shape inference before converting models to fp16
Browse files Browse the repository at this point in the history
per discussion in microsoft/onnxruntime#14827
  • Loading branch information
ssube committed Mar 1, 2023
1 parent 86984be commit dbf9eaf
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 16 deletions.
6 changes: 3 additions & 3 deletions api/logging.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ formatters:
handlers:
console:
class: logging.StreamHandler
level: DEBUG
level: INFO
formatter: simple
stream: ext://sys.stdout
loggers:
'':
level: DEBUG
level: INFO
handlers: [console]
propagate: True
root:
level: DEBUG
level: INFO
handlers: [console]
2 changes: 0 additions & 2 deletions api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def load_stable_diffusion(
model_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
torch_dtype=torch.float16,
)
else:
logger.debug(
Expand All @@ -51,7 +50,6 @@ def load_stable_diffusion(
pipe = StableDiffusionUpscalePipeline.from_pretrained(
model_path,
provider=device.provider,
torch_dtype=torch.float16,
)

if not server.show_progress:
Expand Down
20 changes: 16 additions & 4 deletions api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from logging import getLogger
from os import makedirs, path
from sys import exit
from typing import Any, Dict, List, Optional, Tuple
from traceback import format_exception
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse

from jsonschema import ValidationError, validate
Expand Down Expand Up @@ -242,7 +242,11 @@ def convert_models(ctx: ConversionContext, args, models: Models):
)

except Exception as e:
logger.error("error converting diffusion model %s: %s", name, format_exception(type(e), e, e.__traceback__))
logger.error(
"error converting diffusion model %s: %s",
name,
format_exception(type(e), e, e.__traceback__),
)

if args.upscaling and "upscaling" in models:
for model in models.get("upscaling"):
Expand All @@ -260,7 +264,11 @@ def convert_models(ctx: ConversionContext, args, models: Models):
)
convert_upscale_resrgan(ctx, model, source)
except Exception as e:
logger.error("error converting upscaling model %s: %s", name, format_exception(type(e), e, e.__traceback__))
logger.error(
"error converting upscaling model %s: %s",
name,
format_exception(type(e), e, e.__traceback__),
)

if args.correction and "correction" in models:
for model in models.get("correction"):
Expand All @@ -277,7 +285,11 @@ def convert_models(ctx: ConversionContext, args, models: Models):
)
convert_correction_gfpgan(ctx, model, source)
except Exception as e:
logger.error("error converting correction model %s: %s", name, format_exception(type(e), e, e.__traceback__))
logger.error(
"error converting correction model %s: %s",
name,
format_exception(type(e), e, e.__traceback__),
)


def main() -> int:
Expand Down
18 changes: 12 additions & 6 deletions api/onnx_web/convert/diffusion/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
StableDiffusionPipeline,
)
from onnx import load_model, save_model
from onnx.shape_inference import infer_shapes_path
from onnxruntime.transformers.float16 import convert_float_to_float16
from torch.onnx import export

Expand Down Expand Up @@ -64,19 +65,24 @@ def onnx_export(
)

if half:
logger.info("converting model to FP16 internally")
logger.info("converting model to FP16 internally: %s", output_file)
infer_shapes_path(output_file)
base_model = load_model(output_file)
opt_model = convert_float_to_float16(base_model, keep_io_types=True, force_fp16_initializers=True)
opt_model = convert_float_to_float16(
base_model,
disable_shape_infer=True,
keep_io_types=True,
force_fp16_initializers=True,
)
save_model(
opt_model,
f"{output_file}-optimized",
f"{output_file}",
save_as_external_data=external_data,
all_tensors_to_one_file=True,
location=f"{output_file}-tensors",
location=f"{output_file}-weights",
)



@torch.no_grad()
def convert_diffusion_diffusers(
ctx: ConversionContext,
Expand All @@ -91,7 +97,7 @@ def convert_diffusion_diffusers(
single_vae = model.get("single_vae")
replace_vae = model.get("vae")

dtype = torch.float32 # torch.float16 if ctx.half else torch.float32
dtype = torch.float32 # torch.float16 if ctx.half else torch.float32
dest_path = path.join(ctx.model_path, name)

# diffusers go into a directory rather than .onnx file
Expand Down
1 change: 0 additions & 1 deletion api/onnx_web/diffusion/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def load_pipeline(
provider=device.ort_provider(),
sess_options=device.sess_options(),
subfolder="scheduler",
torch_dtype=torch.float16,
)

if device is not None and hasattr(scheduler, "to"):
Expand Down

0 comments on commit dbf9eaf

Please sign in to comment.