Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Change Logs
===========

0.7.12
++++++

* :pr:`226`: fix input order for models created with modelbuilder


0.7.11
++++++

Expand Down
1 change: 1 addition & 0 deletions _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def linkcode_resolve(domain, info):
("py:class", "CacheProcessor"),
("py:class", "default=sklearn.utils.metadata_routing.UNCHANGED"),
("py:class", "diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput"),
("py:class", "MambaCache"),
("py:class", "ModelProto"),
("py:class", "Model"),
("py:class", "Module"),
Expand Down
3 changes: 1 addition & 2 deletions _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,8 @@ The function replaces dynamic dimensions defined as strings by
Older versions
==============

* `0.7.12 <../v0.7.12/index.html>`_
* `0.7.11 <../v0.7.11/index.html>`_
* `0.7.10 <../v0.7.10/index.html>`_
* `0.7.9 <../v0.7.9/index.html>`_
* `0.6.3 <../v0.6.3/index.html>`_
* `0.5.0 <../v0.5.0/index.html>`_
* `0.4.4 <../v0.4.4/index.html>`_
Expand Down
5 changes: 3 additions & 2 deletions _unittests/ut_torch_models/test_validate_whole_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,18 @@ def test_k_filter_inputs(self):
@ignore_warnings(FutureWarning)
@requires_transformers("4.51")
def test_l_validate_model_modelbuilder(self):
mid = "meta-llama/Llama-2-7b-hf"
mid = "microsoft/phi-2"
summary, data = validate_model(
mid,
do_run=True,
verbose=10,
exporter="modelbuilder",
dump_folder="dump_test/validate_model_modelbuilder",
patch=True,
)
self.assertIsInstance(summary, dict)
self.assertIsInstance(data, dict)
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-2)
self.assertLess(summary["disc_onnx_ort_run_abs"], 3e-2)
onnx_filename = data["onnx_filename"]
self.assertExists(onnx_filename)

Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
Functions, classes to dig into a model when this one is right, slow, wrong...
"""

__version__ = "0.7.11"
__version__ = "0.7.12"
__author__ = "Xavier Dupré"
24 changes: 14 additions & 10 deletions onnx_diagnostic/helpers/rt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def make_feeds(
"""
# NOTE: position_ids is a special case because ModelBuilder does not usually use it,
# because it's fued into rotary embedding in GQA.
if isinstance(inputs, dict):
if is_modelbuilder and isinstance(inputs, dict):
inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing.

flat = flatten_object(inputs, drop_keys=True)
Expand Down Expand Up @@ -112,19 +112,23 @@ def reorder_modelbuilder_cache_to_torch(past_kv: List[Any]) -> List[Any]:
Reorders the past_kvs for ModelBuilder to match the expected order
by PyTorch exported models.

NOTE: This function can take either the names or the actual tensors
as long as they are in a list.
.. note::
This function can take either the names or the actual tensors
as long as they are in a list.

Conceptually,

From:
[past_key_values.0.key, past_key_values.0.value,
past_key_values.1.key, past_key_values.1.value, ...]
To:
[past_key_values.0.key, past_key_values.1.key,
..., past_key_values.0.value, past_key_values.1.value, ...]
From::

:param flat: list of flattened inputs
[past_key_values.0.key, past_key_values.0.value,
past_key_values.1.key, past_key_values.1.value, ...]

To::

[past_key_values.0.key, past_key_values.1.key,
..., past_key_values.0.value, past_key_values.1.value, ...]

:param past_kv: list of flattened inputs
:return: reordered list of flattened inputs
"""
total_len = len(past_kv)
Expand Down
Loading