Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions _doc/technical/plot_broadcast_export_issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def forward(self, x, y):
# d1 = shape_env.create_unbacked_symint()
# d2 = shape_env.create_unbacked_symint()
fake_inputs = fake_mode.from_tensor(
torch.zeros((2,), dtype=torch.float32), static_shapes=False
), fake_mode.from_tensor(torch.zeros((2,), dtype=torch.float32), static_shapes=False)
torch.zeros((3,), dtype=torch.float32), static_shapes=False
), fake_mode.from_tensor(torch.zeros((3,), dtype=torch.float32), static_shapes=False)

print("fake_inputs are ", fake_inputs)
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
Expand Down Expand Up @@ -115,7 +115,7 @@ def forward(self, x, y):
try:
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
except Exception as e:
print(e)
print("error", e)

# %%
# By applying the patches:
Expand Down
19 changes: 18 additions & 1 deletion _unittests/ut_tasks/test_tasks_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TestTasksTextGeneration(ExtTestCase):
@hide_stdout()
@requires_transformers("4.53")
@requires_torch("2.7.99")
def test_image_text_to_text_gemma3_for_causallm(self):
def test_tet_generation_gemma3_for_causallm(self):
mid = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
self.assertEqual(data["task"], "text-generation")
Expand All @@ -28,6 +28,23 @@ def test_image_text_to_text_gemma3_for_causallm(self):
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)

@hide_stdout()
@requires_transformers("4.53")
@requires_torch("2.7.99")
def test_itext_generation_phi_3_mini_128k_instruct(self):
mid = "microsoft/Phi-3-mini-128k-instruct"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
self.assertEqual(data["task"], "text-generation")
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
print("--", self.string_type(inputs, with_shape=True))
print("--", self.string_type(ds))
model(**torch_deepcopy(inputs))
model(**data["inputs2"])
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)


if __name__ == "__main__":
unittest.main(verbosity=2)
110 changes: 60 additions & 50 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,30 +371,34 @@ def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, d)


