From cc64994dc24951887bda1fc6d4cd101f3585516f Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 19 Sep 2025 14:16:43 +0200 Subject: [PATCH] Handle more models --- onnx_diagnostic/_command_lines_parser.py | 7 +++++++ onnx_diagnostic/torch_models/validate.py | 12 +++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index db488e46..1c98a90c 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -542,6 +542,12 @@ def get_parser_validate() -> ArgumentParser: "the onnx exporter should use.", default="", ) + parser.add_argument( + "--ort-logs", + default=False, + action=BooleanOptionalAction, + help="Enables onnxruntime logging when the session is created", + ) return parser @@ -601,6 +607,7 @@ def _cmd_validate(argv: List[Any]): repeat=args.repeat, warmup=args.warmup, inputs2=args.inputs2, + ort_logs=args.ort_logs, output_names=( None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",") ), diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index b0b69e50..63c624db 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -292,6 +292,7 @@ def validate_model( warmup: int = 0, inputs2: int = 1, output_names: Optional[List[str]] = None, + ort_logs: bool = False, ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]: """ Validates a model. @@ -344,6 +345,7 @@ def validate_model( this ensures that the model does support dynamism, the value is used as an increment to the first set of values (added to dimensions) :param output_names: output names the onnx exporter should use + :param ort_logs: increases onnxruntime verbosity when creating the session :return: two dictionaries, one with some metrics, another one with whatever the function produces @@ -758,6 +760,7 @@ def validate_model( repeat=repeat, warmup=warmup, inputs2=inputs2, + ort_logs=ort_logs, ) summary.update(summary_valid) @@ -1158,6 +1161,7 @@ def validate_onnx_model( repeat: int = 1, warmup: int = 0, inputs2: int = 1, + ort_logs: bool = False, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Verifies that an onnx model produces the same @@ -1176,6 +1180,7 @@ def validate_onnx_model( :param inputs2: to validate the model on the second input set to make sure the exported model supports dynamism, the value is used as an increment added to the first set of inputs (added to dimensions) + :param ort_logs: triggers the logs for onnxruntime :return: two dictionaries, one with some metrics, another one with whatever the function produces """ @@ -1232,8 +1237,13 @@ def _mk(key, flavour=flavour): if verbose: print("[validate_onnx_model] runtime is onnxruntime") - cls_runtime = lambda model, providers: onnxruntime.InferenceSession( + sess_opts = onnxruntime.SessionOptions() + if ort_logs: + sess_opts.log_severity_level = 0 + sess_opts.log_verbosity_level = 4 + cls_runtime = lambda model, providers, _o=sess_opts: onnxruntime.InferenceSession( (model.SerializeToString() if isinstance(model, onnx.ModelProto) else model), + _o, providers=providers, ) elif runtime == "torch":