From a4a5c0e6aa16eef0cc23e5acd098d88e88e23210 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 3 Apr 2025 14:35:28 +0200 Subject: [PATCH 1/5] refactor algorithm to validate shapes --- _unittests/ut_export/test_dynamic_shapes.py | 37 +++--- onnx_diagnostic/export/dynamic_shapes.py | 126 +++++++++++--------- 2 files changed, 91 insertions(+), 72 deletions(-) diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index 962398f2..5b6f8841 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -467,13 +467,13 @@ def test_couple_input_ds_0(self): T1x4 = torch.rand((1, 4)) T1x1 = torch.rand((1, 1)) Cls = CoupleInputsDynamicShapes - self.assertEqual([(0, "[0]")], Cls((T1x4,), {}, ({0: "batch"},)).invalid_paths()) - self.assertEqual([(0, "[0]")], Cls((T1x1,), {}, ({0: "batch"},)).invalid_paths()) + self.assertEqual(({0: "d=[1]"},), Cls((T1x4,), {}, ({0: "batch"},)).invalid_paths()) + self.assertEqual(({0: "d=[1]"},), Cls((T1x1,), {}, ({0: "batch"},)).invalid_paths()) self.assertEqual( - [("A", "[0]")], Cls((), {"A": T1x1}, {"A": {0: "batch"}}).invalid_paths() + {"A": {0: "d=[1]"}}, Cls((), {"A": T1x1}, {"A": {0: "batch"}}).invalid_paths() ) self.assertEqual( - [("A", "[0]")], Cls((), {"A": T1x4}, {"A": {0: "batch"}}).invalid_paths() + {"A": {0: "d=[1]"}}, Cls((), {"A": T1x4}, {"A": {0: "batch"}}).invalid_paths() ) def test_couple_input_ds_1(self): @@ -483,8 +483,10 @@ def test_couple_input_ds_1(self): ds_batch_seq = {0: "batch", 1: "seq"} args = (T3x4, T3x1) Cls = CoupleInputsDynamicShapes - self.assertEqual([], Cls(args, {}, (ds_batch, ds_batch)).invalid_paths()) - self.assertEqual([(1, "[1]")], Cls(args, {}, (ds_batch, ds_batch_seq)).invalid_paths()) + self.assertEqual(None, Cls(args, {}, (ds_batch, ds_batch)).invalid_paths()) + self.assertEqual( + (None, {1: "d=[1]"}), Cls(args, {}, (ds_batch, ds_batch_seq)).invalid_paths() + ) def test_couple_input_ds_2(self): T3x1 = torch.rand((3, 1)) @@ -493,9 +495,10 @@ def test_couple_input_ds_2(self): ds_batch_seq = {0: "batch", 1: "seq"} kwargs = {"A": T3x4, "B": T3x1} Cls = CoupleInputsDynamicShapes - self.assertEqual([], Cls((), kwargs, {"A": ds_batch, "B": ds_batch}).invalid_paths()) + self.assertEqual(None, Cls((), kwargs, {"A": ds_batch, "B": ds_batch}).invalid_paths()) self.assertEqual( - [("B", "[1]")], Cls((), kwargs, {"A": ds_batch, "B": ds_batch_seq}).invalid_paths() + {"B": {1: "d=[1]"}}, + Cls((), kwargs, {"A": ds_batch, "B": ds_batch_seq}).invalid_paths(), ) def test_couple_input_ds_3(self): @@ -506,10 +509,10 @@ def test_couple_input_ds_3(self): kwargs = {"A": T3x4, "B": (T3x1, T3x1)} Cls = CoupleInputsDynamicShapes self.assertEqual( - [], Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths() + None, Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths() ) self.assertEqual( - [("B", 1, "[1]")], + {"B": (None, {1: "d=[1]"})}, Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(), ) @@ -532,7 +535,7 @@ def test_couple_input_ds_cache(self): Cls = CoupleInputsDynamicShapes with bypass_export_some_errors(patch_transformers=True): self.assertEqual( - [], + None, Cls( (), kwargs, @@ -540,7 +543,7 @@ def test_couple_input_ds_cache(self): ).invalid_paths(), ) self.assertEqual( - [("B", 1, "DynamicCache", 1, "[2]"), ("B", 1, "DynamicCache", 3, "[2]")], + {"B": (None, [None, {2: "d=[1]"}, None, {2: "d=[1]"}])}, Cls( (), kwargs, @@ -561,16 +564,16 @@ def test_couple_input_ds_args_kwargs_0(self): kwargs = {"A": T3x4, "B": (T3x1, T3x1)} Cls = CoupleInputsDynamicShapes self.assertEqual( - [], Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths() + None, Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths() ) self.assertEqual( - [], + None, Cls( args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}, args_names=["X"] ).invalid_paths(), ) self.assertEqual( - [("B", 1, "[1]")], + {"B": (None, {1: "d=[1]"})}, Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(), ) @@ -584,7 +587,7 @@ def test_couple_input_ds_args_kwargs_1(self): kwargs = {"A": T3x4, "B": (T3x1, T3x1)} Cls = CoupleInputsDynamicShapes self.assertEqual( - [], + None, Cls( args, kwargs, @@ -593,7 +596,7 @@ def test_couple_input_ds_args_kwargs_1(self): ).invalid_paths(), ) self.assertEqual( - [("X", "[1]"), ("B", 1, "[1]")], + {"X": {1: "d=[1]"}, "B": (None, {1: "d=[1]"})}, Cls( args, kwargs, diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index faf6e298..a89877bb 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -488,7 +488,7 @@ def __str__(self) -> str: ] ) - def invalid_paths(self) -> List[Union[str, int]]: + def invalid_paths(self) -> Any: """ Tells the inputs are valid based on the dynamic shapes definition. The method assumes that all custom classes can be serialized. @@ -498,18 +498,42 @@ def invalid_paths(self) -> List[Union[str, int]]: The function checks that a dynamic dimension does not receive a value of 0 or 1. It returns a list of invalid path. """ + return self._generic_walker(self._valid_shapes_tensor) + + @classmethod + def _valid_shapes_tensor(cls, inputs: Any, ds: Any) -> Iterable: + assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}" + assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), ( + f"Unexpected types, inputs is a Tensor but ds is {ds}, " + f"a dictionary is expected to specify a dimension dimension" + ) + issues = {} + for i, d in enumerate(inputs.shape): + if i in ds and not isinstance(ds[i], int): + # dynamic then + if d in {0, 1}: + # export issues for sure + issues[i] = f"d=[{d}]" + return issues if issues else None + + def _generic_walker(self, method_to_call: Callable) -> Any: + """ + Generic deserializator walking through inputs and dynamic_shapes all along. + The function returns a result with the same structure as the dynamic shapes. + """ if not self.args: assert isinstance(self.kwargs, dict) and isinstance(self.dynamic_shapes, dict), ( f"Type mismatch, args={string_type(self.args)} and " f"dynamic_shapes={self.dynamic_shapes} should have the same type." ) - return list(self._valid_shapes(self.kwargs, self.dynamic_shapes)) + return self._generic_walker_step(method_to_call, self.kwargs, self.dynamic_shapes) + if not self.kwargs: assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), ( f"Type mismatch, args={string_type(self.args)} and " f"dynamic_shapes={self.dynamic_shapes} should have the same type." ) - return list(self._valid_shapes(self.args, self.dynamic_shapes)) + return self._generic_walker_step(method_to_call, self.args, self.dynamic_shapes) assert isinstance(self.dynamic_shapes, dict), ( f"Both positional and named arguments (args and kwargs) are filled. " @@ -519,12 +543,14 @@ def invalid_paths(self) -> List[Union[str, int]]: self.dynamic_shapes ): # No dynamic shapes for the positional arguments. - return list(self._valid_shapes(self.kwargs, self.dynamic_shapes)) + return self._generic_walker_step(method_to_call, self.kwargs, self.dynamic_shapes) if isinstance(self.args_names, list): if not set(self.args_names) & set(self.dynamic_shapes): # No dynamic shapes for the positional arguments. - return list(self._valid_shapes(self.kwargs, self.dynamic_shapes)) + return self._generic_walker_step( + method_to_call, self.kwargs, self.dynamic_shapes + ) assert self.args_names, ( "args and kwargs are filled, then args_names must be specified in " @@ -537,7 +563,7 @@ def invalid_paths(self) -> List[Union[str, int]]: ) kwargs = dict(zip(self.args_names, self.args)) kwargs.update(self.kwargs) - return list(self._valid_shapes(kwargs, self.dynamic_shapes)) + return self._generic_walker_step(method_to_call, kwargs, self.dynamic_shapes) raise NotImplementedError( f"Not yet implemented when args is filled, " @@ -545,54 +571,44 @@ def invalid_paths(self) -> List[Union[str, int]]: ) @classmethod - def _valid_shapes( - cls, inputs: Any, ds: Any, prefix: Tuple[Union[int, str], ...] = () - ) -> Iterable: - assert all(isinstance(i, (int, str)) for i in prefix), f"Unexpected prefix {prefix}" + def _generic_walker_step(cls, method_to_call: Callable, inputs: Any, ds: Any) -> Iterable: if isinstance(inputs, torch.Tensor): - assert isinstance(ds, dict) and all( - isinstance(s, int) for s in ds - ), f"Unexpected types, inputs is a Tensor but ds={ds}, prefix={prefix}" - for i, d in enumerate(inputs.shape): - if i in ds and not isinstance(ds[i], int): - # dynamic then - if d in {0, 1}: - # export issues for sure - yield (*prefix, f"[{i}]") - else: - if isinstance(inputs, (int, float, str)): - pass - elif isinstance(inputs, (tuple, list, dict)): - assert type(ds) is type(inputs), ( - f"Type mismatch between inputs {type(inputs)} " - f"and ds={type(ds)}, prefix={prefix!r}" - ) - assert len(ds) == len(inputs), ( - f"Length mismatch between inputs {len(inputs)} " - f"and ds={len(ds)}, prefix={prefix!r}\n" - f"inputs={string_type(inputs, with_shape=True)}, ds={ds}" - ) - if isinstance(inputs, (tuple, list)): - for ind, (i, d) in enumerate(zip(inputs, ds)): - for path in cls._valid_shapes(i, d, prefix=(*prefix, ind)): - yield path - else: - assert set(inputs) == set(ds), ( - f"Keys mismatch between inputs {set(inputs)} " - f"and ds={set(ds)}, prefix={prefix!r}" - ) - for k, v in inputs.items(): - for path in cls._valid_shapes(v, ds[k], prefix=(*prefix, k)): - yield path - else: - # A custom class. - assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, ( - f"Class {inputs.__class__.__name__!r} was not registered using " - f"torch.utils._pytree.register_pytree_node, it is not possible to " - f"map this class with the given dynamic shapes." + return method_to_call(inputs, ds) + if isinstance(inputs, (int, float, str)): + return None + if isinstance(inputs, (tuple, list, dict)): + assert type(ds) is type( + inputs + ), f"Type mismatch between inputs {type(inputs)} and ds={type(ds)}" + assert len(ds) == len(inputs), ( + f"Length mismatch between inputs {len(inputs)} " + f"and ds={len(ds)}\n" + f"inputs={string_type(inputs, with_shape=True)}, ds={ds}" + ) + if isinstance(inputs, (tuple, list)): + value = [] + for i, d in zip(inputs, ds): + value.append(cls._generic_walker_step(method_to_call, i, d)) + return ( + (value if isinstance(ds, list) else tuple(value)) + if any(v is not None for v in value) + else None ) - flat, _spec = torch.utils._pytree.tree_flatten(inputs) - for path in cls._valid_shapes( - flat, ds, prefix=(*prefix, inputs.__class__.__name__) - ): - yield path + assert set(inputs) == set( + ds + ), f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}" + dvalue = {} + for k, v in inputs.items(): + t = cls._generic_walker_step(method_to_call, v, ds[k]) + if t is not None: + dvalue[k] = t + return dvalue if dvalue else None + + # A custom class. + assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, ( + f"Class {inputs.__class__.__name__!r} was not registered using " + f"torch.utils._pytree.register_pytree_node, it is not possible to " + f"map this class with the given dynamic shapes." + ) + flat, _spec = torch.utils._pytree.tree_flatten(inputs) + return cls._generic_walker_step(method_to_call, flat, ds) From 735b3738d4984467a7c303eaf9496b89f1c9c027 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 3 Apr 2025 14:44:37 +0200 Subject: [PATCH 2/5] issues --- .../ut_torch_models/test_hghub_model.py | 2 ++ onnx_diagnostic/export/dynamic_shapes.py | 30 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/_unittests/ut_torch_models/test_hghub_model.py b/_unittests/ut_torch_models/test_hghub_model.py index 4fb9b39d..08699750 100644 --- a/_unittests/ut_torch_models/test_hghub_model.py +++ b/_unittests/ut_torch_models/test_hghub_model.py @@ -62,6 +62,7 @@ def test_get_untrained_model_with_inputs_tiny_gpt_neo(self): self.assertEqual((316712, 79178), (data["size"], data["n_weights"])) @hide_stdout() + @ignore_errors(OSError) def test_get_untrained_model_with_inputs_phi_2(self): mid = "microsoft/phi-2" data = get_untrained_model_with_inputs(mid, verbose=1) @@ -83,6 +84,7 @@ def test_get_untrained_model_with_inputs_beit(self): self.assertIn((data["size"], data["n_weights"]), [(111448, 27862), (56880, 14220)]) @hide_stdout() + @ignore_errors(OSError) def test_get_untrained_model_with_inputs_codellama(self): mid = "codellama/CodeLlama-7b-Python-hf" data = get_untrained_model_with_inputs(mid, verbose=1) diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index a89877bb..1cb14ca4 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -1,5 +1,5 @@ import inspect -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from ..helpers import string_type @@ -488,7 +488,7 @@ def __str__(self) -> str: ] ) - def invalid_paths(self) -> Any: + def invalid_paths(self): """ Tells the inputs are valid based on the dynamic shapes definition. The method assumes that all custom classes can be serialized. @@ -501,7 +501,7 @@ def invalid_paths(self) -> Any: return self._generic_walker(self._valid_shapes_tensor) @classmethod - def _valid_shapes_tensor(cls, inputs: Any, ds: Any) -> Iterable: + def _valid_shapes_tensor(cls, inputs, ds): assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}" assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), ( f"Unexpected types, inputs is a Tensor but ds is {ds}, " @@ -516,7 +516,7 @@ def _valid_shapes_tensor(cls, inputs: Any, ds: Any) -> Iterable: issues[i] = f"d=[{d}]" return issues if issues else None - def _generic_walker(self, method_to_call: Callable) -> Any: + def _generic_walker(self, processor: Callable): """ Generic deserializator walking through inputs and dynamic_shapes all along. The function returns a result with the same structure as the dynamic shapes. @@ -526,14 +526,14 @@ def _generic_walker(self, method_to_call: Callable) -> Any: f"Type mismatch, args={string_type(self.args)} and " f"dynamic_shapes={self.dynamic_shapes} should have the same type." ) - return self._generic_walker_step(method_to_call, self.kwargs, self.dynamic_shapes) + return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes) if not self.kwargs: assert isinstance(self.args, tuple) and isinstance(self.dynamic_shapes, tuple), ( f"Type mismatch, args={string_type(self.args)} and " f"dynamic_shapes={self.dynamic_shapes} should have the same type." ) - return self._generic_walker_step(method_to_call, self.args, self.dynamic_shapes) + return self._generic_walker_step(processor, self.args, self.dynamic_shapes) assert isinstance(self.dynamic_shapes, dict), ( f"Both positional and named arguments (args and kwargs) are filled. " @@ -543,14 +543,12 @@ def _generic_walker(self, method_to_call: Callable) -> Any: self.dynamic_shapes ): # No dynamic shapes for the positional arguments. - return self._generic_walker_step(method_to_call, self.kwargs, self.dynamic_shapes) + return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes) if isinstance(self.args_names, list): if not set(self.args_names) & set(self.dynamic_shapes): # No dynamic shapes for the positional arguments. - return self._generic_walker_step( - method_to_call, self.kwargs, self.dynamic_shapes - ) + return self._generic_walker_step(processor, self.kwargs, self.dynamic_shapes) assert self.args_names, ( "args and kwargs are filled, then args_names must be specified in " @@ -563,7 +561,7 @@ def _generic_walker(self, method_to_call: Callable) -> Any: ) kwargs = dict(zip(self.args_names, self.args)) kwargs.update(self.kwargs) - return self._generic_walker_step(method_to_call, kwargs, self.dynamic_shapes) + return self._generic_walker_step(processor, kwargs, self.dynamic_shapes) raise NotImplementedError( f"Not yet implemented when args is filled, " @@ -571,9 +569,9 @@ def _generic_walker(self, method_to_call: Callable) -> Any: ) @classmethod - def _generic_walker_step(cls, method_to_call: Callable, inputs: Any, ds: Any) -> Iterable: + def _generic_walker_step(cls, processor: Callable, inputs, ds): if isinstance(inputs, torch.Tensor): - return method_to_call(inputs, ds) + return processor(inputs, ds) if isinstance(inputs, (int, float, str)): return None if isinstance(inputs, (tuple, list, dict)): @@ -588,7 +586,7 @@ def _generic_walker_step(cls, method_to_call: Callable, inputs: Any, ds: Any) -> if isinstance(inputs, (tuple, list)): value = [] for i, d in zip(inputs, ds): - value.append(cls._generic_walker_step(method_to_call, i, d)) + value.append(cls._generic_walker_step(processor, i, d)) return ( (value if isinstance(ds, list) else tuple(value)) if any(v is not None for v in value) @@ -599,7 +597,7 @@ def _generic_walker_step(cls, method_to_call: Callable, inputs: Any, ds: Any) -> ), f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}" dvalue = {} for k, v in inputs.items(): - t = cls._generic_walker_step(method_to_call, v, ds[k]) + t = cls._generic_walker_step(processor, v, ds[k]) if t is not None: dvalue[k] = t return dvalue if dvalue else None @@ -611,4 +609,4 @@ def _generic_walker_step(cls, method_to_call: Callable, inputs: Any, ds: Any) -> f"map this class with the given dynamic shapes." ) flat, _spec = torch.utils._pytree.tree_flatten(inputs) - return cls._generic_walker_step(method_to_call, flat, ds) + return cls._generic_walker_step(processor, flat, ds) From ab0f3de01d94883ce1ff75000a557bda898f9253 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 3 Apr 2025 15:52:51 +0200 Subject: [PATCH 3/5] add change dynamic --- _unittests/ut_export/test_dynamic_shapes.py | 44 ++++++ onnx_diagnostic/export/dynamic_shapes.py | 156 +++++++++++++++++++- 2 files changed, 199 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index 5b6f8841..71980b64 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -605,6 +605,50 @@ def test_couple_input_ds_args_kwargs_1(self): ).invalid_paths(), ) + def test_couple_input_ds_replace_string(self): + T3x1 = torch.rand((3, 1)) + T3x4 = torch.rand((3, 4)) + T5x1 = torch.rand((5, 1)) + ds_batch = {0: "batch"} + ds_batch_seq = {0: "batch", 1: "seq"} + args = (T5x1,) + kwargs = {"A": T3x4, "B": (T3x1, T3x1)} + Cls = CoupleInputsDynamicShapes + self.assertEqual( + {"X": {0: "DYN"}, "A": {0: "DYN"}, "B": ({0: "DYN"}, {0: "DYN"})}, + Cls( + args, + kwargs, + {"X": ds_batch, "A": ds_batch, "B": (ds_batch, ds_batch)}, + args_names=["X"], + ).replace_string_by(value="DYN"), + ) + self.assertEqual( + { + "A": {0: "DYN"}, + "B": ({0: "DYN"}, {0: "DYN", 1: "DYN"}), + "X": {0: "DYN", 1: "DYN"}, + }, + Cls( + args, + kwargs, + {"X": ds_batch_seq, "A": ds_batch, "B": (ds_batch, ds_batch_seq)}, + args_names=["X"], + ).replace_string_by(value="DYN"), + ) + + def test_couple_input_ds_change_dynamic_dimensions(self): + T257 = torch.arange(2 * 5 * 7).reshape((2, 5, 7)) + T29 = torch.arange(2 * 9).reshape((2, 9)) + inst = CoupleInputsDynamicShapes( + (), + {"A": T257, "B": T29}, + {"A": {0: "batch", 2: "last"}, "B": {0: "batch", 1: "seq"}}, + ) + new_input = inst.change_dynamic_dimensions() + self.assertEqual((3, 5, 8), new_input["A"].shape) + self.assertEqual((3, 10), new_input["B"].shape) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index 1cb14ca4..58209b83 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -488,15 +488,73 @@ def __str__(self) -> str: ] ) + def replace_string_by(self, value: Any = None): + """ + Replaces string by the value ``torch.export.Dim.DYNAMIC`` + (default) or any other value specified by value. + + Example: + + .. runpython:: + :showcode: + + import torch + from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes + + T3x1 = torch.rand((3, 1)) + T3x4 = torch.rand((3, 4)) + ds_batch = {0: "batch"} + ds_batch_seq = {0: "batch", 1: "seq"} + kwargs = {"A": T3x4, "B": (T3x1, T3x1)} + ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} + print(CoupleInputsDynamicShapes((), kwargs, ds).replace_string_by()) + """ + return self._generic_walker( + lambda inputs, ds, value=value: self._replace_string_dim_tensor( + inputs, ds, value=value + ) + ) + + @classmethod + def _replace_string_dim_tensor(cls, inputs, ds, value=None): + assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}" + assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), ( + f"Unexpected types, inputs is a Tensor but ds is {ds}, " + f"a dictionary is expected to specify a dimension dimension" + ) + if value is None: + value = torch.export.Dim.DYNAMIC + new_ds = ds.copy() + for i, v in ds.items(): + if isinstance(v, str): + new_ds[i] = value + return new_ds + def invalid_paths(self): """ - Tells the inputs are valid based on the dynamic shapes definition. + Tells if the inputs are valid based on the dynamic shapes definition. The method assumes that all custom classes can be serialized. If some patches were applied to export, they should enabled while calling this method if the inputs contains such classes. The function checks that a dynamic dimension does not receive a value of 0 or 1. It returns a list of invalid path. + + Example: + + .. runpython:: + :showcode: + + import torch + from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes + + T3x1 = torch.rand((3, 1)) + T3x4 = torch.rand((3, 4)) + ds_batch = {0: "batch"} + ds_batch_seq = {0: "batch", 1: "seq"} + kwargs = {"A": T3x4, "B": (T3x1, T3x1)} + ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} + print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_paths()) """ return self._generic_walker(self._valid_shapes_tensor) @@ -610,3 +668,99 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds): ) flat, _spec = torch.utils._pytree.tree_flatten(inputs) return cls._generic_walker_step(processor, flat, ds) + + class ChangeDimensionProcessor: + def __init__(self): + self.mapping = {} + + def _build_new_shape( + self, shape: Tuple[int, ...], ds: Dict[int, Any] + ) -> Tuple[int, ...]: + new_shape = list(shape) + for i in range(len(shape)): + if i in ds: + if isinstance(ds[i], str): + d = ds[i] + elif isinstance( + ds[i], + ( + torch.export.dynamic_shapes._DerivedDim, + torch.export.dynamic_shapes._Dim, + ), + ): + d = str(ds[i]) + elif not isinstance(ds[i], int): + raise NotImplementedError(f"Unable to handle type {ds[i]} in {ds}") + if d in self.mapping: + new_dim = self.mapping[d] + else: + new_dim = shape[i] + 1 + self.mapping[d] = new_dim + new_shape[i] = new_dim + return tuple(new_shape) + + def _build_new_tensor(self, tensor: torch.Tensor, new_shape: Tuple[int, ...]): + rank = len(tensor.shape) + for i in range(len(tensor.shape)): + d0 = tensor.shape[i] + d1 = new_shape[i] + if d0 == d1: + continue + alt_shape = list(tensor.shape) + alt_shape[i] = d1 + new_tensor = torch.zeros( + tuple(alt_shape), dtype=tensor.dtype, device=tensor.device + ) + mind = min(d0, d1) + indices = [slice(None) for _ in range(rank)] + indices[i] = slice(0, mind) + ind = tuple(indices) + new_tensor[ind] = tensor[ind] + if d1 > mind: + for k in range(d1 - mind): + indices0 = [slice(None) for _ in range(rank)] + indices1 = [slice(None) for _ in range(rank)] + indices1[i] = mind + k + indices0[i] = k % mind + new_tensor[tuple(indices1)] = tensor[tuple(indices0)] + tensor = new_tensor + return tensor + + def __call__(self, inputs, ds): + assert isinstance( + inputs, torch.Tensor + ), f"unexpected type for inputs {type(inputs)}" + assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), ( + f"Unexpected types, inputs is a Tensor but ds is {ds}, " + f"a dictionary is expected to specify a dimension dimension" + ) + new_shape = self._build_new_shape(inputs.shape, ds) + return self._build_new_tensor(inputs, new_shape) + + def change_dynamic_dimensions(self): + """ + A model exported with dynamic shapes is not necessarily dynamic + just because the user specified dynamic shapes. The algorithm + may discover that a dimension cannot be dynamic and then continues + the export making the assumption it is static. That may lead a wrong + model. This function produces a new set of inputs with different values + for the dimension than the first ones, assuming they were used to export + the model. + + Example: + + .. runpython:: + :showcode: + + import torch + from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes + + T3x1 = torch.rand((3, 1)) + T3x4 = torch.rand((3, 4)) + ds_batch = {0: "batch"} + ds_batch_seq = {0: "batch", 1: "seq"} + kwargs = {"A": T3x4, "B": (T3x1, T3x1)} + ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} + print(CoupleInputsDynamicShapes((), kwargs, ds).change_dynamic_dimension()) + """ + return self._generic_walker(self.ChangeDimensionProcessor()) From 6e1c7e6de4aeb1a50fd74ae19ccf00a75796f726 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 3 Apr 2025 17:04:19 +0200 Subject: [PATCH 4/5] fix documentation --- onnx_diagnostic/export/dynamic_shapes.py | 10 +++++++--- .../torch_export_patches/patches/patch_transformers.py | 4 ++-- onnx_diagnostic/torch_models/test_helper.py | 3 +-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index 58209b83..e8db99ca 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -753,14 +753,18 @@ def change_dynamic_dimensions(self): :showcode: import torch + from onnx_diagnostic.helpers import string_type from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes - T3x1 = torch.rand((3, 1)) + T3x15 = torch.rand((3, 15)) + T3x20 = torch.rand((3, 20)) T3x4 = torch.rand((3, 4)) ds_batch = {0: "batch"} ds_batch_seq = {0: "batch", 1: "seq"} - kwargs = {"A": T3x4, "B": (T3x1, T3x1)} + kwargs = {"A": T3x4, "B": (T3x15, T3x20)} ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} - print(CoupleInputsDynamicShapes((), kwargs, ds).change_dynamic_dimension()) + new_kwargs = CoupleInputsDynamicShapes((), kwargs, ds).change_dynamic_dimensions() + print("before:", string_type(kwargs, with_shape=True)) + print("-after:", string_type(new_kwargs, with_shape=True)) """ return self._generic_walker(self.ChangeDimensionProcessor()) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 225463d3..0ae199f6 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -47,7 +47,7 @@ def _patch_make_causal_mask( if sys.version_info[:2] <= (3, 11): @dataclass - class kkpatched_AttentionMaskConverter: + class patched_AttentionMaskConverter: """ Patches ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. @@ -72,7 +72,7 @@ def _make_causal_mask( else: @dataclass - class kkpatched_AttentionMaskConverter: + class patched_AttentionMaskConverter: """ Patches ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index e39d2b08..270c5028 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -277,8 +277,7 @@ def validate_model( if verbose: print(f"[validate_model] new inputs: {string_type(data['inputs'])}") print( - f"[validate_model] new dynnamic_shapes: " - f"{_ds_clean(data['dynamic_shapes'])}" + f"[validate_model] new dynamic_hapes: {_ds_clean(data['dynamic_shapes'])}" ) if not empty(dtype): From c058fd18f681f47ca9d120a108f30b85f930f662 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 3 Apr 2025 17:25:44 +0200 Subject: [PATCH 5/5] rename --- _unittests/ut_export/test_dynamic_shapes.py | 70 ++++++++++++++------- onnx_diagnostic/export/dynamic_shapes.py | 30 +++++++-- onnx_diagnostic/torch_models/test_helper.py | 4 +- 3 files changed, 75 insertions(+), 29 deletions(-) diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index 71980b64..8eb025ff 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -459,21 +459,31 @@ def test_couple_input_ds_0(self): T3x4 = torch.rand((3, 4)) T3x1 = torch.rand((3, 1)) Cls = CoupleInputsDynamicShapes - self.assertEmpty(Cls((T3x4,), {}, ({0: "batch"},)).invalid_paths()) - self.assertEmpty(Cls((T3x1,), {}, ({0: "batch"},)).invalid_paths()) - self.assertEmpty(Cls((), {"A": T3x1}, {"A": {0: "batch"}}).invalid_paths()) - self.assertEmpty(Cls((), {"A": T3x4}, {"A": {0: "batch"}}).invalid_paths()) + self.assertEmpty(Cls((T3x4,), {}, ({0: "batch"},)).invalid_dimensions_for_export()) + self.assertEmpty(Cls((T3x1,), {}, ({0: "batch"},)).invalid_dimensions_for_export()) + self.assertEmpty( + Cls((), {"A": T3x1}, {"A": {0: "batch"}}).invalid_dimensions_for_export() + ) + self.assertEmpty( + Cls((), {"A": T3x4}, {"A": {0: "batch"}}).invalid_dimensions_for_export() + ) T1x4 = torch.rand((1, 4)) T1x1 = torch.rand((1, 1)) Cls = CoupleInputsDynamicShapes - self.assertEqual(({0: "d=[1]"},), Cls((T1x4,), {}, ({0: "batch"},)).invalid_paths()) - self.assertEqual(({0: "d=[1]"},), Cls((T1x1,), {}, ({0: "batch"},)).invalid_paths()) self.assertEqual( - {"A": {0: "d=[1]"}}, Cls((), {"A": T1x1}, {"A": {0: "batch"}}).invalid_paths() + ({0: "d=[1]"},), Cls((T1x4,), {}, ({0: "batch"},)).invalid_dimensions_for_export() ) self.assertEqual( - {"A": {0: "d=[1]"}}, Cls((), {"A": T1x4}, {"A": {0: "batch"}}).invalid_paths() + ({0: "d=[1]"},), Cls((T1x1,), {}, ({0: "batch"},)).invalid_dimensions_for_export() + ) + self.assertEqual( + {"A": {0: "d=[1]"}}, + Cls((), {"A": T1x1}, {"A": {0: "batch"}}).invalid_dimensions_for_export(), + ) + self.assertEqual( + {"A": {0: "d=[1]"}}, + Cls((), {"A": T1x4}, {"A": {0: "batch"}}).invalid_dimensions_for_export(), ) def test_couple_input_ds_1(self): @@ -483,9 +493,12 @@ def test_couple_input_ds_1(self): ds_batch_seq = {0: "batch", 1: "seq"} args = (T3x4, T3x1) Cls = CoupleInputsDynamicShapes - self.assertEqual(None, Cls(args, {}, (ds_batch, ds_batch)).invalid_paths()) self.assertEqual( - (None, {1: "d=[1]"}), Cls(args, {}, (ds_batch, ds_batch_seq)).invalid_paths() + None, Cls(args, {}, (ds_batch, ds_batch)).invalid_dimensions_for_export() + ) + self.assertEqual( + (None, {1: "d=[1]"}), + Cls(args, {}, (ds_batch, ds_batch_seq)).invalid_dimensions_for_export(), ) def test_couple_input_ds_2(self): @@ -495,10 +508,15 @@ def test_couple_input_ds_2(self): ds_batch_seq = {0: "batch", 1: "seq"} kwargs = {"A": T3x4, "B": T3x1} Cls = CoupleInputsDynamicShapes - self.assertEqual(None, Cls((), kwargs, {"A": ds_batch, "B": ds_batch}).invalid_paths()) + self.assertEqual( + None, + Cls((), kwargs, {"A": ds_batch, "B": ds_batch}).invalid_dimensions_for_export(), + ) self.assertEqual( {"B": {1: "d=[1]"}}, - Cls((), kwargs, {"A": ds_batch, "B": ds_batch_seq}).invalid_paths(), + Cls( + (), kwargs, {"A": ds_batch, "B": ds_batch_seq} + ).invalid_dimensions_for_export(), ) def test_couple_input_ds_3(self): @@ -509,11 +527,16 @@ def test_couple_input_ds_3(self): kwargs = {"A": T3x4, "B": (T3x1, T3x1)} Cls = CoupleInputsDynamicShapes self.assertEqual( - None, Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths() + None, + Cls( + (), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)} + ).invalid_dimensions_for_export(), ) self.assertEqual( {"B": (None, {1: "d=[1]"})}, - Cls((), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(), + Cls( + (), kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} + ).invalid_dimensions_for_export(), ) def test_couple_input_ds_cache(self): @@ -540,7 +563,7 @@ def test_couple_input_ds_cache(self): (), kwargs, {"A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch, ds_batch, ds_batch])}, - ).invalid_paths(), + ).invalid_dimensions_for_export(), ) self.assertEqual( {"B": (None, [None, {2: "d=[1]"}, None, {2: "d=[1]"}])}, @@ -551,7 +574,7 @@ def test_couple_input_ds_cache(self): "A": ds_batch, "B": (ds_batch, [ds_batch, ds_batch_seq, ds_batch, ds_batch_seq]), }, - ).invalid_paths(), + ).invalid_dimensions_for_export(), ) def test_couple_input_ds_args_kwargs_0(self): @@ -564,17 +587,22 @@ def test_couple_input_ds_args_kwargs_0(self): kwargs = {"A": T3x4, "B": (T3x1, T3x1)} Cls = CoupleInputsDynamicShapes self.assertEqual( - None, Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}).invalid_paths() + None, + Cls( + args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)} + ).invalid_dimensions_for_export(), ) self.assertEqual( None, Cls( args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch)}, args_names=["X"] - ).invalid_paths(), + ).invalid_dimensions_for_export(), ) self.assertEqual( {"B": (None, {1: "d=[1]"})}, - Cls(args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)}).invalid_paths(), + Cls( + args, kwargs, {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} + ).invalid_dimensions_for_export(), ) def test_couple_input_ds_args_kwargs_1(self): @@ -593,7 +621,7 @@ def test_couple_input_ds_args_kwargs_1(self): kwargs, {"X": ds_batch, "A": ds_batch, "B": (ds_batch, ds_batch)}, args_names=["X"], - ).invalid_paths(), + ).invalid_dimensions_for_export(), ) self.assertEqual( {"X": {1: "d=[1]"}, "B": (None, {1: "d=[1]"})}, @@ -602,7 +630,7 @@ def test_couple_input_ds_args_kwargs_1(self): kwargs, {"X": ds_batch_seq, "A": ds_batch, "B": (ds_batch, ds_batch_seq)}, args_names=["X"], - ).invalid_paths(), + ).invalid_dimensions_for_export(), ) def test_couple_input_ds_replace_string(self): diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index e8db99ca..619f1077 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -71,7 +71,7 @@ def forward(self, x, y): ds = mi.guess_dynamic_shapes() pprint.pprint(ds) - **and and kwargs** + **args and kwargs** .. runpython:: :showcode: @@ -449,7 +449,10 @@ def validate_inputs_for_export( if len(self.inputs) == 1: return [] dyn_shapes = self.guess_dynamic_shapes() - return [CoupleInputsDynamicShapes(*i, dyn_shapes).invalid_paths() for i in self.inputs] + return [ + CoupleInputsDynamicShapes(*i, dyn_shapes).invalid_dimensions_for_export() + for i in self.inputs + ] class CoupleInputsDynamicShapes: @@ -530,7 +533,7 @@ def _replace_string_dim_tensor(cls, inputs, ds, value=None): new_ds[i] = value return new_ds - def invalid_paths(self): + def invalid_dimensions_for_export(self): """ Tells if the inputs are valid based on the dynamic shapes definition. The method assumes that all custom classes can be serialized. @@ -538,7 +541,8 @@ def invalid_paths(self): calling this method if the inputs contains such classes. The function checks that a dynamic dimension does not receive a value - of 0 or 1. It returns a list of invalid path. + of 0 or 1. It returns the unexpected values in the same structure as + the given dynamic shapes. Example: @@ -554,7 +558,23 @@ def invalid_paths(self): ds_batch_seq = {0: "batch", 1: "seq"} kwargs = {"A": T3x4, "B": (T3x1, T3x1)} ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} - print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_paths()) + print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export()) + + In case it works, it shows: + + .. runpython:: + :showcode: + + import torch + from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes + + T3x2 = torch.rand((3, 2)) + T3x4 = torch.rand((3, 4)) + ds_batch = {0: "batch"} + ds_batch_seq = {0: "batch", 1: "seq"} + kwargs = {"A": T3x4, "B": (T3x2, T3x2)} + ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} + print(CoupleInputsDynamicShapes((), kwargs, ds).invalid_dimensions_for_export()) """ return self._generic_walker(self._valid_shapes_tensor) diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index 270c5028..dfe1d9c3 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -276,9 +276,7 @@ def validate_model( ) if verbose: print(f"[validate_model] new inputs: {string_type(data['inputs'])}") - print( - f"[validate_model] new dynamic_hapes: {_ds_clean(data['dynamic_shapes'])}" - ) + print(f"[validate_model] new dynamic_hapes: {_ds_clean(data['dynamic_shapes'])}") if not empty(dtype): if isinstance(dtype, str):