Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA optimizations #177

Merged
merged 8 commits into from
Feb 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from diffusers import StableDiffusionUpscalePipeline
from PIL import Image

from ..diffusion.load import optimize_pipeline
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
Expand Down Expand Up @@ -52,6 +53,8 @@ def load_stable_diffusion(
if not server.show_progress:
pipe.set_progress_bar_config(disable=True)

optimize_pipeline(server, pipe)

server.cache.set("diffusion", cache_key, pipe)
run_gc([device])

Expand Down
45 changes: 45 additions & 0 deletions api/onnx_web/diffusion/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
KDPM2DiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
)

try:
Expand Down Expand Up @@ -87,6 +88,48 @@ def get_tile_latents(
return full_latents[:, :, y:yt, x:xt]


def optimize_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,
) -> None:
if "diffusers-attention-slicing" in server.optimizations:
logger.debug("enabling attention slicing on SD pipeline")
try:
pipe.enable_attention_slicing()
except Exception as e:
logger.warning("error while enabling attention slicing: %s", e)

if "diffusers-vae-slicing" in server.optimizations:
logger.debug("enabling VAE slicing on SD pipeline")
try:
pipe.enable_vae_slicing()
except Exception as e:
logger.warning("error while enabling VAE slicing: %s", e)

if "diffusers-cpu-offload-sequential" in server.optimizations:
logger.debug("enabling sequential CPU offload on SD pipeline")
try:
pipe.enable_sequential_cpu_offload()
except Exception as e:
logger.warning("error while enabling sequential CPU offload: %s", e)

elif "diffusers-cpu-offload-model" in server.optimizations:
# TODO: check for accelerate
logger.debug("enabling model CPU offload on SD pipeline")
try:
pipe.enable_model_cpu_offload()
except Exception as e:
logger.warning("error while enabling model CPU offload: %s", e)

if "diffusers-memory-efficient-attention" in server.optimizations:
# TODO: check for xformers
logger.debug("enabling memory efficient attention for SD pipeline")
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception as e:
logger.warning("error while enabling memory efficient attention: %s", e)


def load_pipeline(
server: ServerContext,
pipeline: DiffusionPipeline,
Expand Down Expand Up @@ -151,6 +194,8 @@ def load_pipeline(
if not server.show_progress:
pipe.set_progress_bar_config(disable=True)

optimize_pipeline(server, pipe)

if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device.torch_str())

Expand Down
38 changes: 34 additions & 4 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from enum import IntEnum
from typing import Any, Dict, Literal, Optional, Tuple, Union
from logging import getLogger
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from onnxruntime import SessionOptions
from onnxruntime import GraphOptimizationLevel, SessionOptions

logger = getLogger(__name__)


class SizeChart(IntEnum):
Expand Down Expand Up @@ -75,11 +78,16 @@ def tojson(self) -> Dict[str, int]:

class DeviceParams:
def __init__(
self, device: str, provider: str, options: Optional[dict] = None
self,
device: str,
provider: str,
options: Optional[dict] = None,
optimizations: Optional[List[str]] = None,
) -> None:
self.device = device
self.provider = provider
self.options = options
self.optimizations = optimizations

def __str__(self) -> str:
return "%s - %s (%s)" % (self.device, self.provider, self.options)
Expand All @@ -91,7 +99,29 @@ def ort_provider(self) -> Tuple[str, Any]:
return (self.provider, self.options)

def sess_options(self) -> SessionOptions:
return SessionOptions()
sess = SessionOptions()

if "onnx-low-memory" in self.optimizations:
logger.debug("enabling ONNX low-memory optimizations")
sess.enable_cpu_mem_arena = False
sess.enable_mem_pattern = False
sess.enable_mem_reuse = False

if "onnx-graph-disable" in self.optimizations:
logger.debug("disabling all ONNX graph optimizations")
sess.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
elif "onnx-graph-basic" in self.optimizations:
logger.debug("enabling basic ONNX graph optimizations")
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
elif "onnx-graph-all" in self.optimizations:
logger.debug("enabling all ONNX graph optimizations")
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL

if "onnx-deterministic-compute" in self.optimizations:
logger.debug("enabling ONNX deterministic compute")
sess.use_deterministic_compute = True

return sess

def torch_str(self) -> str:
if self.device.startswith("cuda"):
Expand Down
17 changes: 15 additions & 2 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,16 +349,29 @@ def load_platforms(context: ServerContext) -> None:
{
"device_id": i,
},
context.optimizations,
)
)
else:
available_platforms.append(
DeviceParams(potential, platform_providers[potential])
DeviceParams(
potential,
platform_providers[potential],
None,
context.optimizations,
)
)

