Skip to content

Commit

Permalink
feat(api): add flag for ORT float16 optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 19, 2023
1 parent e4b59f0 commit 1c631c2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
18 changes: 12 additions & 6 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,16 +353,21 @@ def optimize_pipeline(


class UNetWrapper(object):
def __init__(self, wrapped):
def __init__(self, server, wrapped):
self.server = server
self.wrapped = wrapped

def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
global timestep_dtype
timestep_dtype = timestep.dtype

logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
if sample.dtype != timestep.dtype:
logger.info("converting UNet sample dtype")
if "onnx-fp16" in self.server.optimizations:
logger.info("converting UNet sample to ONNX fp16")
sample = sample.astype(np.float16)
encoder_hidden_states = encoder_hidden_states.astype(np.float16)
elif sample.dtype != timestep.dtype:
logger.info("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)

return self.wrapped(
Expand All @@ -376,7 +381,8 @@ def __getattr__(self, attr):


class VAEWrapper(object):
def __init__(self, wrapped):
def __init__(self, server, wrapped):
self.server = server
self.wrapped = wrapped

def __call__(self, latent_sample=None):
Expand Down Expand Up @@ -404,5 +410,5 @@ def patch_pipeline(
original_unet = pipe.unet
original_vae = pipe.vae_decoder

pipe.unet = UNetWrapper(original_unet)
pipe.vae_decoder = VAEWrapper(original_vae)
pipe.unet = UNetWrapper(server, original_unet)
pipe.vae_decoder = VAEWrapper(server, original_vae)
14 changes: 9 additions & 5 deletions docs/server-admin.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,24 @@ Others:
- not available for ONNX pipelines (most of them)
- https://huggingface.co/docs/diffusers/optimization/fp16#sliced-vae-decode-for-larger-batches
- `onnx-*`
- `onnx-low-memory`
- disable ONNX features that allocate more memory than is strictly required or keep memory after use
- `onnx-deterministic-compute`
- enable ONNX deterministic compute
- `onnx-fp16`
- force 16-bit floating point values when running pipelines
- use with https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/stable_diffusion#optimize-onnx-pipeline
and the `--float16` flag
- `onnx-graph-*`
- `onnx-graph-disable`
- disable all ONNX graph optimizations
- `onnx-graph-basic`
- enable basic ONNX graph optimizations
- `onnx-graph-all`
- enable all ONNX graph optimizations
- `onnx-deterministic-compute`
- enable ONNX deterministic compute
- `onnx-low-memory`
- disable ONNX features that allocate more memory than is strictly required or keep memory after use
- `torch-*`
- `torch-fp16`
- use 16-bit floating point values when loading and converting pipelines
- use 16-bit floating point values when converting and running pipelines
- applies during conversion as well
- only available on CUDA platform

Expand Down

0 comments on commit 1c631c2

Please sign in to comment.