Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Change Logs
0.7.12
++++++

* :pr:`232`: fixes ``--patch`` argument so that ``--patch=0`` works
* :pr:`231`: better statistics about fusions
* :pr:`227`: better support for ``model_id//pretrained``, adds speed up when running command validate
* :pr:`226`: fix input order for models created with modelbuilder

Expand Down
5 changes: 3 additions & 2 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ def _cmd_validate(argv: List[Any]):
):
print(f"validate - unsupported args: export={args.export!r}, opt={args.opt!r}")
return
patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch}
summary, _data = validate_model(
model_id=args.mid,
task=args.task,
Expand All @@ -591,8 +592,8 @@ def _cmd_validate(argv: List[Any]):
use_pretrained=args.trained,
dtype=args.dtype,
device=args.device,
patch=args.patch,
rewrite=args.rewrite,
patch=patch_dict,
rewrite=args.rewrite and patch_dict.get("patch", True),
stop_if_static=args.stop_if_static,
optimization=args.opt,
exporter=args.export,
Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/torch_models/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def validate_model(
assert not rewrite or patch_kwargs.get("patch", False), (
f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
f"patch must be True to enable rewriting, "
f"if --no-patch was specified on the command line, --no-rewrite must be added."
f"if --patch=0 was specified on the command line, rewrites are disabled."
)
summary = version_summary()
summary.update(
Expand Down
Loading