if context.any_platform:
# the platform should be ignored when the job is scheduled, but set to CPU just in case
available_platforms.append(DeviceParams("any", platform_providers["cpu"]))
available_platforms.append(
DeviceParams(
"any",
platform_providers["cpu"],
None,
context.optimizations,
)
)

# make sure CPU is last on the list
def any_first_cpu_last(a: DeviceParams, b: DeviceParams):
Expand Down
14 changes: 7 additions & 7 deletions api/onnx_web/server/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def run_upscale_correction(

chain = ChainPipeline()

upscale_stage = None
if upscale.scale > 1:
if "esrgan" in upscale.upscale_model:
esrgan_params = StageParams(
Expand All @@ -42,23 +43,22 @@ def run_upscale_correction(
upscale_stage = (upscale_resrgan, esrgan_params, None)
elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size)
sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (upscale_stable_diffusion, sd_stage, None)
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (upscale_stable_diffusion, sd_params, None)
else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
upscale_stage = None

correct_stage = None
if upscale.faces:
face_stage = StageParams(
face_params = StageParams(
tile_size=stage.tile_size, outscale=upscale.face_outscale
)
if "codeformer" in upscale.correction_model:
correct_stage = (correct_codeformer, face_stage, None)
correct_stage = (correct_codeformer, face_params, None)
elif "gfpgan" in upscale.correction_model:
correct_stage = (correct_gfpgan, face_stage, None)
correct_stage = (correct_gfpgan, face_params, None)
else:
logger.warn("unknown correction model: %s", upscale.correction_model)
correct_stage = None

if upscale.upscale_order == "correction-both":
chain.append(correct_stage)
Expand Down
3 changes: 3 additions & 0 deletions api/onnx_web/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
cache: ModelCache = None,
cache_path: str = None,
show_progress: bool = True,
optimizations: List[str] = [],
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
Expand All @@ -42,6 +43,7 @@ def __init__(
self.cache = cache or ModelCache(num_workers)
self.cache_path = cache_path or path.join(model_path, ".cache")
self.show_progress = show_progress
self.optimizations = optimizations

@classmethod
def from_environ(cls):
Expand All @@ -64,6 +66,7 @@ def from_environ(cls):
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
cache=ModelCache(limit=cache_limit),
show_progress=get_boolean(environ, "ONNX_WEB_SHOW_PROGRESS", True),
optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","),
)


Expand Down
34 changes: 34 additions & 0 deletions docs/server-admin.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Please see [the user guide](user-guide.md) for descriptions of the client and ea
- [Configuration](#configuration)
- [Debug Mode](#debug-mode)
- [Environment Variables](#environment-variables)
- [Pipeline Optimizations](#pipeline-optimizations)
- [Server Parameters](#server-parameters)
- [Containers](#containers)
- [CPU](#cpu)
Expand Down Expand Up @@ -73,6 +74,39 @@ Others:
- `ONNX_WEB_SHOW_PROGRESS`
- show progress bars in the logs
- disabling this can reduce noise in server logs, especially when logging to a file
- `ONNX_WEB_OPTIMIZATIONS`
- comma-delimited list of optimizations to enable

### Pipeline Optimizations

- `diffusers-*`
- `diffusers-attention-slicing`
- https://huggingface.co/docs/diffusers/optimization/fp16#sliced-attention-for-additional-memory-savings
- `diffusers-cpu-offload-*`
- `diffusers-cpu-offload-sequential`
- not available for ONNX pipelines (most of them)
- https://huggingface.co/docs/diffusers/optimization/fp16#offloading-to-cpu-with-accelerate-for-memory-savings
- `diffusers-cpu-offload-model`
- not available for ONNX pipelines (most of them)
- https://huggingface.co/docs/diffusers/optimization/fp16#model-offloading-for-fast-inference-and-memory-savings
- `diffusers-memory-efficient-attention`
- requires [the `xformers` library](https://huggingface.co/docs/diffusers/optimization/xformers)
- https://huggingface.co/docs/diffusers/optimization/fp16#memory-efficient-attention
- `diffusers-vae-slicing`
- not available for ONNX pipelines (most of them)
- https://huggingface.co/docs/diffusers/optimization/fp16#sliced-vae-decode-for-larger-batches
- `onnx-*`
- `onnx-low-memory`
- disable ONNX features that allocate more memory than is strictly required or keep memory after use
- `onnx-graph-*`
- `onnx-graph-disable`
- disable all ONNX graph optimizations
- `onnx-graph-basic`
- enable basic ONNX graph optimizations
- `onnx-graph-all`
- enable all ONNX graph optimizations
- `onnx-deterministic-compute`
- enable ONNX deterministic compute

### Server Parameters

Expand Down