Skip to content

Commit

Permalink
Fixed obtaining inputs names from ONNX file for TensorRT conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
jkosek committed Aug 30, 2023
1 parent 3ee5574 commit 8baf510
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 180 deletions.
35 changes: 35 additions & 0 deletions .github/workflows/stale.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: 'Close stale issues and PRs'
on:
schedule:
- cron: "30 1 * * *"
jobs:
stale:
if: github.repository_owner == 'triton-inference-server'
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v8
with:
days-before-stale: 60
days-before-close: 7
stale-issue-message: 'This issue is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 7 days.'
stale-pr-message: 'This PR is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 7 days.'
close-issue-message: 'This issue was closed because it has been stalled for 7 days with no activity.'
close-pr-message: 'This PR was closed because it has been stalled for 7 days with no activity.'
exempt-issue-labels: 'non-stale'
exempt-pr-labels: 'non-stale'
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,23 @@ limitations under the License.

# Changelog

## 0.7.2
- fix: Obtaining inputs names from ONNX file for TensorRT conversion
- change: Raise exception instead of exit with code when required command has failed

- Version of external components used during testing:
- [PyTorch 2.1.0a0+b5021ba](https://github.com/pytorch/pytorch/commit/b5021ba9)
- [TensorFlow 2.12.0](https://github.com/tensorflow/tensorflow/releases/tag/v2.12.0)
- [TensorRT 8.6.1](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html)
- [ONNX Runtime 1.15.1](https://github.com/microsoft/onnxruntime/tree/v1.15.1)
- [Polygraphy](https://github.com/NVIDIA/TensorRT/tree/master/tools/Polygraphy/): 0.47.1
- [GraphSurgeon](https://github.com/NVIDIA/TensorRT/tree/master/tools/onnx-graphsurgeon/): 0.3.27
- [tf2onnx v1.14.0](https://github.com/onnx/tensorflow-onnx/releases/tag/v1.14.0)
- Other component versions depend on the used framework containers versions.
See its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
for a detailed summary.


## 0.7.1
- fix: gather onnx input names based on model's forward signature
- fix: do not run TensorRT max batch size search when max batch size is None
Expand Down
2 changes: 1 addition & 1 deletion model_navigator/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# noqa: D100
__version__ = "0.7.1"
__version__ = "0.7.2"
74 changes: 9 additions & 65 deletions model_navigator/commands/convert/onnx/onnx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""ConvertONNX2TRT command."""
import json
import pathlib
import sys
import tempfile
from distutils.version import LooseVersion
from typing import Any, Dict, List, Optional

Expand All @@ -24,14 +22,14 @@
TensorRTPrecision,
TensorRTPrecisionMode,
TensorRTProfile,
TensorType,
)
from model_navigator.commands.base import CommandOutput, CommandStatus
from model_navigator.commands.convert.base import Convert2TensorRTWithMaxBatchSizeSearch
from model_navigator.commands.execution_context import ExecutionContext
from model_navigator.core.logger import LOGGER
from model_navigator.core.tensor import TensorMetadata
from model_navigator.core.workspace import Workspace
from model_navigator.frameworks.onnx.utils import get_onnx_io_names
from model_navigator.frameworks.tensorrt import utils as tensorrt_utils
from model_navigator.runners.tensorrt import TensorRTRunner
from model_navigator.utils import devices
Expand Down Expand Up @@ -109,8 +107,6 @@ def _run(
input_model_path=input_model_path,
converted_model_path=converted_model_path,
workspace=workspace,
input_metadata=input_metadata,
output_metadata=output_metadata,
dataloader_trt_profile=dataloader_trt_profile,
optimized_trt_profiles=optimized_trt_profiles,
batch_dim=batch_dim,
Expand All @@ -119,7 +115,6 @@ def _run(
precision=precision,
precision_mode=precision_mode,
max_workspace_size=max_workspace_size,
verbose=verbose,
custom_args=custom_args,
)

Expand Down Expand Up @@ -167,8 +162,6 @@ def _get_get_args_callable(
input_model_path: pathlib.Path,
converted_model_path: pathlib.Path,
workspace: Workspace,
input_metadata: TensorMetadata,
output_metadata: TensorMetadata,
dataloader_trt_profile: TensorRTProfile,
custom_args: Dict[str, Any],
optimized_trt_profiles: Optional[List[TensorRTProfile]] = None,
Expand All @@ -178,7 +171,6 @@ def _get_get_args_callable(
precision: Optional[TensorRTPrecision] = None,
precision_mode: Optional[TensorRTPrecisionMode] = None,
max_workspace_size: Optional[int] = None,
verbose: bool = False,
):
convert_cmd = ["polygraphy", "convert", input_model_path.relative_to(workspace.path).as_posix()]
convert_cmd.extend(["--convert-to", "trt"])
Expand Down Expand Up @@ -214,35 +206,28 @@ def _get_get_args_callable(
else:
convert_cmd.extend(["--pool-limit", f"workspace:{max_workspace_size}"])

onnx_input_metadata = self._get_onnx_input_metadata(
input_model_path=input_model_path,
input_metadata=input_metadata,
output_metadata=output_metadata,
workspace=workspace,
reproduce_script_path=converted_model_path.parent,
verbose=verbose,
)

for k, v in (custom_args or {}).items():
if isinstance(v, bool) and v is True:
convert_cmd.append(k)
else:
convert_cmd.extend([k, v])

onnx_input_names, _ = get_onnx_io_names(onnx_path=input_model_path)

def get_args(max_batch_size=None):
if optimized_trt_profiles:
shape_args = []
for trt_profile in optimized_trt_profiles:
shape_args.extend(
self._trt_profile_to_shape_args(
onnx_input_metadata=onnx_input_metadata,
onnx_input_names=onnx_input_names,
trt_profile=trt_profile,
)
)
return convert_cmd + shape_args
else:
return convert_cmd + self._get_shape_args(
onnx_input_metadata=onnx_input_metadata,
onnx_input_names=onnx_input_names,
trt_profile=dataloader_trt_profile,
batch_dim=batch_dim,
max_batch_size=max_batch_size,
Expand All @@ -252,15 +237,15 @@ def get_args(max_batch_size=None):

@staticmethod
def _trt_profile_to_shape_args(
onnx_input_metadata: TensorMetadata,
onnx_input_names: List[str],
trt_profile: TensorRTProfile,
):
shape_args = []
for attr in ("min", "opt", "max"):
arg = f"--trt-{attr}-shapes"
shapes = []
for input_name in trt_profile:
if input_name not in onnx_input_metadata:
if input_name not in onnx_input_names:
continue
shape = ",".join([str(d) for d in getattr(trt_profile[input_name], attr)])
shapes.append(f"{input_name}:[{shape}]")
Expand All @@ -286,7 +271,7 @@ def _get_conversion_profiles(

@staticmethod
def _get_shape_args(
onnx_input_metadata: TensorMetadata,
onnx_input_names: List[str],
trt_profile: TensorRTProfile,
batch_dim: Optional[int] = None,
max_batch_size: Optional[int] = None,
Expand All @@ -302,52 +287,11 @@ def _get_shape_args(
arg = f"--trt-{attr}-shapes"
shapes = []
for input_name in trt_profile:
if input_name not in onnx_input_metadata:
if input_name not in onnx_input_names:
continue
shape = ",".join([str(d) for d in getattr(trt_profile[input_name], attr)])
shapes.append(f"{input_name}:[{shape}]")
if shapes:
shape_args.extend([f"{arg}"] + shapes)

return shape_args

def _get_onnx_input_metadata(
self,
workspace: Workspace,
input_model_path: pathlib.Path,
input_metadata: TensorMetadata,
output_metadata: TensorMetadata,
reproduce_script_path: pathlib.Path,
verbose: bool,
):
with ExecutionContext(
script_path=reproduce_script_path / "reproduce_onnx_input_metadata.py",
cmd_path=reproduce_script_path / "reproduce_onnx_input_metadata.sh",
workspace=workspace,
verbose=verbose,
) as context, tempfile.NamedTemporaryFile() as temp_file:
kwargs = {
"model_path": input_model_path.relative_to(workspace.path).as_posix(),
"input_metadata": input_metadata.to_json(),
"output_metadata": output_metadata.to_json(),
"results_path": temp_file.name,
}
args = parse_kwargs_to_cmd(kwargs)
from . import collect_onnx_input_metadata

try:
context.execute_external_runtime_script(collect_onnx_input_metadata.__file__, args)
with open(temp_file.name) as fp:
input_metadata = json.load(fp)
LOGGER.info("Input metadata collected from ONNX model.")
except Exception as e:
LOGGER.warning(
"Unable to collect metadata from ONNX model. The evaluation failed. Empty metadata used."
)
LOGGER.warning(f"Error during obtaining metadata: {str(e)}")
input_metadata = {
"metadata": [],
"pytree_metadata": {"metadata": None, "tensor_type": TensorType.NUMPY.value},
}

return TensorMetadata.from_json(input_metadata)
13 changes: 13 additions & 0 deletions model_navigator/frameworks/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""ONNX utils."""
import pathlib
from typing import List, Tuple

import numpy as np

Expand All @@ -30,3 +32,14 @@
"tensor(bool)": bool,
"tensor(string)": str,
}


def get_onnx_io_names(onnx_path: pathlib.Path) -> Tuple[List, List]:
"""Get input and output metadata from ONNX model."""
import onnx

model = onnx.load_model(onnx_path.as_posix())

input_names = [input.name for input in model.graph.input]
output_names = [output.name for output in model.graph.output]
return input_names, output_names
12 changes: 9 additions & 3 deletions model_navigator/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@
# limitations under the License.
"""Definition of Pipeline module - Direct Acyclic Graph (DAG) of commands execution."""
import contextlib
import sys
import traceback
from typing import List

from model_navigator.commands.base import CommandOutput, CommandStatus, ExecutionUnit
from model_navigator.configuration.common_config import CommonConfig
from model_navigator.core.logger import LOGGER, LoggingContext, StdoutLogger, pad_string
from model_navigator.core.workspace import Workspace
from model_navigator.exceptions import ModelNavigatorCommandNotExecutable, ModelNavigatorUserInputError
from model_navigator.exceptions import (
ModelNavigatorCommandNotExecutable,
ModelNavigatorRuntimeError,
ModelNavigatorUserInputError,
)
from model_navigator.pipelines.pipeline_context import PipelineContext


Expand Down Expand Up @@ -125,6 +128,9 @@ def _execute_unit(
LOGGER.error(f"Command finished with unexpected error: {error}")

if command_output.status != CommandStatus.OK and execution_unit.command.is_required():
sys.exit("The required command has failed. Please, review the log and verify the reported problems.")
raise ModelNavigatorRuntimeError(
"The required command has failed. Please, review the log and verify the reported problems: \n"
f"{command_output.output}."
)

return command_output
Loading

0 comments on commit 8baf510

Please sign in to comment.