diff --git a/_unittests/ut_export/test_issue_2025.py b/_unittests/ut_export/test_issue_2025.py new file mode 100644 index 00000000..d1f48fac --- /dev/null +++ b/_unittests/ut_export/test_issue_2025.py @@ -0,0 +1,55 @@ +import unittest +import numpy as np +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch + + +class TestIssues2025(ExtTestCase): + @requires_torch("2.8") + def test_issue_158786_qwen2vl(self): + # https://github.com/pytorch/pytorch/issues/158786 + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.spatial_merge_size = 2 # Default + + def forward(self, a): + pos_ids = [] + for t, h, w in a: + t = t.item() + h = h.item() + w = w.item() + torch._constrain_as_size(t) + torch._constrain_as_size(h) + torch._constrain_as_size(w) + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + return pos_ids + + model = Model() + inputs = torch.tensor(np.array([1, 98, 146]).reshape(1, 3)) + ep = torch.export.export(model, (inputs,)) + self.assertIn("torch.ops.aten.cat.default", str(ep)) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py index b9adbc7d..862848d9 100644 --- a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py +++ b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py @@ -22,7 +22,11 @@ class TestOnnxExportErrors(ExtTestCase): def test_pytree_flatten_mamba_cache(self): import torch import torch.utils._pytree as py_pytree - from transformers.cache_utils import MambaCache + + try: + from transformers.models.mamba.modeling_mamba import MambaCache + except ImportError: + from transformers.cache_utils import MambaCache class _config: def __init__(self): diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 820983a4..838d4dc4 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -4,6 +4,11 @@ import transformers import transformers.cache_utils +try: + from transformers.models.mamba.modeling_mamba import MambaCache +except ImportError: + from transformers.cache_utils import MambaCache + def flatten_unflatten_for_dynamic_shapes( obj: Any, @@ -242,10 +247,8 @@ def make_encoder_decoder_cache( ) -def make_mamba_cache( - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], -) -> transformers.cache_utils.MambaCache: - "Creates a :class:`transformers.cache_utils.MambaCache`." +def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache: + "Creates a ``MambaCache``." dtype = key_value_pairs[0][0].dtype class _config: @@ -256,7 +259,7 @@ def __init__(self): self.num_hidden_layers = len(key_value_pairs) self.dtype = dtype - cache = transformers.cache_utils.MambaCache( + cache = MambaCache( _config(), max_batch_size=key_value_pairs[0][0].shape[0], device=key_value_pairs[0][0].device, @@ -286,7 +289,7 @@ def __init__(self): def make_sliding_window_cache( key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], -) -> transformers.cache_utils.MambaCache: +) -> transformers.cache_utils.SlidingWindowCache: "Creates a :class:`transformers.cache_utils.SlidingWindowCache`." class _config: diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 599062bc..1a617bdd 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -1,6 +1,5 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union import torch -import transformers from ..helpers.cache_helper import ( make_dynamic_cache, make_mamba_cache, @@ -95,9 +94,14 @@ def get_inputs( cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) if config is not None and config.__class__.__name__ == "FalconMambaConfig": + try: + from transformers.models.mamba.modeling_mamba import MambaCache + except ImportError: + from transformers.cache_utils import MambaCache + assert cls_cache in ( "MambaCache", - transformers.cache_utils.MambaCache, + MambaCache, ), f"Unexpected value for cls_cache={cls_cache} and config={config}" seq_length_multiple = 8 sequence_length = ( diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 4c4d2507..ca8b66f4 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -6,12 +6,16 @@ import transformers from transformers.cache_utils import ( DynamicCache, - MambaCache, EncoderDecoderCache, SlidingWindowCache, StaticCache, ) +try: + from transformers.models.mamba.modeling_mamba import MambaCache +except ImportError: + from transformers.cache_utils import MambaCache + from ..helpers import string_type from .serialization import _lower_name_with_ diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index 3b2dc899..4507d14c 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -3,11 +3,15 @@ import transformers from transformers.cache_utils import ( DynamicCache, - MambaCache, EncoderDecoderCache, SlidingWindowCache, StaticCache, ) + +try: + from transformers.models.mamba.modeling_mamba import MambaCache +except ImportError: + from transformers.cache_utils import MambaCache from transformers.modeling_outputs import BaseModelOutput from ...helpers.cache_helper import make_static_cache from . import make_serialization_function_for_dataclass