Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ Change Logs
0.7.12
++++++

* :pr:`227`: better support for ``model_id//pretrained``, adds speed up when running command validate
* :pr:`226`: fix input order for models created with modelbuilder


0.7.11
++++++

Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/torch_models/hghub/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def get_untrained_model_with_inputs(
f"subfolder={subfolder!r}"
)
model = transformers.AutoModel.from_pretrained(
model_id, subfolder=subfolder, trust_remote_code=True, **mkwargs
model_id, subfolder=subfolder or "", trust_remote_code=True, **mkwargs
)
if verbose:
print(
Expand Down
79 changes: 71 additions & 8 deletions onnx_diagnostic/torch_models/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,18 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
return new_cfg


def _preprocess_model_id(model_id, subfolder):
def _preprocess_model_id(
model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
) -> Tuple[str, Optional[str], bool, bool]:
if subfolder or "//" not in model_id:
return model_id, subfolder
return model_id, subfolder, same_as_pretrained, use_pretrained
spl = model_id.split("//")
if spl[-1] == "pretrained":
return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
if spl[-1] in {"transformer", "vae"}:
# known subfolder
return "//".join(spl[:-1]), spl[-1]
return model_id, subfolder
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
return model_id, subfolder, same_as_pretrained, use_pretrained


def validate_model(
Expand Down Expand Up @@ -384,7 +388,12 @@ def validate_model(
if ``runtime == 'ref'``,
``orteval10`` increases the verbosity.
"""
model_id, subfolder = _preprocess_model_id(model_id, subfolder)
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
model_id,
subfolder,
same_as_pretrained=same_as_pretrained,
use_pretrained=use_pretrained,
)
if isinstance(patch, bool):
patch_kwargs = (
dict(patch_transformers=True, patch_diffusers=True, patch=True)
Expand Down Expand Up @@ -841,6 +850,8 @@ def validate_model(
)
summary.update(summary_valid)

_compute_final_statistics(summary)

if verbose:
print("[validate_model] -- done (final)")
if dump_stats:
Expand All @@ -853,15 +864,24 @@ def validate_model(
def compute_statistics(onnx_filename: str) -> Dict[str, Union[float, int]]:
"""Computes some statistics on the model itself."""
onx = onnx.load(onnx_filename, load_external_data=False)
cache_functions = {(f.domain, f.name): f for f in onx.functions}
local_domains = set(f.domain for f in onx.functions)

def node_iter(proto):
if isinstance(proto, onnx.ModelProto):
yield from node_iter(proto.graph)
for f in proto.functions:
yield from node_iter(f)
yield from node_iter(proto.graph)
elif isinstance(proto, (onnx.FunctionProto, onnx.GraphProto)):
for node in proto.node:
yield node

# Let's inline the function
key = node.domain, node.op_type
if key in cache_functions:
yield from node_iter(cache_functions[key])

# Let's continue
for att in node.attribute:
if att.type == onnx.AttributeProto.GRAPH:
yield from node_iter(att.g)
Expand All @@ -879,6 +899,11 @@ def node_iter(proto):
n_nodes += 1
if proto.op_type != "Constant":
n_nodes_nocst += 1
if proto.domain in local_domains:
key = "n_node_local_function"
if key not in counts:
counts[key] = 0
counts[key] += 1
else:
key = f"n_node_initializer_{proto.data_type}"

Expand Down Expand Up @@ -1400,7 +1425,7 @@ def call_torch_export_onnx(
:return: two dictionaries, one with some metrics,
another one with whatever the function produces
"""
available = {None, "", "ir", "os_ort"}
available = {None, "", "ir", "os_ort", "ir+default"}
assert (
optimization in available
), f"unexpected value for optimization={optimization}, available={available}"
Expand Down Expand Up @@ -1490,11 +1515,31 @@ def call_torch_export_onnx(
print(epo)
print("[call_torch_export_onnx] -- End of ONNXProgram")

if optimization in {"ir", "os_ort"}:
if optimization in {"ir", "os_ort", "ir+default"}:
if verbose:
print(f"[call_torch_export_onnx] starts optimization={optimization!r}...")
if optimization == "ir":
label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize())
elif optimization == "ir+default":
import onnxscript
from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions

def _ir_default_opt(epo):
onnxscript.optimizer.optimize_ir(epo.model)
onx = epo.model_proto
# not very efficient
gr = GraphBuilder(
onx,
infer_shapes_options=True,
optimization_options=OptimizationOptions(patterns="default"),
)
cont = gr.to_onnx(large_model=True)
epo.model = cont.to_ir()

label, f_optim = "export_onnx_opt_ir_default", (
lambda epo=epo: _ir_default_opt(epo)
)

else:
import onnxscript
import onnxscript.rewriter.ort_fusions as ort_fusions
Expand Down Expand Up @@ -1893,3 +1938,21 @@ def run_ort_fusion(
f"opt_ort_{model_type}_duration": duration,
f"opt_ort_{model_type}_duration_save": d,
}, {f"opt_ort_{model_type}": output_path}


def _compute_final_statistics(summary: Dict[str, Any]):
"""
Updates inline the list of statistics. It adds:

- speedup
"""
stats = {}
if (
"time_run_latency" in summary
and "time_run_onnx_ort_latency" in summary
and summary["time_run_onnx_ort_latency"] > 0
):
stats["stat_estimated_speedup_ort"] = (
summary["time_run_latency"] / summary["time_run_onnx_ort_latency"]
)
summary.update(stats)
Loading