From 509cfe405f2bf319586ea6c1f46d94d65eb145de Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 22 Jul 2025 09:48:31 +0200 Subject: [PATCH 1/4] Add a unit test about an issue --- _unittests/ut_export/test_issue_2025.py | 54 +++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 _unittests/ut_export/test_issue_2025.py diff --git a/_unittests/ut_export/test_issue_2025.py b/_unittests/ut_export/test_issue_2025.py new file mode 100644 index 00000000..18e5c595 --- /dev/null +++ b/_unittests/ut_export/test_issue_2025.py @@ -0,0 +1,54 @@ +import unittest +import numpy as np +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase + + +class TestIssues2025(ExtTestCase): + def test_issue_158786_qwen2vl(self): + # https://github.com/pytorch/pytorch/issues/158786 + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.spatial_merge_size = 2 # Default + + def forward(self, a): + pos_ids = [] + for t, h, w in a: + t = t.item() + h = h.item() + w = w.item() + torch._constrain_as_size(t) + torch._constrain_as_size(h) + torch._constrain_as_size(w) + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + return pos_ids + + model = Model() + inputs = torch.tensor(np.array([1, 98, 146]).reshape(1, 3)) + ep = torch.export.export(model, (inputs,)) + self.assertIn("torch.ops.aten.cat.default", str(ep)) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From ef2d94b99e3899578cdebdb54797550fcfcbf789 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 22 Jul 2025 10:04:49 +0200 Subject: [PATCH 2/4] fix --- _unittests/ut_export/test_issue_2025.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_export/test_issue_2025.py b/_unittests/ut_export/test_issue_2025.py index 18e5c595..d1f48fac 100644 --- a/_unittests/ut_export/test_issue_2025.py +++ b/_unittests/ut_export/test_issue_2025.py @@ -1,10 +1,11 @@ import unittest import numpy as np import torch -from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch class TestIssues2025(ExtTestCase): + @requires_torch("2.8") def test_issue_158786_qwen2vl(self): # https://github.com/pytorch/pytorch/issues/158786 class Model(torch.nn.Module): From 15cd008a5fc5b8047261814fd7467cac9a4ccfaf Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 22 Jul 2025 10:13:05 +0200 Subject: [PATCH 3/4] fix mambacache import --- .../test_onnx_export_errors.py | 6 +++++- onnx_diagnostic/helpers/cache_helper.py | 15 +++++++++------ onnx_diagnostic/tasks/text_generation.py | 8 ++++++-- .../onnx_export_serialization.py | 6 +++++- .../serialization/transformers_impl.py | 6 +++++- 5 files changed, 30 insertions(+), 11 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py index b9adbc7d..5935928c 100644 --- a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py +++ b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py @@ -22,7 +22,11 @@ class TestOnnxExportErrors(ExtTestCase): def test_pytree_flatten_mamba_cache(self): import torch import torch.utils._pytree as py_pytree - from transformers.cache_utils import MambaCache + + try: + from transformers.models.mamba.cache_mamba import MambaCache + except ImportError: + from transformers.cache_utils import MambaCache class _config: def __init__(self): diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 820983a4..3afbd46b 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -4,6 +4,11 @@ import transformers import transformers.cache_utils +try: + from transformers.models.mamba.cache_mamba import MambaCache +except ImportError: + from transformers.cache_utils import MambaCache + def flatten_unflatten_for_dynamic_shapes( obj: Any, @@ -242,10 +247,8 @@ def make_encoder_decoder_cache( ) -def make_mamba_cache( - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], -) -> transformers.cache_utils.MambaCache: - "Creates a :class:`transformers.cache_utils.MambaCache`." +def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache: + "Creates a ``MambaCache``." dtype = key_value_pairs[0][0].dtype class _config: @@ -256,7 +259,7 @@ def __init__(self): self.num_hidden_layers = len(key_value_pairs) self.dtype = dtype - cache = transformers.cache_utils.MambaCache( + cache = MambaCache( _config(), max_batch_size=key_value_pairs[0][0].shape[0], device=key_value_pairs[0][0].device, @@ -286,7 +289,7 @@ def __init__(self): def make_sliding_window_cache( key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], -) -> transformers.cache_utils.MambaCache: +) -> transformers.cache_utils.SlidingWindowCache: "Creates a :class:`transformers.cache_utils.SlidingWindowCache`." class _config: diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 599062bc..403350d8 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -1,6 +1,5 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union import torch -import transformers from ..helpers.cache_helper import ( make_dynamic_cache, make_mamba_cache, @@ -95,9 +94,14 @@ def get_inputs( cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) if config is not None and config.__class__.__name__ == "FalconMambaConfig": + try: + from transformers.models.mamba.cache_mamba import MambaCache + except ImportError: + from transformers.cache_utils import MambaCache + assert cls_cache in ( "MambaCache", - transformers.cache_utils.MambaCache, + MambaCache, ), f"Unexpected value for cls_cache={cls_cache} and config={config}" seq_length_multiple = 8 sequence_length = ( diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 4c4d2507..f7c5e561 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -6,12 +6,16 @@ import transformers from transformers.cache_utils import ( DynamicCache, - MambaCache, EncoderDecoderCache, SlidingWindowCache, StaticCache, ) +try: + from transformers.models.mamba.cache_mamba import MambaCache +except ImportError: + from transformers.cache_utils import MambaCache + from ..helpers import string_type from .serialization import _lower_name_with_ diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index 3b2dc899..96f29b71 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -3,11 +3,15 @@ import transformers from transformers.cache_utils import ( DynamicCache, - MambaCache, EncoderDecoderCache, SlidingWindowCache, StaticCache, ) + +try: + from transformers.models.mamba.cache_mamba import MambaCache +except ImportError: + from transformers.cache_utils import MambaCache from transformers.modeling_outputs import BaseModelOutput from ...helpers.cache_helper import make_static_cache from . import make_serialization_function_for_dataclass From 244c5efa839fac5bdf85fe4fee5c79f6af780d3a Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 22 Jul 2025 10:17:00 +0200 Subject: [PATCH 4/4] fix import issues --- _unittests/ut_torch_export_patches/test_onnx_export_errors.py | 2 +- onnx_diagnostic/helpers/cache_helper.py | 2 +- onnx_diagnostic/tasks/text_generation.py | 2 +- .../torch_export_patches/onnx_export_serialization.py | 2 +- .../torch_export_patches/serialization/transformers_impl.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py index 5935928c..862848d9 100644 --- a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py +++ b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py @@ -24,7 +24,7 @@ def test_pytree_flatten_mamba_cache(self): import torch.utils._pytree as py_pytree try: - from transformers.models.mamba.cache_mamba import MambaCache + from transformers.models.mamba.modeling_mamba import MambaCache except ImportError: from transformers.cache_utils import MambaCache diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 3afbd46b..838d4dc4 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -5,7 +5,7 @@ import transformers.cache_utils try: - from transformers.models.mamba.cache_mamba import MambaCache + from transformers.models.mamba.modeling_mamba import MambaCache except ImportError: from transformers.cache_utils import MambaCache diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 403350d8..1a617bdd 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -95,7 +95,7 @@ def get_inputs( if config is not None and config.__class__.__name__ == "FalconMambaConfig": try: - from transformers.models.mamba.cache_mamba import MambaCache + from transformers.models.mamba.modeling_mamba import MambaCache except ImportError: from transformers.cache_utils import MambaCache diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index f7c5e561..ca8b66f4 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -12,7 +12,7 @@ ) try: - from transformers.models.mamba.cache_mamba import MambaCache + from transformers.models.mamba.modeling_mamba import MambaCache except ImportError: from transformers.cache_utils import MambaCache diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index 96f29b71..4507d14c 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -9,7 +9,7 @@ ) try: - from transformers.models.mamba.cache_mamba import MambaCache + from transformers.models.mamba.modeling_mamba import MambaCache except ImportError: from transformers.cache_utils import MambaCache from transformers.modeling_outputs import BaseModelOutput