diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index 962398f2..8eb025ff 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -459,21 +459,31 @@ def test_couple_input_ds_0(self): T3x4 = torch.rand((3, 4)) T3x1 = torch.rand((3, 1)) Cls = CoupleInputsDynamicShapes - self.assertEmpty(Cls((T3x4,), {}, ({0: "batch"},)).invalid_paths()) - self.assertEmpty(Cls((T3x1,), {}, ({0: "batch"},)).invalid_paths()) - self.assertEmpty(Cls((), {"A": T3x1}, {"A": {0: "batch"}}).invalid_paths()) - self.assertEmpty(Cls((), {"A": T3x4}, {"A": {0: "batch"}}).invalid_paths()) + self.assertEmpty(Cls((T3x4,), {}, ({0: "batch"},)).invalid_dimensions_for_export()) + self.assertEmpty(Cls((T3x1,), {}, ({0: "batch"},)).invalid_dimensions_for_export()) + self.assertEmpty( + Cls((), {"A": T3x1}, {"A": {0: "batch"}}).invalid_dimensions_for_export() + ) + self.assertEmpty( + Cls((), {"A": T3x4}, {"A": {0: "batch"}}).invalid_dimensions_for_export() + ) T1x4 = torch.rand((1, 4)) T1x1 = torch.rand((1, 1)) Cls = CoupleInputsDynamicShapes - self.assertEqual([(0, "[0]")], Cls((T1x4,), {}, ({0: "batch"},)).invalid_paths()) - self.assertEqual([(0, "[0]")], Cls((T1x1,), {}, ({0: "batch"},)).invalid_paths()) self.assertEqual( - [("A", "[0]")], Cls((), {"A": T1x1}, {"A": {0: "batch"}}).invalid_paths() + ({0: "d=[1]"},), Cls((T1x4,), {}, ({0: "batch"},)).invalid_dimensions_for_export() + ) + self.assertEqual( + ({0: "d=[1]"},), Cls((T1x1,), {}, ({0: "batch"},)).invalid_dimensions_for_export() + ) + self.assertEqual( + {"A": {0: "d=[1]"}}, + Cls((), {"A": T1x1}, {"A": {0: "batch"}}).invalid_dimensions_for_export(), ) self.assertEqual( - [("A", "[0]")], Cls((), {"A": T1x4}, {"A": {0: "batch"}}).invalid_paths() + {"A": {0: "d=[1]"}}, + Cls((), {"A": T1x4}, {"A": {0: "batch"}}).invalid_dimensions_for_export(), ) def test_couple_input_ds_1(self): @@ -483,8 +493,13 @@ def test_couple_input_ds_1(self): ds_batch_seq = {0: "batch", 1: "seq"} args = (T3x4, T3x1) Cls = CoupleInputsDynamicShapes - self.assertEqual([], Cls(args, {}, (ds_batch, ds_batch)).invalid_paths()) - self.assertEqual([(1, "[1]")], Cls(args, {}, (ds_batch, ds_batch_seq)).invalid_paths()) + self.assertEqual( + None, Cls(args, {}, (ds_batch, ds_batch)).invalid_dimensions_for_export() + ) + self.assertEqual( + (None, {1: "d=[1]"}), + Cls(args, {}, (ds_batch, ds_batch_seq)).invalid_dimensions_for_export(), + ) def test_couple_input_ds_2(self): T3x1 = torch.rand((3, 1)) @@ -493,9 +508,15 @@ def test_couple_input_ds_2(self): ds_batch_seq = {0: "batch", 1: "seq"} kwargs = {"A": T3x4, "B": T3x1} Cls = CoupleInputsDynamicShapes - self.assertEqual([], Cls((), kwargs, {"A": ds_batch, "B": ds_batch}).invalid_paths()) self.assertEqual( - [("B", "[1]")], Cls((), kwargs, {"A": ds_batch, "B": ds_batch_seq}).invalid_paths() + None, + Cls((), kwargs, {"A": ds_batch, "B": ds_batch}).invalid_dimensions_for_export(), + ) + self.assertEqual( + {"B": {1: "d=[1]"}}, + Cls( + (), kwargs, {"A": ds_batch, "B": ds_batch_seq} + ).invalid_dimensions_for_export(), ) def test_couple_input_ds_3(self): @@ -506,11 +527,16 @@ def test_couple_input_ds_3(self): kwargs = {"A": T3x4, "B": (T3x1, T3x1)} Cls = CoupleInputsDynamicShapes self.assertEqual( - [], Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths() + None, + Cls( + (), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)} + ).invalid_dimensions_for_export(), ) self.assertEqual( - [("B", 1, "[1]")], - Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(), + {"B": (None, {1: "d=[1]"})}, + Cls( + (), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} + ).invalid_dimensions_for_export(), ) def test_couple_input_ds_cache(self): @@ -532,15 +558,15 @@ def test_couple_input_ds_cache(self): Cls = CoupleInputsDynamicShapes with bypass_export_some_errors(patch_transformers=True): self.assertEqual( - [], + None, Cls( (), kwargs, {"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])}, - ).invalid_paths(), + ).invalid_dimensions_for_export(), ) self.assertEqual( - [("B", 1, "DynamicCache", 1, "[2]"), ("B", 1, "DynamicCache", 3, "[2]")], + {"B": (None, [None, {2: "d=[1]"}, None, {2: "d=[1]"}])}, Cls( (), kwargs, @@ -548,7 +574,7 @@ def test_couple_input_ds_cache(self): "A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]), }, - ).invalid_paths(), + ).invalid_dimensions_for_export(), ) def test_couple_input_ds_args_kwargs_0(self): @@ -561,17 +587,22 @@ def test_couple_input_ds_args_kwargs_0(self): kwargs = {"A": T3x4, "B": (T3x1, T3x1)} Cls = CoupleInputsDynamicShapes self.assertEqual( - [], Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths() + None, + Cls( + args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)} + ).invalid_dimensions_for_export(), ) self.assertEqual( - [], + None, Cls( args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}, args_names=["X"] - ).invalid_paths(), + ).invalid_dimensions_for_export(), ) self.assertEqual( - [("B", 1, "[1]")], - Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(), + {"B": (None, {1: "d=[1]"})}, + Cls( + args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} + ).invalid_dimensions_for_export(), ) def test_couple_input_ds_args_kwargs_1(self): @@ -584,23 +615,67 @@ def test_couple_input_ds_args_kwargs_1(self): kwargs = {"A": T3x4, "B": (T3x1, T3x1)} Cls = CoupleInputsDynamicShapes self.assertEqual( - [], + None, + Cls( + args, + kwargs, + {"X": ds_batch, "A": ds_batch, "B": (ds_batch, ds_batch)}, + args_names=["X"], + ).invalid_dimensions_for_export(), + ) + self.assertEqual( + {"X": {1: "d=[1]"}, "B": (None, {1: "d=[1]"})}, + Cls( + args, + kwargs, + {"X": ds_batch_seq, "A": ds_batch, "B": (ds_batch, ds_batch_seq)}, + args_names=["X"], + ).invalid_dimensions_for_export(), + ) + + def test_couple_input_ds_replace_string(self): + T3x1 = torch.rand((3, 1)) + T3x4 = torch.rand((3, 4)) + T5x1 = torch.rand((5, 1)) + ds_batch = {0: "batch"} + ds_batch_seq = {0: "batch", 1: "seq"} + args = (T5x1,) + kwargs = {"A": T3x4, "B": (T3x1, T3x1)} + Cls = CoupleInputsDynamicShapes + self.assertEqual( + {"X": {0: "DYN"}, "A": {0: "DYN"}, "B": ({0: "DYN"}, {0: "DYN"})}, Cls( args, kwargs, {"X": ds_batch, "A": ds_batch, "B": (ds_batch, ds_batch)}, args_names=["X"], - ).invalid_paths(), + ).replace_string_by(value="DYN"), ) self.assertEqual( - [("X", "[1]"), ("B", 1, "[1]")], + { + "A": {0: "DYN"}, + "B": ({0: "DYN"}, {0: "DYN", 1: "DYN"}), + "X": {0: "DYN", 1: "DYN"}, + }, Cls( args, kwargs, {"X": ds_batch_seq, "A": ds_batch, "B": (ds_batch, ds_batch_seq)}, args_names=["X"], - ).invalid_paths(), + ).replace_string_by(value="DYN"), + ) + + def test_couple_input_ds_change_dynamic_dimensions(self): + T257 = torch.arange(2 * 5 * 7).reshape((2, 5, 7)) + T29 = torch.arange(2 * 9).reshape((2, 9)) + inst = CoupleInputsDynamicShapes( + (), + {"A": T257, "B": T29}, + {"A": {0: "batch", 2: "last"}, "B": {0: "batch", 1: "seq"}}, ) + new_input = inst.change_dynamic_dimensions() + self.assertEqual((3, 5, 8), new_input["A"].shape) + self.assertEqual((3, 10), new_input["B"].shape) if __name__ == "__main__": diff --git a/_unittests/ut_torch_models/test_hghub_model.py b/_unittests/ut_torch_models/test_hghub_model.py index 4fb9b39d..08699750 100644 --- a/_unittests/ut_torch_models/test_hghub_model.py +++ b/_unittests/ut_torch_models/test_hghub_model.py @@ -62,6 +62,7 @@ def test_get_untrained_model_with_inputs_tiny_gpt_neo(self): self.assertEqual((316712, 79178), (data["size"], data["n_weights"])) @hide_stdout() + @ignore_errors(OSError) def test_get_untrained_model_with_inputs_phi_2(self): mid = "microsoft/phi-2" data = get_untrained_model_with_inputs(mid, verbose=1) @@ -83,6 +84,7 @@ def test_get_untrained_model_with_inputs_beit(self): self.assertIn((data["size"], data["n_weights"]), [(111448, 27862), (56880, 14220)]) @hide_stdout() + @ignore_errors(OSError) def test_get_untrained_model_with_inputs_codellama(self): mid = "codellama/CodeLlama-7b-Python-hf" data = get_untrained_model_with_inputs(mid, verbose=1) diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index faf6e298..619f1077 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -1,5 +1,5 @@ import inspect -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from ..helpers import string_type @@ -71,7 +71,7 @@ def forward(self, x, y): ds = mi.guess_dynamic_shapes() pprint.pprint(ds) - **and and kwargs** + **args and kwargs** .. runpython:: :showcode: @@ -449,7 +449,10 @@ def validate_inputs_for_export( if len(self.inputs) == 1: return [] dyn_shapes = self.guess_dynamic_shapes() - return [CoupleInputsDynamicShapes(*i, dyn_shapes).invalid_paths() for i in self.inputs] + return [ + CoupleInputsDynamicShapes(*i, dyn_shapes).invalid_dimensions_for_export() + for i in self.inputs + ] class CoupleInputsDynamicShapes: @@ -488,28 +491,127 @@ def __str__(self) -> str: ] ) - def invalid_paths(self) -> List[Union[str, int]]: + def replace_string_by(self, value: Any = None): """ - Tells the inputs are valid based on the dynamic shapes definition. + Replaces string by the value ``torch.export.Dim.DYNAMIC`` + (default) or any other value specified by value. + + Example: + + .. runpython:: + :showcode: + + import torch + from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes + + T3x1 = torch.rand((3, 1)) + T3x4 = torch.rand((3, 4)) + ds_batch = {0: "batch"} + ds_batch_seq = {0: "batch", 1: "seq"} + kwargs = {"A": T3x4, "B": (T3x1, T3x1)} + ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} + print(CoupleInputsDynamicShapes((), kwargs, ds).replace_string_by()) + """ + return self._generic_walker( + lambda inputs, ds, value=value: self._replace_string_dim_tensor( + inputs, ds, value=value + ) + ) + + @classmethod + def _replace_string_dim_tensor(cls, inputs, ds, value=None): + assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}" + assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), ( + f"Unexpected types, inputs is a Tensor but ds is {ds}, " + f"a dictionary is expected to specify a dimension dimension" + ) + if value is None: + value = torch.export.Dim.DYNAMIC + new_ds = ds.copy() + for i, v in ds.items(): + if isinstance(v, str): + new_ds[i] = value + return new_ds + + def invalid_dimensions_for_export(self): + """ + Tells if the inputs are valid based on the dynamic shapes definition. The method assumes that all custom classes can be serialized. If some patches were applied to export, they should enabled while calling this method if the inputs contains such classes. The function checks that a dynamic dimension does not receive a value - of 0 or 1. It returns a list of invalid path. + of 0 or 1. It returns the unexpected values in the same structure as + the given dynamic shapes. + + Example: + + .. runpython:: + :showcode: + + import torch + from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes + + T3x1 = torch.rand((3, 1)) + T3x4 = torch.rand((3, 4)) + ds_batch = {0: "batch"} + ds_batch_seq = {0: "batch", 1: "seq"} + kwargs = {"A": T3x4, "B": (T3x1, T3x1)} + ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} + print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export()) + + In case it works, it shows: + + .. runpython:: + :showcode: + + import torch + from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes + + T3x2 = torch.rand((3, 2)) + T3x4 = torch.rand((3, 4)) + ds_batch = {0: "batch"} + ds_batch_seq = {0: "batch", 1: "seq"} + kwargs = {"A": T3x4, "B": (T3x2, T3x2)} + ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} + print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export()) + """ + return self._generic_walker(self._valid_shapes_tensor) + + @classmethod + def _valid_shapes_tensor(cls, inputs, ds): + assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}" + assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), ( + f"Unexpected types, inputs is a Tensor but ds is {ds}, " + f"a dictionary is expected to specify a dimension dimension" + ) + issues = {} + for i, d in enumerate(inputs.shape): + if i in ds and not isinstance(ds[i], int): + # dynamic then + if d in {0, 1}: + # export issues for sure + issues[i] = f"d=[{d}]" + return issues if issues else None + + def _generic_walker(self, processor: Callable): + """ + Generic deserializator walking through inputs and dynamic_shapes all along. + The function returns a result with the same structure as the dynamic shapes. """ if not self.args: assert isinstance(self.kwargs, dict) and isinstance(self.dynamic_shapes, dict), ( f"Type mismatch, args={string_type(self.args)} and " f"dynamic_shapes={self.dynamic_shapes} should have the same type." ) - return list(self._valid_shapes(self.kwargs, self.dynamic_shapes)) + return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes) + if not self.kwargs: assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), ( f"Type mismatch, args={string_type(self.args)} and " f"dynamic_shapes={self.dynamic_shapes} should have the same type." ) - return list(self._valid_shapes(self.args, self.dynamic_shapes)) + return self._generic_walker_step(processor, self.args, self.dynamic_shapes) assert isinstance(self.dynamic_shapes, dict), ( f"Both positional and named arguments (args and kwargs) are filled. " @@ -519,12 +621,12 @@ def invalid_paths(self) -> List[Union[str, int]]: self.dynamic_shapes ): # No dynamic shapes for the positional arguments. - return list(self._valid_shapes(self.kwargs, self.dynamic_shapes)) + return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes) if isinstance(self.args_names, list): if not set(self.args_names) & set(self.dynamic_shapes): # No dynamic shapes for the positional arguments. - return list(self._valid_shapes(self.kwargs, self.dynamic_shapes)) + return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes) assert self.args_names, ( "args and kwargs are filled, then args_names must be specified in " @@ -537,7 +639,7 @@ def invalid_paths(self) -> List[Union[str, int]]: ) kwargs = dict(zip(self.args_names, self.args)) kwargs.update(self.kwargs) - return list(self._valid_shapes(kwargs, self.dynamic_shapes)) + return self._generic_walker_step(processor, kwargs, self.dynamic_shapes) raise NotImplementedError( f"Not yet implemented when args is filled, " @@ -545,54 +647,144 @@ def invalid_paths(self) -> List[Union[str, int]]: ) @classmethod - def _valid_shapes( - cls, inputs: Any, ds: Any, prefix: Tuple[Union[int, str], ...] = () - ) -> Iterable: - assert all(isinstance(i, (int, str)) for i in prefix), f"Unexpected prefix {prefix}" + def _generic_walker_step(cls, processor: Callable, inputs, ds): if isinstance(inputs, torch.Tensor): - assert isinstance(ds, dict) and all( - isinstance(s, int) for s in ds - ), f"Unexpected types, inputs is a Tensor but ds={ds}, prefix={prefix}" - for i, d in enumerate(inputs.shape): - if i in ds and not isinstance(ds[i], int): - # dynamic then - if d in {0, 1}: - # export issues for sure - yield (*prefix, f"[{i}]") - else: - if isinstance(inputs, (int, float, str)): - pass - elif isinstance(inputs, (tuple, list, dict)): - assert type(ds) is type(inputs), ( - f"Type mismatch between inputs {type(inputs)} " - f"and ds={type(ds)}, prefix={prefix!r}" - ) - assert len(ds) == len(inputs), ( - f"Length mismatch between inputs {len(inputs)} " - f"and ds={len(ds)}, prefix={prefix!r}\n" - f"inputs={string_type(inputs, with_shape=True)}, ds={ds}" + return processor(inputs, ds) + if isinstance(inputs, (int, float, str)): + return None + if isinstance(inputs, (tuple, list, dict)): + assert type(ds) is type( + inputs + ), f"Type mismatch between inputs {type(inputs)} and ds={type(ds)}" + assert len(ds) == len(inputs), ( + f"Length mismatch between inputs {len(inputs)} " + f"and ds={len(ds)}\n" + f"inputs={string_type(inputs, with_shape=True)}, ds={ds}" + ) + if isinstance(inputs, (tuple, list)): + value = [] + for i, d in zip(inputs, ds): + value.append(cls._generic_walker_step(processor, i, d)) + return ( + (value if isinstance(ds, list) else tuple(value)) + if any(v is not None for v in value) + else None ) - if isinstance(inputs, (tuple, list)): - for ind, (i, d) in enumerate(zip(inputs, ds)): - for path in cls._valid_shapes(i, d, prefix=(*prefix, ind)): - yield path - else: - assert set(inputs) == set(ds), ( - f"Keys mismatch between inputs {set(inputs)} " - f"and ds={set(ds)}, prefix={prefix!r}" - ) - for k, v in inputs.items(): - for path in cls._valid_shapes(v, ds[k], prefix=(*prefix, k)): - yield path - else: - # A custom class. - assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, ( - f"Class {inputs.__class__.__name__!r} was not registered using " - f"torch.utils._pytree.register_pytree_node, it is not possible to " - f"map this class with the given dynamic shapes." + assert set(inputs) == set( + ds + ), f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}" + dvalue = {} + for k, v in inputs.items(): + t = cls._generic_walker_step(processor, v, ds[k]) + if t is not None: + dvalue[k] = t + return dvalue if dvalue else None + + # A custom class. + assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, ( + f"Class {inputs.__class__.__name__!r} was not registered using " + f"torch.utils._pytree.register_pytree_node, it is not possible to " + f"map this class with the given dynamic shapes." + ) + flat, _spec = torch.utils._pytree.tree_flatten(inputs) + return cls._generic_walker_step(processor, flat, ds) + + class ChangeDimensionProcessor: + def __init__(self): + self.mapping = {} + + def _build_new_shape( + self, shape: Tuple[int, ...], ds: Dict[int, Any] + ) -> Tuple[int, ...]: + new_shape = list(shape) + for i in range(len(shape)): + if i in ds: + if isinstance(ds[i], str): + d = ds[i] + elif isinstance( + ds[i], + ( + torch.export.dynamic_shapes._DerivedDim, + torch.export.dynamic_shapes._Dim, + ), + ): + d = str(ds[i]) + elif not isinstance(ds[i], int): + raise NotImplementedError(f"Unable to handle type {ds[i]} in {ds}") + if d in self.mapping: + new_dim = self.mapping[d] + else: + new_dim = shape[i] + 1 + self.mapping[d] = new_dim + new_shape[i] = new_dim + return tuple(new_shape) + + def _build_new_tensor(self, tensor: torch.Tensor, new_shape: Tuple[int, ...]): + rank = len(tensor.shape) + for i in range(len(tensor.shape)): + d0 = tensor.shape[i] + d1 = new_shape[i] + if d0 == d1: + continue + alt_shape = list(tensor.shape) + alt_shape[i] = d1 + new_tensor = torch.zeros( + tuple(alt_shape), dtype=tensor.dtype, device=tensor.device ) - flat, _spec = torch.utils._pytree.tree_flatten(inputs) - for path in cls._valid_shapes( - flat, ds, prefix=(*prefix, inputs.__class__.__name__) - ): - yield path + mind = min(d0, d1) + indices = [slice(None) for _ in range(rank)] + indices[i] = slice(0, mind) + ind = tuple(indices) + new_tensor[ind] = tensor[ind] + if d1 > mind: + for k in range(d1 - mind): + indices0 = [slice(None) for _ in range(rank)] + indices1 = [slice(None) for _ in range(rank)] + indices1[i] = mind + k + indices0[i] = k % mind + new_tensor[tuple(indices1)] = tensor[tuple(indices0)] + tensor = new_tensor + return tensor + + def __call__(self, inputs, ds): + assert isinstance( + inputs, torch.Tensor + ), f"unexpected type for inputs {type(inputs)}" + assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), ( + f"Unexpected types, inputs is a Tensor but ds is {ds}, " + f"a dictionary is expected to specify a dimension dimension" + ) + new_shape = self._build_new_shape(inputs.shape, ds) + return self._build_new_tensor(inputs, new_shape) + + def change_dynamic_dimensions(self): + """ + A model exported with dynamic shapes is not necessarily dynamic + just because the user specified dynamic shapes. The algorithm + may discover that a dimension cannot be dynamic and then continues + the export making the assumption it is static. That may lead a wrong + model. This function produces a new set of inputs with different values + for the dimension than the first ones, assuming they were used to export + the model. + + Example: + + .. runpython:: + :showcode: + + import torch + from onnx_diagnostic.helpers import string_type + from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes + + T3x15 = torch.rand((3, 15)) + T3x20 = torch.rand((3, 20)) + T3x4 = torch.rand((3, 4)) + ds_batch = {0: "batch"} + ds_batch_seq = {0: "batch", 1: "seq"} + kwargs = {"A": T3x4, "B": (T3x15, T3x20)} + ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} + new_kwargs = CoupleInputsDynamicShapes((), kwargs, ds).change_dynamic_dimensions() + print("before:", string_type(kwargs, with_shape=True)) + print("-after:", string_type(new_kwargs, with_shape=True)) + """ + return self._generic_walker(self.ChangeDimensionProcessor()) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 225463d3..0ae199f6 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -47,7 +47,7 @@ def _patch_make_causal_mask( if sys.version_info[:2] <= (3, 11): @dataclass - class kkpatched_AttentionMaskConverter: + class patched_AttentionMaskConverter: """ Patches ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. @@ -72,7 +72,7 @@ def _make_causal_mask( else: @dataclass - class kkpatched_AttentionMaskConverter: + class patched_AttentionMaskConverter: """ Patches ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index e39d2b08..dfe1d9c3 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -276,10 +276,7 @@ def validate_model( ) if verbose: print(f"[validate_model] new inputs: {string_type(data['inputs'])}") - print( - f"[validate_model] new dynnamic_shapes: " - f"{_ds_clean(data['dynamic_shapes'])}" - ) + print(f"[validate_model] new dynamic_hapes: {_ds_clean(data['dynamic_shapes'])}") if not empty(dtype): if isinstance(dtype, str):