From e7b9dc1b6ea459c75940c147d116297c17f84512 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 23 Oct 2025 02:06:53 +0200 Subject: [PATCH 1/4] fix patches --- _doc/technical/plot_broadcast_export_issue.py | 6 +++--- onnx_diagnostic/_command_lines_parser.py | 10 +++++----- .../torch_export_patches/patches/patch_transformers.py | 4 ++++ 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/_doc/technical/plot_broadcast_export_issue.py b/_doc/technical/plot_broadcast_export_issue.py index 02778455..2015dc28 100644 --- a/_doc/technical/plot_broadcast_export_issue.py +++ b/_doc/technical/plot_broadcast_export_issue.py @@ -80,8 +80,8 @@ def forward(self, x, y): # d1 = shape_env.create_unbacked_symint() # d2 = shape_env.create_unbacked_symint() fake_inputs = fake_mode.from_tensor( - torch.zeros((2,), dtype=torch.float32), static_shapes=False -), fake_mode.from_tensor(torch.zeros((2,), dtype=torch.float32), static_shapes=False) + torch.zeros((3,), dtype=torch.float32), static_shapes=False +), fake_mode.from_tensor(torch.zeros((3,), dtype=torch.float32), static_shapes=False) print("fake_inputs are ", fake_inputs) res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs) @@ -115,7 +115,7 @@ def forward(self, x, y): try: res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs) except Exception as e: - print(e) + print("error", e) # %% # By applying the patches: diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index f4c2eca5..e99d3c2e 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -693,16 +693,16 @@ def _cmd_export_sample(argv: List[Any]): os.makedirs(args.dump_folder, exist_ok=True) name = ( _make_folder_name( - model_id=args.model_id, - exporter=args.exporter, - optimization=args.optimization, + model_id=args.mid, + exporter=args.export, + optimization=args.opt, dtype=args.dtype, device=args.device, subfolder=args.subfolder, opset=args.opset, drop_inputs=None if not args.drop else args.drop.split(","), - same_as_pretrained=args.same_as_pretrained, - use_pretrained=args.use_pretrained, + same_as_pretrained=args.same_as_trained, + use_pretrained=args.trained, task=args.task, ).replace("/", "-") + ".py" diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 493ff4fd..7e911cd6 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -1312,6 +1312,10 @@ def patched_sdpa_attention_forward( # is_causal = query.shape[2] > 1 and attention_mask is None and is_causal is_causal = attention_mask is None and is_causal + torch._check( + attention_mask.shape[3] == key.shape[2], + "Attention mask shape incompatible with key shape.", + ) attn_output = torch.nn.functional.scaled_dot_product_attention( query, key, From ae0b77c39a345946ac219f8b5fbd0a2d9e4d0108 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 23 Oct 2025 13:03:11 +0200 Subject: [PATCH 2/4] fix patch --- .../torch_export_patches/patches/patch_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 7e911cd6..2d8a72d3 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -1313,7 +1313,7 @@ def patched_sdpa_attention_forward( is_causal = attention_mask is None and is_causal torch._check( - attention_mask.shape[3] == key.shape[2], + attention_mask is None or attention_mask.shape[3] == key.shape[2], "Attention mask shape incompatible with key shape.", ) attn_output = torch.nn.functional.scaled_dot_product_attention( From 743d71e44b2c7eda9c4e3f874505fdbd6765ae64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 23 Oct 2025 13:37:57 +0200 Subject: [PATCH 3/4] fixes shape information --- .../ut_tasks/test_tasks_text_generation.py | 19 +++- onnx_diagnostic/_command_lines_parser.py | 100 ++++++++++-------- onnx_diagnostic/tasks/image_text_to_text.py | 2 +- onnx_diagnostic/tasks/text_generation.py | 5 +- 4 files changed, 75 insertions(+), 51 deletions(-) diff --git a/_unittests/ut_tasks/test_tasks_text_generation.py b/_unittests/ut_tasks/test_tasks_text_generation.py index 8303b788..95cc6329 100644 --- a/_unittests/ut_tasks/test_tasks_text_generation.py +++ b/_unittests/ut_tasks/test_tasks_text_generation.py @@ -16,7 +16,7 @@ class TestTasksTextGeneration(ExtTestCase): @hide_stdout() @requires_transformers("4.53") @requires_torch("2.7.99") - def test_image_text_to_text_gemma3_for_causallm(self): + def test_tet_generation_gemma3_for_causallm(self): mid = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) self.assertEqual(data["task"], "text-generation") @@ -28,6 +28,23 @@ def test_image_text_to_text_gemma3_for_causallm(self): model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False ) + @hide_stdout() + @requires_transformers("4.53") + @requires_torch("2.7.99") + def test_itext_generation_phi_3_mini_128k_instruct(self): + mid = "microsoft/Phi-3-mini-128k-instruct" + data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) + self.assertEqual(data["task"], "text-generation") + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + print("--", self.string_type(inputs, with_shape=True)) + print("--", self.string_type(ds)) + model(**torch_deepcopy(inputs)) + model(**data["inputs2"]) + with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False): + torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index e99d3c2e..415e3572 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -371,30 +371,34 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, d) -def get_parser_validate() -> ArgumentParser: +def get_parser_validate(name: str = "validate") -> ArgumentParser: parser = ArgumentParser( - prog="validate", + prog=name, description=textwrap.dedent( """ - Prints out dummy inputs for a particular task or a model id. - If both mid and task are empty, the command line displays the list - of supported tasks. + Validates a model for a particular task given the model id. + It exports the model and then validates it by computing the discrepancies + on different input sets. + """ + if name == "validate" + else """ + Creates a script to export a model for a particular task given the model id. """ ), epilog=textwrap.dedent( - """ + f""" If the model id is specified, one untrained version of it is instantiated. Examples: - python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\ + python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\ --run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\ --dtype float16 --device cuda --patch --export onnx-dynamo --opt ir - python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\ + python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\ --run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\ --dtype float16 --device cuda --patch --export custom --opt default - python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\ + python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\ --run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\ --dtype float16 --device cuda --export modelbuilder @@ -405,12 +409,12 @@ def get_parser_validate() -> ArgumentParser: The behaviour may be modified compare the original configuration, the following argument can be rope_scaling to dynamic: - --mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"" + --mop \"rope_scaling={{'rope_type': 'dynamic', 'factor': 10.0}}\"" You can profile the command line by running: - pyinstrument -m onnx_diagnostic validate ... - pyinstrument -r html -o profile.html -m onnx_diagnostic validate ... + pyinstrument -m onnx_diagnostic {name} ... + pyinstrument -r html -o profile.html -m onnx_diagnostic {name} ... """ ), formatter_class=RawTextHelpFormatter, @@ -460,19 +464,19 @@ def get_parser_validate() -> ArgumentParser: "--same-as-trained", default=False, action=BooleanOptionalAction, - help="Validates a model identical to the trained model but not trained.", + help="Validates or exports a model identical to the trained model but not trained.", ) parser.add_argument( "--trained", default=False, action=BooleanOptionalAction, - help="Validates the trained model (requires downloading).", + help="Validates or exports the trained model (requires downloading).", ) parser.add_argument( "--inputs2", default=1, type=int, - help="Validates the model on a second set of inputs\n" + help="Validates or exports the model on a second set of inputs\n" "to check the exported model supports dynamism. The values is used " "as an increment to the first set of inputs. A high value may trick " "a different behavior in the model and missed by the exporter.", @@ -504,13 +508,14 @@ def get_parser_validate() -> ArgumentParser: "--subfolder", help="Subfolder where to find the model and the configuration.", ) - parser.add_argument( - "--ortfusiontype", - required=False, - help="Applies onnxruntime fusion, this parameter should contain the\n" - "model type or multiple values separated by `|`. `ALL` can be used\n" - "to run them all.", - ) + if name == "validate": + parser.add_argument( + "--ortfusiontype", + required=False, + help="Applies onnxruntime fusion, this parameter should contain the\n" + "model type or multiple values separated by `|`. `ALL` can be used\n" + "to run them all.", + ) parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity") parser.add_argument("--dtype", help="Changes dtype if necessary.") parser.add_argument("--device", help="Changes the device if necessary.") @@ -532,33 +537,38 @@ def get_parser_validate() -> ArgumentParser: "--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"", action=_ParseDict, ) - parser.add_argument( - "--repeat", - default=1, - type=int, - help="number of times to run the model to measures inference time", - ) - parser.add_argument( - "--warmup", default=0, type=int, help="number of times to run the model to do warmup" - ) + if name == "validate": + parser.add_argument( + "--repeat", + default=1, + type=int, + help="number of times to run the model to measures inference time", + ) + parser.add_argument( + "--warmup", + default=0, + type=int, + help="number of times to run the model to do warmup", + ) parser.add_argument( "--outnames", help="This comma separated list defines the output names " "the onnx exporter should use.", default="", ) - parser.add_argument( - "--ort-logs", - default=False, - action=BooleanOptionalAction, - help="Enables onnxruntime logging when the session is created", - ) - parser.add_argument( - "--quiet-input-sets", - default="", - help="Avoids raising an exception when an input sets does not work with " - "the exported model.\nExample: --quiet-input-sets=inputs,inputs22", - ) + if name == "validate": + parser.add_argument( + "--ort-logs", + default=False, + action=BooleanOptionalAction, + help="Enables onnxruntime logging when the session is created", + ) + parser.add_argument( + "--quiet-input-sets", + default="", + help="Avoids raising an exception when an input sets does not work with " + "the exported model.\nExample: --quiet-input-sets=inputs,inputs22", + ) return parser @@ -637,7 +647,7 @@ def _cmd_export_sample(argv: List[Any]): from .torch_models.code_sample import code_sample from .tasks import supported_tasks - parser = get_parser_validate() + parser = get_parser_validate("exportsample") args = parser.parse_args(argv[1:]) if not args.task and not args.mid: print("-- list of supported tasks:") @@ -1111,7 +1121,7 @@ def main(argv: Optional[List[Any]] = None): validate=get_parser_validate, stats=get_parser_stats, agg=get_parser_agg, - exportsample=get_parser_validate, + exportsample=lambda: get_parser_validate("exportsample"), ) cmd = argv[0] if cmd not in parsers: diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index a15b0ee9..0bb8a4e9 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -271,7 +271,7 @@ def get_inputs_default( "input_ids": {0: batch, 1: seq_length}, "token_type_ids": {0: batch, 1: seq_length}, "attention_mask": {0: batch, 1: "cache+seq"}, - "position_ids": {0: batch, 1: "cache+seq"}, + "position_ids": {0: batch, 1: seq_length}, "past_key_values": [ [{0: batch} for _ in range(num_hidden_layers)], [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 9643dc34..eebd4aa9 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -220,10 +220,7 @@ def get_inputs( 0: batch, 1: "cache+seq", # cache_length + seq_length }, - "position_ids": { - 0: batch, - 1: "cache+seq", # cache_length + seq_length - }, + "position_ids": {0: batch, 1: seq_length}, "past_key_values": [ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], From 8e8ad095f6ecfee5d2300a4f1163477052e43aae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 23 Oct 2025 13:42:38 +0200 Subject: [PATCH 4/4] fix issues --- onnx_diagnostic/_command_lines_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 415e3572..46a2737a 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1121,7 +1121,7 @@ def main(argv: Optional[List[Any]] = None): validate=get_parser_validate, stats=get_parser_stats, agg=get_parser_agg, - exportsample=lambda: get_parser_validate("exportsample"), + exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator] ) cmd = argv[0] if cmd not in parsers: