diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 2cfbdc43..651e9891 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.7.11 ++++++ +* :pr:`224`: support model_id with // to specify a subfolder * :pr:`223`: adds task image-to-video * :pr:`220`: adds option --ort-logs to display onnxruntime logs when creating the session * :pr:`220`: adds a patch for PR `#40791 `_ in transformers diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index 63c624db..084194fe 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -264,6 +264,16 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]: return new_cfg +def _preprocess_model_id(model_id, subfolder): + if subfolder or "//" not in model_id: + return model_id, subfolder + spl = model_id.split("//") + if spl[-1] in {"transformer", "vae"}: + # known subfolder + return "//".join(spl[:-1]), spl[-1] + return model_id, subfolder + + def validate_model( model_id: str, task: Optional[str] = None, @@ -374,6 +384,7 @@ def validate_model( if ``runtime == 'ref'``, ``orteval10`` increases the verbosity. """ + model_id, subfolder = _preprocess_model_id(model_id, subfolder) if isinstance(patch, bool): patch_kwargs = ( dict(patch_transformers=True, patch_diffusers=True, patch=True)