def get_parser_validate() -> ArgumentParser:
def get_parser_validate(name: str = "validate") -> ArgumentParser:
parser = ArgumentParser(
prog="validate",
prog=name,
description=textwrap.dedent(
"""
Prints out dummy inputs for a particular task or a model id.
If both mid and task are empty, the command line displays the list
of supported tasks.
Validates a model for a particular task given the model id.
It exports the model and then validates it by computing the discrepancies
on different input sets.
"""
if name == "validate"
else """
Creates a script to export a model for a particular task given the model id.
"""
),
epilog=textwrap.dedent(
"""
f"""
If the model id is specified, one untrained version of it is instantiated.
Examples:

python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
--dtype float16 --device cuda --patch --export onnx-dynamo --opt ir

python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
--dtype float16 --device cuda --patch --export custom --opt default

python -m onnx_diagnostic validate -m microsoft/Phi-4-mini-reasoning \\
python -m onnx_diagnostic {name} -m microsoft/Phi-4-mini-reasoning \\
--run -v 1 -o dump_test --no-quiet --repeat 2 --warmup 2 \\
--dtype float16 --device cuda --export modelbuilder

Expand All @@ -405,12 +409,12 @@ def get_parser_validate() -> ArgumentParser:
The behaviour may be modified compare the original configuration,
the following argument can be rope_scaling to dynamic:

--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\""
--mop \"rope_scaling={{'rope_type': 'dynamic', 'factor': 10.0}}\""

You can profile the command line by running:

pyinstrument -m onnx_diagnostic validate ...
pyinstrument -r html -o profile.html -m onnx_diagnostic validate ...
pyinstrument -m onnx_diagnostic {name} ...
pyinstrument -r html -o profile.html -m onnx_diagnostic {name} ...
"""
),
formatter_class=RawTextHelpFormatter,
Expand Down Expand Up @@ -460,19 +464,19 @@ def get_parser_validate() -> ArgumentParser:
"--same-as-trained",
default=False,
action=BooleanOptionalAction,
help="Validates a model identical to the trained model but not trained.",
help="Validates or exports a model identical to the trained model but not trained.",
)
parser.add_argument(
"--trained",
default=False,
action=BooleanOptionalAction,
help="Validates the trained model (requires downloading).",
help="Validates or exports the trained model (requires downloading).",
)
parser.add_argument(
"--inputs2",
default=1,
type=int,
help="Validates the model on a second set of inputs\n"
help="Validates or exports the model on a second set of inputs\n"
"to check the exported model supports dynamism. The values is used "
"as an increment to the first set of inputs. A high value may trick "
"a different behavior in the model and missed by the exporter.",
Expand Down Expand Up @@ -504,13 +508,14 @@ def get_parser_validate() -> ArgumentParser:
"--subfolder",
help="Subfolder where to find the model and the configuration.",
)
parser.add_argument(
"--ortfusiontype",
required=False,
help="Applies onnxruntime fusion, this parameter should contain the\n"
"model type or multiple values separated by `|`. `ALL` can be used\n"
"to run them all.",
)
if name == "validate":
parser.add_argument(
"--ortfusiontype",
required=False,
help="Applies onnxruntime fusion, this parameter should contain the\n"
"model type or multiple values separated by `|`. `ALL` can be used\n"
"to run them all.",
)
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
parser.add_argument("--dtype", help="Changes dtype if necessary.")
parser.add_argument("--device", help="Changes the device if necessary.")
Expand All @@ -532,33 +537,38 @@ def get_parser_validate() -> ArgumentParser:
"--mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\"",
action=_ParseDict,
)
parser.add_argument(
"--repeat",
default=1,
type=int,
help="number of times to run the model to measures inference time",
)
parser.add_argument(
"--warmup", default=0, type=int, help="number of times to run the model to do warmup"
)
if name == "validate":
parser.add_argument(
"--repeat",
default=1,
type=int,
help="number of times to run the model to measures inference time",
)
parser.add_argument(
"--warmup",
default=0,
type=int,
help="number of times to run the model to do warmup",
)
parser.add_argument(
"--outnames",
help="This comma separated list defines the output names "
"the onnx exporter should use.",
default="",
)
parser.add_argument(
"--ort-logs",
default=False,
action=BooleanOptionalAction,
help="Enables onnxruntime logging when the session is created",
)
parser.add_argument(
"--quiet-input-sets",
default="",
help="Avoids raising an exception when an input sets does not work with "
"the exported model.\nExample: --quiet-input-sets=inputs,inputs22",
)
if name == "validate":
parser.add_argument(
"--ort-logs",
default=False,
action=BooleanOptionalAction,
help="Enables onnxruntime logging when the session is created",
)
parser.add_argument(
"--quiet-input-sets",
default="",
help="Avoids raising an exception when an input sets does not work with "
"the exported model.\nExample: --quiet-input-sets=inputs,inputs22",
)
return parser


Expand Down Expand Up @@ -637,7 +647,7 @@ def _cmd_export_sample(argv: List[Any]):
from .torch_models.code_sample import code_sample
from .tasks import supported_tasks

parser = get_parser_validate()
parser = get_parser_validate("exportsample")
args = parser.parse_args(argv[1:])
if not args.task and not args.mid:
print("-- list of supported tasks:")
Expand Down Expand Up @@ -693,16 +703,16 @@ def _cmd_export_sample(argv: List[Any]):
os.makedirs(args.dump_folder, exist_ok=True)
name = (
_make_folder_name(
model_id=args.model_id,
exporter=args.exporter,
optimization=args.optimization,
model_id=args.mid,
exporter=args.export,
optimization=args.opt,
dtype=args.dtype,
device=args.device,
subfolder=args.subfolder,
opset=args.opset,
drop_inputs=None if not args.drop else args.drop.split(","),
same_as_pretrained=args.same_as_pretrained,
use_pretrained=args.use_pretrained,
same_as_pretrained=args.same_as_trained,
use_pretrained=args.trained,
task=args.task,
).replace("/", "-")
+ ".py"
Expand Down Expand Up @@ -1111,7 +1121,7 @@ def main(argv: Optional[List[Any]] = None):
validate=get_parser_validate,
stats=get_parser_stats,
agg=get_parser_agg,
exportsample=get_parser_validate,
exportsample=lambda: get_parser_validate("exportsample"), # type: ignore[operator]
)
cmd = argv[0]
if cmd not in parsers:
Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/tasks/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def get_inputs_default(
"input_ids": {0: batch, 1: seq_length},
"token_type_ids": {0: batch, 1: seq_length},
"attention_mask": {0: batch, 1: "cache+seq"},
"position_ids": {0: batch, 1: "cache+seq"},
"position_ids": {0: batch, 1: seq_length},
"past_key_values": [
[{0: batch} for _ in range(num_hidden_layers)],
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
Expand Down
5 changes: 1 addition & 4 deletions onnx_diagnostic/tasks/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,7 @@ def get_inputs(
0: batch,
1: "cache+seq", # cache_length + seq_length
},
"position_ids": {
0: batch,
1: "cache+seq", # cache_length + seq_length
},
"position_ids": {0: batch, 1: seq_length},
"past_key_values": [
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,10 @@ def patched_sdpa_attention_forward(
# is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
is_causal = attention_mask is None and is_causal

torch._check(
attention_mask is None or attention_mask.shape[3] == key.shape[2],
"Attention mask shape incompatible with key shape.",
)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
Expand Down
Loading