Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,12 @@ def get_parser_validate() -> ArgumentParser:
"the onnx exporter should use.",
default="",
)
parser.add_argument(
"--ort-logs",
default=False,
action=BooleanOptionalAction,
help="Enables onnxruntime logging when the session is created",
)
return parser


Expand Down Expand Up @@ -601,6 +607,7 @@ def _cmd_validate(argv: List[Any]):
repeat=args.repeat,
warmup=args.warmup,
inputs2=args.inputs2,
ort_logs=args.ort_logs,
output_names=(
None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
),
Expand Down
12 changes: 11 additions & 1 deletion onnx_diagnostic/torch_models/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def validate_model(
warmup: int = 0,
inputs2: int = 1,
output_names: Optional[List[str]] = None,
ort_logs: bool = False,
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
"""
Validates a model.
Expand Down Expand Up @@ -344,6 +345,7 @@ def validate_model(
this ensures that the model does support dynamism, the value is used
as an increment to the first set of values (added to dimensions)
:param output_names: output names the onnx exporter should use
:param ort_logs: increases onnxruntime verbosity when creating the session
:return: two dictionaries, one with some metrics,
another one with whatever the function produces

Expand Down Expand Up @@ -758,6 +760,7 @@ def validate_model(
repeat=repeat,
warmup=warmup,
inputs2=inputs2,
ort_logs=ort_logs,
)
summary.update(summary_valid)

Expand Down Expand Up @@ -1158,6 +1161,7 @@ def validate_onnx_model(
repeat: int = 1,
warmup: int = 0,
inputs2: int = 1,
ort_logs: bool = False,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Verifies that an onnx model produces the same
Expand All @@ -1176,6 +1180,7 @@ def validate_onnx_model(
:param inputs2: to validate the model on the second input set
to make sure the exported model supports dynamism, the value is
used as an increment added to the first set of inputs (added to dimensions)
:param ort_logs: triggers the logs for onnxruntime
:return: two dictionaries, one with some metrics,
another one with whatever the function produces
"""
Expand Down Expand Up @@ -1232,8 +1237,13 @@ def _mk(key, flavour=flavour):

if verbose:
print("[validate_onnx_model] runtime is onnxruntime")
cls_runtime = lambda model, providers: onnxruntime.InferenceSession(
sess_opts = onnxruntime.SessionOptions()
if ort_logs:
sess_opts.log_severity_level = 0
sess_opts.log_verbosity_level = 4
cls_runtime = lambda model, providers, _o=sess_opts: onnxruntime.InferenceSession(
(model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
_o,
providers=providers,
)
elif runtime == "torch":
Expand Down
Loading