Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions _unittests/ut_export/test_issue_2025.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import unittest
import numpy as np
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch


class TestIssues2025(ExtTestCase):
@requires_torch("2.8")
def test_issue_158786_qwen2vl(self):
# https://github.com/pytorch/pytorch/issues/158786
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.spatial_merge_size = 2 # Default

def forward(self, a):
pos_ids = []
for t, h, w in a:
t = t.item()
h = h.item()
w = w.item()
torch._constrain_as_size(t)
torch._constrain_as_size(h)
torch._constrain_as_size(w)
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()

wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
return pos_ids

model = Model()
inputs = torch.tensor(np.array([1, 98, 146]).reshape(1, 3))
ep = torch.export.export(model, (inputs,))
self.assertIn("torch.ops.aten.cat.default", str(ep))


if __name__ == "__main__":
unittest.main(verbosity=2)
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ class TestOnnxExportErrors(ExtTestCase):
def test_pytree_flatten_mamba_cache(self):
import torch
import torch.utils._pytree as py_pytree
from transformers.cache_utils import MambaCache

try:
from transformers.models.mamba.modeling_mamba import MambaCache
except ImportError:
from transformers.cache_utils import MambaCache

class _config:
def __init__(self):
Expand Down
15 changes: 9 additions & 6 deletions onnx_diagnostic/helpers/cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import transformers
import transformers.cache_utils

try:
from transformers.models.mamba.modeling_mamba import MambaCache
except ImportError:
from transformers.cache_utils import MambaCache


def flatten_unflatten_for_dynamic_shapes(
obj: Any,
Expand Down Expand Up @@ -242,10 +247,8 @@ def make_encoder_decoder_cache(
)


def make_mamba_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.MambaCache:
"Creates a :class:`transformers.cache_utils.MambaCache`."
def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache:
"Creates a ``MambaCache``."
dtype = key_value_pairs[0][0].dtype

class _config:
Expand All @@ -256,7 +259,7 @@ def __init__(self):
self.num_hidden_layers = len(key_value_pairs)
self.dtype = dtype

cache = transformers.cache_utils.MambaCache(
cache = MambaCache(
_config(),
max_batch_size=key_value_pairs[0][0].shape[0],
device=key_value_pairs[0][0].device,
Expand Down Expand Up @@ -286,7 +289,7 @@ def __init__(self):

def make_sliding_window_cache(
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
) -> transformers.cache_utils.MambaCache:
) -> transformers.cache_utils.SlidingWindowCache:
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."

class _config:
Expand Down
8 changes: 6 additions & 2 deletions onnx_diagnostic/tasks/text_generation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
import transformers
from ..helpers.cache_helper import (
make_dynamic_cache,
make_mamba_cache,
Expand Down Expand Up @@ -95,9 +94,14 @@ def get_inputs(
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)

if config is not None and config.__class__.__name__ == "FalconMambaConfig":
try:
from transformers.models.mamba.modeling_mamba import MambaCache
except ImportError:
from transformers.cache_utils import MambaCache

assert cls_cache in (
"MambaCache",
transformers.cache_utils.MambaCache,
MambaCache,
), f"Unexpected value for cls_cache={cls_cache} and config={config}"
seq_length_multiple = 8
sequence_length = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
import transformers
from transformers.cache_utils import (
DynamicCache,
MambaCache,
EncoderDecoderCache,
SlidingWindowCache,
StaticCache,
)

try:
from transformers.models.mamba.modeling_mamba import MambaCache
except ImportError:
from transformers.cache_utils import MambaCache

from ..helpers import string_type
from .serialization import _lower_name_with_

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import transformers
from transformers.cache_utils import (
DynamicCache,
MambaCache,
EncoderDecoderCache,
SlidingWindowCache,
StaticCache,
)

try:
from transformers.models.mamba.modeling_mamba import MambaCache
except ImportError:
from transformers.cache_utils import MambaCache
from transformers.modeling_outputs import BaseModelOutput
from ...helpers.cache_helper import make_static_cache
from . import make_serialization_function_for_dataclass
Expand Down
Loading