Skip to content

Commit

Permalink
feat(api): enable optimizations for SD pipelines based on env vars (#155
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ssube committed Feb 18, 2023
1 parent ff57527 commit ab6462d
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 0 deletions.
3 changes: 3 additions & 0 deletions api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from diffusers import StableDiffusionUpscalePipeline
from PIL import Image

from ..diffusion.load import optimize_pipeline
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
Expand Down Expand Up @@ -52,6 +53,8 @@ def load_stable_diffusion(
if not server.show_progress:
pipe.set_progress_bar_config(disable=True)

optimize_pipeline(server, pipe)

server.cache.set("diffusion", cache_key, pipe)
run_gc([device])

Expand Down
29 changes: 29 additions & 0 deletions api/onnx_web/diffusion/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
KDPM2DiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
)

try:
Expand Down Expand Up @@ -87,6 +88,32 @@ def get_tile_latents(
return full_latents[:, :, y:yt, x:xt]


def optimize_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,
) -> None:
if "attention-slicing" in server.optimizations:
logger.debug("enabling attention slicing on SD pipeline")
pipe.enable_attention_slicing()

if "vae-slicing" in server.optimizations:
logger.debug("enabling VAE slicing on SD pipeline")
pipe.enable_vae_slicing()

if "sequential-cpu-offload" in server.optimizations:
logger.debug("enabling sequential CPU offload on SD pipeline")
pipe.enable_sequential_cpu_offload()
elif "model-cpu-offload" in server.optimizations:
# TODO: check for accelerate
logger.debug("enabling model CPU offload on SD pipeline")
pipe.enable_model_cpu_offload()

if "memory-efficient-attention" in server.optimizations:
# TODO: check for xformers
logger.debug("enabling memory efficient attention for SD pipeline")
pipe.enable_xformers_memory_efficient_attention()


def load_pipeline(
server: ServerContext,
pipeline: DiffusionPipeline,
Expand Down Expand Up @@ -151,6 +178,8 @@ def load_pipeline(
if not server.show_progress:
pipe.set_progress_bar_config(disable=True)

optimize_pipeline(server, pipe)

if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device.torch_str())

Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
cache: ModelCache = None,
cache_path: str = None,
show_progress: bool = True,
optimizations: List[str] = [],
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
Expand All @@ -42,6 +43,7 @@ def __init__(
self.cache = cache or ModelCache(num_workers)
self.cache_path = cache_path or path.join(model_path, ".cache")
self.show_progress = show_progress
self.optimizations = optimizations

@classmethod
def from_environ(cls):
Expand All @@ -64,6 +66,7 @@ def from_environ(cls):
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
cache=ModelCache(limit=cache_limit),
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
)


Expand Down

0 comments on commit ab6462d

Please sign in to comment.