Skip to content

Commit

Permalink
feat(api): enable ONNX optimizations through env
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 18, 2023
1 parent 0d2211f commit 5b4c370
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
32 changes: 28 additions & 4 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from enum import IntEnum
from typing import Any, Dict, Literal, Optional, Tuple, Union
from logging import getLogger
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from onnxruntime import SessionOptions
from onnxruntime import GraphOptimizationLevel, SessionOptions

logger = getLogger(__name__)


class SizeChart(IntEnum):
Expand Down Expand Up @@ -75,11 +78,16 @@ def tojson(self) -> Dict[str, int]:

class DeviceParams:
def __init__(
self, device: str, provider: str, options: Optional[dict] = None
self,
device: str,
provider: str,
options: Optional[dict] = None,
optimizations: Optional[List[str]] = None,
) -> None:
self.device = device
self.provider = provider
self.options = options
self.optimizations = optimizations

def __str__(self) -> str:
return "%s - %s (%s)" % (self.device, self.provider, self.options)
Expand All @@ -91,7 +99,23 @@ def ort_provider(self) -> Tuple[str, Any]:
return (self.provider, self.options)

def sess_options(self) -> SessionOptions:
return SessionOptions()
sess = SessionOptions()

if "onnx-low-memory" in self.optimizations:
logger.debug("enabling ONNX low-memory optimizations")
sess.enable_cpu_mem_arena = False
sess.enable_mem_pattern = False
sess.enable_mem_reuse = False

if "onnx-optimization-disable" in self.optimizations:
sess.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
elif "onnx-optimization-basic" in self.optimizations:
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
elif "onnx-optimization-all" in self.optimizations:
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL

if "onnx-deterministic-compute" in self.optimizations:
sess.use_deterministic_compute = True

def torch_str(self) -> str:
if self.device.startswith("cuda"):
Expand Down
17 changes: 15 additions & 2 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,16 +349,29 @@ def load_platforms(context: ServerContext) -> None:
{
"device_id": i,
},
context.optimizations,
)
)
else:
available_platforms.append(
DeviceParams(potential, platform_providers[potential])
DeviceParams(
potential,
platform_providers[potential],
None,
context.optimizations,
)
)

if context.any_platform:
# the platform should be ignored when the job is scheduled, but set to CPU just in case
available_platforms.append(DeviceParams("any", platform_providers["cpu"]))
available_platforms.append(
DeviceParams(
"any",
platform_providers["cpu"],
None,
context.optimizations,
)
)

# make sure CPU is last on the list
def any_first_cpu_last(a: DeviceParams, b: DeviceParams):
Expand Down

0 comments on commit 5b4c370

Please sign in to comment.