Skip to content

Commit

Permalink
fix(api): only use optimum's fp16 mode for SDXL export when torch fp1…
Browse files Browse the repository at this point in the history
…6 is enabled
  • Loading branch information
ssube committed Nov 17, 2023
1 parent b31227e commit eb3f147
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def main(args=None) -> int:
logger.info("CLI arguments: %s", args)

server = ConversionContext.from_environ()
server.half = args.half or "onnx-fp16" in server.optimizations
server.half = args.half or server.has_optimization("onnx-fp16")
server.opset = args.opset
server.token = args.token
logger.info(
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/convert/diffusion/diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def convert_diffusion_diffusers_xl(
output=dest_path,
task="stable-diffusion-xl",
device=device,
fp16=conversion.half,
fp16=conversion.has_optimization("torch-fp16"), # optimum's fp16 mode only works on CUDA or ROCm
framework="pt",
)

Expand Down
14 changes: 7 additions & 7 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,45 +563,45 @@ def optimize_pipeline(
pipe: StableDiffusionPipeline,
) -> None:
if (
"diffusers-attention-slicing" in server.optimizations
or "diffusers-attention-slicing-auto" in server.optimizations
server.has_optimization("diffusers-attention-slicing")
or server.has_optimization("diffusers-attention-slicing-auto")
):
logger.debug("enabling auto attention slicing on SD pipeline")
try:
pipe.enable_attention_slicing(slice_size="auto")
except Exception as e:
logger.warning("error while enabling auto attention slicing: %s", e)

if "diffusers-attention-slicing-max" in server.optimizations:
if server.has_optimization("diffusers-attention-slicing-max"):
logger.debug("enabling max attention slicing on SD pipeline")
try:
pipe.enable_attention_slicing(slice_size="max")
except Exception as e:
logger.warning("error while enabling max attention slicing: %s", e)

if "diffusers-vae-slicing" in server.optimizations:
if server.has_optimization("diffusers-vae-slicing"):
logger.debug("enabling VAE slicing on SD pipeline")
try:
pipe.enable_vae_slicing()
except Exception as e:
logger.warning("error while enabling VAE slicing: %s", e)

if "diffusers-cpu-offload-sequential" in server.optimizations:
if server.has_optimization("diffusers-cpu-offload-sequential"):
logger.debug("enabling sequential CPU offload on SD pipeline")
try:
pipe.enable_sequential_cpu_offload()
except Exception as e:
logger.warning("error while enabling sequential CPU offload: %s", e)

elif "diffusers-cpu-offload-model" in server.optimizations:
elif server.has_optimization("diffusers-cpu-offload-model"):
# TODO: check for accelerate
logger.debug("enabling model CPU offload on SD pipeline")
try:
pipe.enable_model_cpu_offload()
except Exception as e:
logger.warning("error while enabling model CPU offload: %s", e)

if "diffusers-memory-efficient-attention" in server.optimizations:
if server.has_optimization("diffusers-memory-efficient-attention"):
# TODO: check for xformers
logger.debug("enabling memory efficient attention for SD pipeline")
try:
Expand Down
5 changes: 4 additions & 1 deletion api/onnx_web/server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,11 @@ def from_environ(cls):
def has_feature(self, flag: str) -> bool:
return flag in self.feature_flags

def has_optimization(self, opt: str) -> bool:
return opt in self.optimizations

def torch_dtype(self):
if "torch-fp16" in self.optimizations:
if self.has_optimization("torch-fp16"):
return torch.float16
else:
return torch.float32

0 comments on commit eb3f147

Please sign in to comment.