diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index e25c41b6..a4026994 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -84,14 +84,14 @@ jobs: - name: Check for errors and warnings run: | - if [[ $(grep ERROR doc.txt | grep -v 'Unknown target name: "l_shape"' | grep -v 'Unknown target name: "l_x"') ]]; then + if [[ $(grep ERROR doc.txt | grep -v 'l-plot-tiny-llm-export') ]]; then echo "Documentation produces errors." - grep ERROR doc.txt + grep ERROR doc.txt | grep -v 'l-plot-tiny-llm-export' exit 1 fi - if [[ $(grep WARNING doc.txt) ]]; then + if [[ $(grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export') ]]; then echo "Documentation produces warnings." - grep WARNING doc.txt + grep WARNING doc.txt | grep -v 'l-plot-tiny-llm-export' exit 1 fi diff --git a/_doc/api/index.rst b/_doc/api/index.rst index fa6c9afe..35d2319c 100644 --- a/_doc/api/index.rst +++ b/_doc/api/index.rst @@ -7,6 +7,8 @@ API of onnx_diagnostic :maxdepth: 1 :caption: submodules + torch_export_patches/index + torch_models/index .. toctree:: :maxdepth: 1 diff --git a/_doc/api/torch_export_patches/index.rst b/_doc/api/torch_export_patches/index.rst new file mode 100644 index 00000000..38910161 --- /dev/null +++ b/_doc/api/torch_export_patches/index.rst @@ -0,0 +1,6 @@ +onnx_diagnostic.torch_export_patches +==================================== + +.. automodule:: onnx_diagnostic.torch_export_patches + :members: + :no-undoc-members: diff --git a/_doc/api/torch_models/index.rst b/_doc/api/torch_models/index.rst new file mode 100644 index 00000000..291c1c6d --- /dev/null +++ b/_doc/api/torch_models/index.rst @@ -0,0 +1,12 @@ +onnx_diagnostic.torch_models +============================ + +.. toctree:: + :maxdepth: 1 + :caption: submodules + + llms + +.. automodule:: onnx_diagnostic.torch_models + :members: + :no-undoc-members: diff --git a/_doc/api/torch_models/llms.rst b/_doc/api/torch_models/llms.rst new file mode 100644 index 00000000..c1ce1317 --- /dev/null +++ b/_doc/api/torch_models/llms.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.torch_models.llms +================================= + +.. automodule:: onnx_diagnostic.torch_models.llms + :members: + :no-undoc-members: diff --git a/_doc/conf.py b/_doc/conf.py index 9352bbf1..ab4972eb 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -113,6 +113,7 @@ ("py:class", "torch.utils._pytree.Context"), ("py:class", "torch.utils._pytree.KeyEntry"), ("py:class", "torch.utils._pytree.TreeSpec"), + ("py:class", "transformers.LlamaConfig"), ("py:class", "transformers.cache_utils.Cache"), ("py:class", "transformers.cache_utils.DynamicCache"), ("py:class", "transformers.cache_utils.MambaCache"), @@ -154,7 +155,7 @@ } if int(os.environ.get("UNITTEST_GOING", "0")): - sphinx_gallery_conf["ignore_pattern"] = ".*((_oe_)|(dort)|(draft_mode)).*" + sphinx_gallery_conf["ignore_pattern"] = ".*((tiny_llm)|(dort)|(draft_mode)).*" elif pv.Version(torch.__version__) < pv.Version("2.8"): sphinx_gallery_conf["ignore_pattern"] = ".*((_oe_)|(dort)|(draft_mode)).*" diff --git a/_doc/examples/plot_export_tiny_llm.py b/_doc/examples/plot_export_tiny_llm.py index 0a70f44e..afb4c697 100644 --- a/_doc/examples/plot_export_tiny_llm.py +++ b/_doc/examples/plot_export_tiny_llm.py @@ -1,4 +1,6 @@ """ +.. _l-plot-tiny-llm-export: + Export LLM with dynamic shapes ============================== @@ -15,11 +17,11 @@ We use the dummy example from the model page. """ -from typing import Any, Dict +import copy import torch import transformers from onnx_diagnostic.helpers import string_type -from onnx_diagnostic.cache_helpers import make_dynamic_cache +from onnx_diagnostic.torch_models.llms import get_tiny_llm MODEL_NAME = "arnir0/Tiny-LLM" @@ -30,21 +32,6 @@ # We rewrite the forward method to print the cache dimension. -def string_inputs(args, kwargs): - def _cache(a): - if len(a.key_cache): - return f"n_caches={len(a.key_cache)}, shape={a.key_cache[0].shape}" - return f"n_caches={len(a.key_cache)}" - - for a in args: - if isinstance(a, transformers.cache_utils.DynamicCache): - return _cache(a) - for k, a in kwargs.items(): - if isinstance(a, transformers.cache_utils.DynamicCache): - return f"{k}={_cache(a)}" - return "no_cache" - - def _forward_(*args, _f=None, **kwargs): assert _f is not None if not torch.compiler.is_exporting(): @@ -83,100 +70,6 @@ def _forward_(*args, _f=None, **kwargs): # Let's create an untrained model. -def get_tiny_llm( - batch_size: int = 2, - input_cache: bool = True, - common_dynamic_shapes: bool = True, - dynamic_rope: bool = False, - **kwargs, -) -> Dict[str, Any]: - """ - Gets a non initialized model. - - :param batch_size: batch size - :param input_cache: generate data for this iteration with or without cache - :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1`` - :param common_dynamic_shapes: if True returns dynamic shapes as well - :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) - :return: dictionary - """ - import transformers - - config = { - "architectures": ["LlamaForCausalLM"], - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 192, - "initializer_range": 0.02, - "intermediate_size": 1024, - "max_position_embeddings": 1024, - "model_type": "llama", - "num_attention_heads": 2, - "num_hidden_layers": 1, - "num_key_value_heads": 1, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None, - "tie_word_embeddings": False, - "torch_dtype": "float32", - "transformers_version": "4.31.0.dev0", - "use_cache": True, - "vocab_size": 32000, - } - - config.update(**kwargs) - conf = transformers.LlamaConfig(**config) - model = transformers.LlamaForCausalLM(conf) - model.eval() - - # now the inputs - cache_last_dim = 96 - sequence_length = 30 - sequence_length2 = 3 - num_key_value_heads = 1 - max_token_id = config["vocab_size"] - 1 - n_layers = config["num_hidden_layers"] - - batch = torch.export.Dim("batch", min=1, max=1024) - seq_length = torch.export.Dim("seq_length", min=1, max=4096) - cache_length = torch.export.Dim("cache_length", min=1, max=4096) - - shapes = { - "input_ids": {0: batch, 1: seq_length}, - "attention_mask": { - 0: batch, - 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length - }, - "past_key_values": [ - [{0: batch, 2: cache_length} for _ in range(n_layers)], - [{0: batch, 2: cache_length} for _ in range(n_layers)], - ], - } - inputs = dict( - input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to( - torch.int64 - ), - attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to( - torch.int64 - ), - past_key_values=make_dynamic_cache( - [ - ( - torch.randn( - batch_size, num_key_value_heads, sequence_length, cache_last_dim - ), - torch.randn( - batch_size, num_key_value_heads, sequence_length, cache_last_dim - ), - ) - for i in range(n_layers) - ] - ), - ) - return dict(inputs=inputs, model=model, dynamic_shapes=shapes) - - # %% # Let's get the model, inputs and dynamic shapes. @@ -187,9 +80,25 @@ def get_tiny_llm( experiment["dynamic_shapes"], ) +# %% +# Before we run it, we make a copy of the inputs as the cache +# get modified by the execution. Then it is no longer valid +# associated with the previous input_ids and mask. +cloned_inputs = copy.deepcopy(inputs) + + # %% Let's run it. -expected_output = model(**inputs) -print("result type", type(expected_output)) +print("input type", string_type(inputs, with_shape=True)) + +expected_output = untrained_model(**inputs) + + +print("input after the execution", string_type(inputs, with_shape=True)) +print("result type", string_type(expected_output, with_shape=True)) + +ep = torch.export.export( + untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes +) # %% # It works. @@ -199,7 +108,7 @@ def get_tiny_llm( try: ep = torch.export.export( - untrained_model, (), inputs, dynamic_shapes=dynamic_shapes, strict=False + untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes ) print("It worked:") print(ep) @@ -217,7 +126,7 @@ def get_tiny_llm( # Let's use the same dummy inputs but we use the downloaded model. try: - ep = torch.export.export(model, (), inputs, dynamic_shapes=dynamic_shapes, strict=False) + ep = torch.export.export(model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes) print("It worked:") print(ep) except Exception as e: diff --git a/_doc/examples/plot_exporter_exporter_dynamic_shapes_auto.py b/_doc/examples/plot_export_with_dynamic_shapes_auto.py similarity index 100% rename from _doc/examples/plot_exporter_exporter_dynamic_shapes_auto.py rename to _doc/examples/plot_export_with_dynamic_shapes_auto.py diff --git a/_doc/galleries.rst b/_doc/galleries.rst deleted file mode 100644 index a7d9d0de..00000000 --- a/_doc/galleries.rst +++ /dev/null @@ -1,7 +0,0 @@ -Galleries of Examples and Recipes -================================= - -.. toctree:: - :maxdepth: 2 - - auto_examples/index diff --git a/_doc/index.rst b/_doc/index.rst index 65fe6ceb..7bb563f2 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -36,7 +36,7 @@ Source are `sdpython/onnx-diagnostic :caption: Contents api/index - galleries + auto_examples/index .. toctree:: :maxdepth: 1 diff --git a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py new file mode 100644 index 00000000..8b92a760 --- /dev/null +++ b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py @@ -0,0 +1,120 @@ +import unittest +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + requires_torch, + requires_transformers, + skipif_ci_windows, + ignore_warnings, +) +from onnx_diagnostic.helpers import string_type +from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( + bypass_export_some_errors, +) + + +class TestOnnxExportErrors(ExtTestCase): + @requires_transformers("4.49.999") + @skipif_ci_windows("not working on Windows") + @ignore_warnings(UserWarning) + def test_pytree_flatten_mamba_cache(self): + import torch + import torch.utils._pytree as py_pytree + from transformers.cache_utils import MambaCache + + class _config: + def __init__(self): + self.intermediate_size = 8 + self.state_size = 16 + self.conv_kernel = 32 + self.num_hidden_layers = 64 + self.dtype = torch.float16 + + cache = MambaCache(_config(), max_batch_size=1, device="cpu") + + with bypass_export_some_errors(): + values, spec = py_pytree.tree_flatten(cache) + cache2 = py_pytree.tree_unflatten(values, spec) + self.assertEqual(cache.dtype, cache2.dtype) + self.assertEqual(cache.max_batch_size, cache2.max_batch_size) + self.assertEqual(cache.intermediate_size, cache2.intermediate_size) + self.assertEqual(cache.ssm_state_size, cache2.ssm_state_size) + self.assertEqual(cache.conv_kernel_size, cache2.conv_kernel_size) + self.assertEqualArrayAny(cache.conv_states, cache2.conv_states) + self.assertEqualArrayAny(cache.ssm_states, cache2.ssm_states) + + @requires_transformers("4.43") + @requires_torch("2.7") + @skipif_ci_windows("not working on Windows") + @ignore_warnings(UserWarning) + def test_exportable_mamba_cache(self): + import torch + from transformers.models.mamba.modeling_mamba import MambaCache + + class _config: + def __init__(self): + self.intermediate_size = 8 + self.state_size = 16 + self.conv_kernel = 32 + self.num_hidden_layers = 64 + self.dtype = torch.float16 + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor, cache: MambaCache): + x1 = cache.ssm_states[0] + x + x2 = cache.conv_states[0][:, :, ::2] + x1 + return x2 + + cache = MambaCache(_config(), max_batch_size=1, device="cpu") + self.assertEqual( + string_type(cache), "MambaCache(conv_states=[T10r3,...], ssm_states=[T10r3,...])" + ) + x = torch.ones(2, 8, 16).to(torch.float16) + model = Model() + model(x, cache) + + with bypass_export_some_errors(): + cache = MambaCache(_config(), max_batch_size=1, device="cpu") + torch.export.export(Model(), (x, cache)) + + @requires_transformers("4.49.999") + @skipif_ci_windows("not working on Windows") + @ignore_warnings(UserWarning) + def test_exportable_mamba_cache_dynamic(self): + import torch + from transformers.models.mamba.modeling_mamba import MambaCache + + class _config: + def __init__(self): + self.intermediate_size = 8 + self.state_size = 16 + self.conv_kernel = 32 + self.num_hidden_layers = 2 + self.dtype = torch.float16 + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor, cache: MambaCache): + x1 = cache.ssm_states[0] + x + x2 = cache.conv_states[0][:, :, ::2] + x1 + return x2 + + cache = MambaCache(_config(), max_batch_size=1, device="cpu") + self.assertEqual( + string_type(cache), + "MambaCache(conv_states=#2[T10r3,T10r3], ssm_states=#2[T10r3,T10r3])", + ) + x = torch.ones(2, 8, 16).to(torch.float16) + model = Model() + model(x, cache) + DYN = torch.export.Dim.DYNAMIC + + with bypass_export_some_errors(): + cache = MambaCache(_config(), max_batch_size=1, device="cpu") + torch.export.export( + Model(), + (x, cache), + dynamic_shapes=({0: DYN}, [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]), + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_models/test_llms.py b/_unittests/ut_torch_models/test_llms.py new file mode 100644 index 00000000..8814cb1c --- /dev/null +++ b/_unittests/ut_torch_models/test_llms.py @@ -0,0 +1,38 @@ +import unittest +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings +from onnx_diagnostic.torch_models.llms import get_tiny_llm +from onnx_diagnostic.helpers import string_type +from onnx_diagnostic.torch_export_patches import bypass_export_some_errors + + +class TestLlms(ExtTestCase): + def test_get_tiny_llm(self): + data = get_tiny_llm() + model, inputs = data["model"], data["inputs"] + self.assertIn("DynamicCache", string_type(inputs)) + model(**inputs) + + @ignore_warnings(UserWarning) + def test_export_tiny_llm_1(self): + data = get_tiny_llm() + model, inputs = data["model"], data["inputs"] + ep = torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"] + ) + assert ep + print(ep) + + @ignore_warnings(UserWarning) + def test_export_tiny_llm_2_bypassed(self): + data = get_tiny_llm() + model, inputs = data["model"], data["inputs"] + with bypass_export_some_errors(): + ep = torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"] + ) + assert ep + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index 62354666..d27b112f 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -5,7 +5,8 @@ import subprocess import time from onnx_diagnostic import __file__ as onnx_diagnostic_file -from onnx_diagnostic.ext_test_case import ExtTestCase, is_windows +from onnx_diagnostic.ext_test_case import ExtTestCase, is_windows, has_transformers + VERBOSE = 0 ROOT = os.path.realpath(os.path.abspath(os.path.join(onnx_diagnostic_file, "..", ".."))) @@ -69,6 +70,13 @@ def add_test_methods(cls): continue reason = None + if ( + not reason + and name in {"plot_export_tiny_llm.py"} + and not has_transformers("4.51") + ): + reason = "transformers<4.51" + if reason: @unittest.skip(reason) diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index 3be0631a..ffa37269 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -774,13 +774,21 @@ def requires_sklearn(version: str, msg: str = "") -> Callable: def has_torch(version: str) -> bool: - "Returns True if torch verions is higher." + "Returns True if torch transformers is higher." import packaging.version as pv import torch return pv.Version(torch.__version__) >= pv.Version(version) +def has_transformers(version: str) -> bool: + "Returns True if transformers version is higher." + import packaging.version as pv + import transformers + + return pv.Version(transformers.__version__) >= pv.Version(version) + + def requires_torch(version: str, msg: str = "") -> Callable: """Skips a unit test if :epkg:`pytorch` is not recent enough.""" import packaging.version as pv diff --git a/onnx_diagnostic/torch_export_patches/__init__.py b/onnx_diagnostic/torch_export_patches/__init__.py new file mode 100644 index 00000000..ff978ee3 --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/__init__.py @@ -0,0 +1,4 @@ +from .onnx_export_errors import ( + bypass_export_some_errors, + register_additional_serialization_functions, +) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py new file mode 100644 index 00000000..0fbc0da6 --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -0,0 +1,494 @@ +import contextlib +import pprint +from typing import Any, Callable, Dict +from .onnx_export_serialization import ( + flatten_with_keys_dynamic_cache, + flatten_dynamic_cache, + unflatten_dynamic_cache, + unflatten_pached_dynamic_cache, + flatten_mamba_cache, + flatten_with_keys_mamba_cache, + unflatten_mamba_cache, +) + + +def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: + # Cache serialization: to be moved into appropriate packages + import torch + + try: + from transformers.cache_utils import DynamicCache + except ImportError: + DynamicCache = None + + try: + from transformers.cache_utils import MambaCache + except ImportError: + MambaCache = None + + # MambaCache + unregistered_mamba_cache = True + if MambaCache is not None and MambaCache in torch.utils._pytree.SUPPORTED_NODES: + if verbose > 1: + print(f"[_register_cache_serialization] {MambaCache} already registered") + # It is already registered because bypass_export_some_errors was called + # within a section already calling bypass_export_some_errors or transformers + # has updated its code to do it. + # No need to register and unregister then. + unregistered_mamba_cache = False + else: + if verbose: + print("[_register_cache_serialization] register MambaCache") + torch.utils._pytree.register_pytree_node( + MambaCache, + flatten_mamba_cache, + unflatten_mamba_cache, + serialized_type_name=f"{MambaCache.__module__}.{MambaCache.__name__}", + flatten_with_keys_fn=flatten_with_keys_mamba_cache, + ) + + # DynamicCache + unregistered_dynamic_cache = True + if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES: + if verbose > 1: + print(f"[_register_cache_serialization] {DynamicCache} already registered") + unregistered_dynamic_cache = False + else: + if verbose: + print("[_register_cache_serialization] register DynamicCache") + torch.utils._pytree.register_pytree_node( + DynamicCache, + flatten_dynamic_cache, + unflatten_dynamic_cache, + serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", + flatten_with_keys_fn=flatten_with_keys_dynamic_cache, + ) + torch.fx._pytree.register_pytree_flatten_spec( + DynamicCache, lambda x, _: [x.key_cache, x.value_cache] + ) + + # check + from ..cache_helpers import make_dynamic_cache + + cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) + values, spec = torch.utils._pytree.tree_flatten(cache) + cache2 = torch.utils._pytree.tree_unflatten(values, spec) + # torch.fx._pytree.tree_flatten(cache) + assert len(cache2.key_cache) == 1 + + # patched_DynamicCache + from .patches.patch_transformers import patched_DynamicCache + + unregistered_patched_dynamic_cache = True + if ( + patched_DynamicCache is not None + and patched_DynamicCache in torch.utils._pytree.SUPPORTED_NODES + ): + if verbose > 1: + print(f"[_register_cache_serialization] {patched_DynamicCache} already registered") + unregistered_patched_dynamic_cache = False + else: + if verbose: + print("[_register_cache_serialization] register patched_DynamicCache") + + torch.utils._pytree.register_pytree_node( + patched_DynamicCache, + flatten_dynamic_cache, + unflatten_pached_dynamic_cache, + serialized_type_name=f"{patched_DynamicCache.__module__}.{patched_DynamicCache.__name__}", + flatten_with_keys_fn=flatten_with_keys_dynamic_cache, + ) + torch.fx._pytree.register_pytree_flatten_spec( + patched_DynamicCache, lambda x, _: [x.key_cache, x.value_cache] + ) + + return dict( + DynamicCache=unregistered_dynamic_cache, + MambaCache=unregistered_mamba_cache, + patched_DynamicCache=unregistered_patched_dynamic_cache, + ) + + +def _unregister(cls: type, verbose: int = 0): + import optree + import torch + + # torch.fx._pytree._deregister_pytree_flatten_spec(cls) + if cls in torch.fx._pytree.SUPPORTED_NODES: + del torch.fx._pytree.SUPPORTED_NODES[cls] + if cls in torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH: + del torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH[cls] + if hasattr(torch.utils._pytree, "_deregister_pytree_node"): + # torch >= 2.7 + torch.utils._pytree._deregister_pytree_node(cls) + optree.unregister_pytree_node(cls, namespace="torch") + assert cls not in torch.utils._pytree.SUPPORTED_NODES, ( + f"{cls} was not successfull unregistered " + f"from torch.utils._pytree.SUPPORTED_NODES=" + f"{pprint.pformat(list(torch.utils._pytree.SUPPORTED_NODES))}" + ) + if verbose: + print(f"[_unregister_cache_serialization] unregistered {cls.__name__}") + + +def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): + + if undo.get("MambaCache", False): + from transformers.cache_utils import MambaCache + + _unregister(MambaCache, verbose) + elif verbose > 1: + print("[_unregister_cache_serialization] skip unregister MambaCache") + + if undo.get("DynamicCache", False): + from transformers.cache_utils import DynamicCache + + _unregister(DynamicCache, verbose) + elif verbose > 1: + print("[_unregister_cache_serialization] skip unregister DynamicCache") + + if undo.get("patched_DynamicCache", False): + from .patches.patch_transformers import patched_DynamicCache + + _unregister(patched_DynamicCache, verbose) + elif verbose > 1: + print("[_unregister_cache_serialization] skip unregister patched_DynamicCache") + + +@contextlib.contextmanager +def register_additional_serialization_functions( + verbose: int = 0, replace_dynamic_cache: bool = False +) -> Callable: + """The necessary modification to run the fx Graph.""" + fct_callable = replacement_before_exporting if replace_dynamic_cache else (lambda x: x) + done = _register_cache_serialization(verbose=verbose) + try: + yield fct_callable + finally: + _unregister_cache_serialization(done, verbose=verbose) + + +@contextlib.contextmanager +def bypass_export_some_errors( + patch_sympy: bool = True, + patch_torch: bool = True, + patch_transformers: bool = False, + replace_dynamic_cache: bool = False, + catch_constraints: bool = True, + verbose: int = 0, + patch: bool = True, +) -> Callable: + """ + Tries to bypass some situations :func:`torch.export.export` does not support. + + :param patch_sympy: fix missing method ``name`` for IntegerConstant + :param patch_torch: patches :epkg:`torch` with supported implementation + :param patch_transformers: patches :epkg:`transformers` with supported implementation + :param replace_dynamic_cache: replaces DynamicCache by a patched class + avoiding issues with the dynamic shapes inferences, + it should be True with LLM using that class and only during the export + :param catch_constraints: catch constraints related to dynamic shapes, + as a result, some dynamic dimension may turn into static ones, + the environment variable ``SKIP_SOLVE_CONSTRAINTS=0`` + can be put to stop at that stage. + :param patch: if False, disable all patches except the registration of + serialization function + + The list of available patches. + + * ``torch.jit.isinstance`` + * ``torch._dynamo.mark_static_address`` + * ``torch._subclasses.fake_impls.infer_size`` + * fix missing method ``name`` for ``sympy.S.IntegerConstant`` + * ``AttentionMaskConverter._make_causal_mask`` + * Serialialization of ``MambaCache`` (in :epkg:`transformers`) + * Serialialization of ``DynamicCache`` (in :epkg:`transformers`) + * reduce errors due to shape inference + * replaces :class:`transformers.cache_utils.DynamicCache` with + :class:`patched_DynamicCache + ` + + Serialization issues happen when a module takes one input or output + has a type :func:`torch.export.export` cannot serialize. + + Examples: + + :: + + with bypass_export_some_errors( + patch_transformers=True, + replace_dynamic_cache=True, + ) as modificator: + inputs = modificator(inputs) + onx = to_onnx(..., inputs, ...) + + :: + + with bypass_export_some_errors( + patch_transformers=True, + replace_dynamic_cache=True, + ) as modificator: + inputs = modificator(inputs) + onx = torch.onnx.export(..., inputs, ...) + + It can be used as well to fix the torch export: + + :: + + with bypass_export_some_errors( + patch_transformers=True, + replace_dynamic_cache=True, + ) as modificator: + inputs = modificator(inputs) + ep = torch.export.export(..., inputs, ...) + + When running the model through the exported program, only the + serialization functions need to be restored: + + :: + + with register_additional_serialization_functions() as modificator: + inputs = modificator(inputs) + ep = torch.export.export(..., inputs, ...) + + When exporting a model with a cache, the following error message + may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``. + It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`. + """ + if not patch: + fct_callable = replacement_before_exporting if replace_dynamic_cache else (lambda x: x) + done = _register_cache_serialization(verbose=verbose) + try: + yield fct_callable + finally: + _unregister_cache_serialization(done, verbose=verbose) + else: + import torch + import torch._export.non_strict_utils # produce_guards_and_solve_constraints + import torch.jit + + if verbose: + print( + "[bypass_export_some_errors] replace torch.jit.isinstance, " + "torch._dynamo.mark_static_address" + ) + + ######## + # caches + ######## + + cache_done = _register_cache_serialization(verbose=verbose) + + ############# + # patch sympy + ############# + + if patch_sympy: + import sympy + + f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None) + + if verbose: + print("[bypass_export_some_errors] patch sympy") + + sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}" + + ############### + # patch pytorch + ############### + + if patch_torch: + from .patches.patch_torch import ( + patched_infer_size, + patched__broadcast_shapes, + _catch_produce_guards_and_solve_constraints, + patch__check_input_constraints_for_graph, + ) + + if verbose: + print("[bypass_export_some_errors] patch pytorch") + + # torch.jit.isinstance + f_jit_isinstance = torch.jit.isinstance + torch.jit.isinstance = isinstance + + # torch._dynamo.mark_static_address + f_mark_static_address = torch._dynamo.mark_static_address + torch._dynamo.mark_static_address = lambda *_, **y_: None + + # torch._subclasses.fake_impls.infer_size + f_infer_size = torch._subclasses.fake_impls.infer_size + torch._subclasses.fake_impls.infer_size = patched_infer_size + + # torch._refs._broadcast_shapes + f__broadcast_shapes = torch._refs._broadcast_shapes + torch._refs._broadcast_shapes = patched__broadcast_shapes + torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes + + # torch._export.non_strict_utils.produce_guards_and_solve_constraints + if catch_constraints: + if verbose: + print("[bypass_export_some_errors] modifies shape constraints") + f_produce_guards_and_solve_constraints = ( + torch._export.non_strict_utils.produce_guards_and_solve_constraints + ) + f__check_input_constraints_for_graph = ( + torch._export.utils._check_input_constraints_for_graph + ) + torch._export.non_strict_utils.produce_guards_and_solve_constraints = ( + lambda *args, **kwargs: _catch_produce_guards_and_solve_constraints( + f_produce_guards_and_solve_constraints, *args, verbose=verbose, **kwargs + ) + ) + torch._export.utils._check_input_constraints_for_graph = ( + lambda *args, **kwargs: patch__check_input_constraints_for_graph( + f__check_input_constraints_for_graph, *args, verbose=verbose, **kwargs + ) + ) + + #################### + # patch transformers + #################### + + if patch_transformers: + import transformers + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + from .patches.patch_transformers import patched_AttentionMaskConverter + + if verbose: + print("[bypass_export_some_errors] patch transformers") + keep__make_causal_mask = AttentionMaskConverter._make_causal_mask + AttentionMaskConverter._make_causal_mask = ( + patched_AttentionMaskConverter._make_causal_mask + ) + + if replace_dynamic_cache: + import transformers + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + from .patches.patch_transformers import patched_DynamicCache + + def raise_assert(): + raise AssertionError("One replacement of DynamicCache was not patched.") + + if verbose: + print("[bypass_export_some_errors] replace DynamicCache") + keep_DynamicCache = transformers.cache_utils.DynamicCache + keep_DynamicCache_init = keep_DynamicCache.__init__ + keep_DynamicCache.__init__ = lambda *args, **kwargs: raise_assert() + + transformers.cache_utils.DynamicCache = patched_DynamicCache + transformers.generation.utils.DynamicCache = patched_DynamicCache + transformers.models.llama.modeling_llama.DynamicCache = patched_DynamicCache + transformers.models.phi.modeling_phi.DynamicCache = patched_DynamicCache + transformers.models.phi3.modeling_phi3.DynamicCache = patched_DynamicCache + + ######## + # export + ######## + + fct_callable = replacement_before_exporting if replace_dynamic_cache else (lambda x: x) + + if verbose: + print("[bypass_export_some_errors] done patching") + + try: + yield fct_callable + finally: + ####### + # sympy + ####### + + if verbose: + print("[bypass_export_some_errors] remove patches") + + if patch_sympy: + + # tracked by https://github.com/pytorch/pytorch/issues/143494 + if f_sympy_name: + sympy.core.numbers.IntegerConstant.name = f_sympy_name + else: + delattr(sympy.core.numbers.IntegerConstant, "name") + + if verbose: + print("[bypass_export_some_errors] restored sympy functions") + + ####### + # torch + ####### + + if patch_torch: + # this should disappear when torch.jit is removed + torch.jit.isinstance = f_jit_isinstance + torch._dynamo.mark_static_address = f_mark_static_address + # tracked by https://github.com/pytorch/pytorch/issues/143495 + torch._subclasses.fake_impls.infer_size = f_infer_size + torch._refs._broadcast_shapes = f__broadcast_shapes + torch._meta_registrations._broadcast_shapes = f__broadcast_shapes + + if verbose: + print("[bypass_export_some_errors] restored pytorch functions") + + if catch_constraints: + # to catch or skip dynamic_shapes issues + torch._export.non_strict_utils.produce_guards_and_solve_constraints = ( + f_produce_guards_and_solve_constraints + ) + torch._export.utils._check_input_constraints_for_graph = ( + f__check_input_constraints_for_graph + ) + if verbose: + print("[bypass_export_some_errors] restored shape constraints") + + ############## + # transformers + ############## + + if patch_transformers: + AttentionMaskConverter._make_causal_mask = keep__make_causal_mask + if verbose: + print("[bypass_export_some_errors] restored transformer") + + if replace_dynamic_cache: + keep_DynamicCache.__init__ = keep_DynamicCache_init + transformers.cache_utils.DynamicCache = keep_DynamicCache + transformers.generation.utils.DynamicCache = keep_DynamicCache + transformers.models.llama.modeling_llama.DynamicCache = keep_DynamicCache + transformers.models.phi.modeling_phi.DynamicCache = keep_DynamicCache + transformers.models.phi3.modeling_phi3.DynamicCache = keep_DynamicCache + if verbose: + print("[bypass_export_some_errors] restored DynamicCache") + + ######## + # caches + ######## + + _unregister_cache_serialization(cache_done, verbose=verbose) + + +def replacement_before_exporting(args: Any) -> Any: + """ + Does replacements on the given inputs such replacing + :class:`transformers.cache_utils.DynamicCache` by + :class:`onnx_diagnostic.torch_export_patches.patches.patch_transformers.patched_DynamicCache`. + """ + if args is None: + return None + if isinstance(args, (int, float)): + return args + if isinstance(args, dict): + return {k: replacement_before_exporting(v) for k, v in args.items()} + if isinstance(args, tuple): + return tuple(replacement_before_exporting(v) for v in args) + if isinstance(args, list): + return [replacement_before_exporting(v) for v in args] + + if args.__class__.__name__ == "DynamicCache": + # Do not use isinstance, the class may have been replaced. + from .patches.patch_transformers import patched_DynamicCache + + patched = patched_DynamicCache() + for k in ["_seen_tokens", "key_cache", "value_cache"]: + setattr(patched, k, getattr(args, k)) + return patched + + return args diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py new file mode 100644 index 00000000..86d5d2fe --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -0,0 +1,149 @@ +from typing import Any, Dict, List, Tuple +import torch +import transformers + +############ +# MambaCache +############ + + +# self.conv_states: torch.Tensor = torch.zeros( +# config.num_hidden_layers, +# self.max_batch_size, +# self.intermediate_size, +# self.conv_kernel_size, +# device=device, +# dtype=dtype, +# ) +# self.ssm_states: torch.Tensor = torch.zeros( +# config.num_hidden_layers, +# self.max_batch_size, +# self.intermediate_size, +# self.ssm_state_size, +# device=device, +# dtype=dtype, +# ) +def flatten_mamba_cache( + mamba_cache: transformers.cache_utils.MambaCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" + flat = [ + (k, getattr(mamba_cache, k)) + for k in [ + # "max_batch_size", # new in transformers==4.47 + # "intermediate_size", + # "ssm_state_size", + # "conv_kernel_size", + "conv_states", + "ssm_states", + ] + if hasattr(mamba_cache, k) + ] + return [f[1] for f in flat], [f[0] for f in flat] + + +def unflatten_mamba_cache( + values: List[Any], + context: torch.utils._pytree.Context, + output_type=None, +) -> transformers.cache_utils.MambaCache: + """Restores a :class:`transformers.cache_utils.MambaCache` from python objects.""" + conv_states, ssm_states = values + + class _config: + def __init__(self): + if isinstance(conv_states, list): + self.intermediate_size = conv_states[0].shape[1] + self.state_size = ssm_states[0].shape[2] + self.conv_kernel = conv_states[0].shape[2] + self.num_hidden_layers = len(conv_states) + else: + self.intermediate_size = conv_states.shape[2] + self.state_size = ssm_states.shape[3] + self.conv_kernel = conv_states.shape[3] + self.num_hidden_layers = conv_states.shape[0] + + from transformers.cache_utils import MambaCache + + cache = MambaCache( + _config(), + max_batch_size=1, + dtype=values[-1][0].dtype, + device="cpu" if values[-1][0].get_device() < 0 else "cuda", + ) + values = dict(zip(context, values)) + for k, v in values.items(): + setattr(cache, k, v) + return cache + + +def flatten_with_keys_mamba_cache(d: Dict[Any, Any]) -> Tuple[ + List[Tuple[torch.utils._pytree.KeyEntry, Any]], + torch.utils._pytree.Context, +]: + """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" + import torch + + values, context = flatten_mamba_cache(d) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +############## +# DynamicCache +############## + + +def flatten_dynamic_cache( + dynamic_cache: transformers.cache_utils.DynamicCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + flat = [ + (k, getattr(dynamic_cache, k)) + for k in ["key_cache", "value_cache"] + if hasattr(dynamic_cache, k) + ] + return [f[1] for f in flat], [f[0] for f in flat] + + +def flatten_with_keys_dynamic_cache(d: Dict[Any, Any]) -> Tuple[ + List[Tuple[torch.utils._pytree.KeyEntry, Any]], + torch.utils._pytree.Context, +]: + """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + import torch + + values, context = flatten_dynamic_cache(d) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +def unflatten_dynamic_cache( + values: List[Any], + context: torch.utils._pytree.Context, + output_type=None, +) -> transformers.cache_utils.DynamicCache: + """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects.""" + from transformers.cache_utils import DynamicCache + + cache = DynamicCache() + values = dict(zip(context, values)) + for k, v in values.items(): + setattr(cache, k, v) + return cache + + +def unflatten_pached_dynamic_cache( + values: List[Any], + context: torch.utils._pytree.Context, + output_type=None, +) -> transformers.cache_utils.DynamicCache: + """Restores a :class:`patched_DynamicCache + ` + from python objects.""" + + from .patches.patch_transformers import patched_DynamicCache + + cache = patched_DynamicCache() + values = dict(zip(context, values)) + for k, v in values.items(): + setattr(cache, k, v) + return cache diff --git a/onnx_diagnostic/torch_export_patches/patches/__init__.py b/onnx_diagnostic/torch_export_patches/patches/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py new file mode 100644 index 00000000..586cfb3a --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -0,0 +1,148 @@ +import inspect +import os +from typing import Any, Callable, Dict, List, Sequence, Tuple, Union +import torch +from torch._subclasses.fake_tensor import FakeTensorMode + + +def _catch_produce_guards_and_solve_constraints( + previous_function: Callable, + fake_mode: "FakeTensorMode", # noqa: F821 + gm: "torch.fx.GraphModule", # noqa: F821 + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], + equalities_inputs: "EqualityConstraint", # noqa: F821 + original_signature: inspect.Signature, + _is_torch_jit_trace: bool = False, + verbose: int = 0, +): + try: + return previous_function( + fake_mode=fake_mode, + gm=gm, + dynamic_shapes=dynamic_shapes, + equalities_inputs=equalities_inputs, + original_signature=original_signature, + _is_torch_jit_trace=_is_torch_jit_trace, + ) + except Exception as e: + if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")): + raise + if verbose: + print( + f"[_catch_produce_guards_and_solve_constraints] ERROR" + f"produce_guards_and_solve_constraints failed, " + f"use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping\n" + f"fake_mode={fake_mode}\n" + f"dynamic_shapes={dynamic_shapes}\n" + f"equalities_inputs={equalities_inputs}\n" + f"original_signature={original_signature}\n" + f"_is_torch_jit_trace={_is_torch_jit_trace}\n" + f"exc={e}\ngm={gm}" + ) + + +def patch__check_input_constraints_for_graph( + previous_function: Callable, + input_placeholders: list[torch.fx.Node], + flat_args_with_path, + range_constraints, + verbose: int = 0, +) -> None: + try: + return previous_function(input_placeholders, flat_args_with_path, range_constraints) + except Exception as e: + if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")): + raise + if verbose: + print( + f"[_check_input_constraints_for_graph] ERROR" + f"_check_input_constraints_for_graph failed, " + f"use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping\n" + f"input_placeholders={input_placeholders}\n" + f"range_constraints={range_constraints}\n" + f"exc={e}" + ) + + +def patched_infer_size(a, b): + """Patches ``torch._subclasses.fake_impls.infer_size``.""" + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + dimsA = len(a) + dimsB = len(b) + ndim = max(dimsA, dimsB) + expandedSizes = [0] * ndim + for i in range(ndim - 1, -1, -1): + offset = ndim - 1 - i + dimA = dimsA - 1 - offset + dimB = dimsB - 1 - offset + sizeA = a[dimA] if dimA >= 0 else 1 + sizeB = b[dimB] if dimB >= 0 else 1 + + # NB: It is very important to test for broadcasting, before testing + # sizeA == sizeB. This is because the broadcasting tests are likely + # to be statically known (in particular, if sizeA/sizeB is unbacked + # but size-like, we will unsoundly assume they never equal 1), but + # the sizeA == sizeB test may not be statically known. However, once + # we have established that no broadcasting is happening, the + # sizeA == sizeB is now expect_true and we can defer it as a runtime + # assert (this works because Python will return the terminal + # expression of an or statement as-is, without bool()'ing it; if this + # were not the case, we'd need to write this using torch.sym_or() or + # something like that). + try: + b1 = guard_size_oblivious(sizeA == 1) + except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: + b1 = False + try: + b2 = guard_size_oblivious(sizeB == 1) + except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: + b2 = False + try: + b3 = guard_size_oblivious(sizeA == sizeB) + except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: + b3 = False + if b1 or b2 or b3: + expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA + else: + # In this case, the current implementation of torch fails (17/12/2024). + # Try model SmolLM. + expandedSizes[i] = torch.sym_max(sizeA, sizeB) + return tuple(expandedSizes) + + +def patched__broadcast_shapes(*_shapes): + """Patches ``torch._refs._broadcast_shapes``.""" + from functools import reduce + from torch._prims_common import IntLike + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + shapes = tuple( + (x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes) + ) + + # Short-circuits on no input + if len(shapes) == 0: + return None + + # Type checking + # TODO: make common validations available as utils + for shape in shapes: + assert isinstance(shape, Sequence) + + # Computes common shape + common_shape: List[Union[int, torch.SymInt]] = [ + 1, + ] * reduce(max, (len(shape) for shape in shapes)) + for _arg_idx, shape in enumerate(shapes): + for idx in range(-1, -1 - len(shape), -1): + if guard_size_oblivious(common_shape[idx] == 1): + if shape[idx] < 0: + raise ValueError( + "Attempting to broadcast a dimension with negative length!" + ) + common_shape[idx] = shape[idx] + elif guard_size_oblivious(shape[idx] != 1): + common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx]) + + return common_shape diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py new file mode 100644 index 00000000..614abca3 --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -0,0 +1,246 @@ +import sys +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple +import torch +import transformers + + +def _patch_make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, +): + """Patched method.""" + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), + mask, + ], + dim=-1, + ) + + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window - 1 + + context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) + # In this case, the current implementation of torch fails (17/12/2024). + # Try model Phi-3.5-Mini-Instruct. + mask = mask.masked_fill(context_mask, torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +if sys.version_info[:2] <= (3, 11): + + @dataclass + class patched_AttentionMaskConverter: + """ + Patches + ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. + """ + + @staticmethod + def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """Patched method.""" + return _patch_make_causal_mask( + input_ids_shape, dtype, device, past_key_values_length, sliding_window + ) + +else: + + @dataclass + class patched_AttentionMaskConverter: + """ + Patches + ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. + """ + + @staticmethod + def _make_causal_mask( + self, + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """Patched method.""" + return _patch_make_causal_mask( + input_ids_shape, dtype, device, past_key_values_length, sliding_window + ) + + +class patched_DynamicCache: + """ + Removes the dependency on :class:`torch.nn.Module` + from :class:`transformers.cache_utils.DynamicCache`. + """ + + def __init__(self, num_hidden_layers: Optional[int] = None) -> None: + self._seen_tokens = 0 + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx]) + else: + raise KeyError( + f"Cache only has {len(self)} layers, " + f"attempted to access layer with index {layer_idx}" + ) + + def __iter__(self): + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) + + def __len__(self): + return len(self.key_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if key_states is not None: + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append(torch.tensor([])) + self.value_cache.append(torch.tensor([])) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + # elif ( + # len(self.key_cache[layer_idx]) == 0 + # ): # fills previously skipped layers; checking for tensor causes errors + # self.key_cache[layer_idx] = key_states + # self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat( + [self.key_cache[layer_idx], key_states], dim=-2 + ) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=-2 + ) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + if not self.key_cache: + return 0 + assert layer_idx < len( + self.key_cache + ), f"Unexpected layer_idx={layer_idx}, len(key_cache)={len(self.key_cache)}" + return self.key_cache[layer_idx].shape[-2] + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + return self.get_seq_length(layer_idx) + + def get_max_cache_shape(self) -> Optional[int]: + return None + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, + past_key_values: Optional[Tuple[Tuple["torch.Tensor"]]] = None, + num_hidden_layers: Optional[int] = None, + ) -> transformers.cache_utils.DynamicCache: + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def crop(self, max_length: int): + # In case it is negative + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + self._seen_tokens = max_length + for idx in range(len(self.key_cache)): + if self.key_cache[idx] != []: + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + + def batch_split( + self, full_batch_size: int, split_size: int, num_hidden_layers: Optional[int] = None + ) -> List[transformers.cache_utils.DynamicCache]: + out = [] + for i in range(0, full_batch_size, split_size): + current_split = patched_DynamicCache() + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [ + tensor[i : i + split_size] for tensor in self.value_cache + ] + out.append(current_split) + return out + + @classmethod + def from_batch_splits( + cls, + splits: List[transformers.cache_utils.DynamicCache], + num_hidden_layers: Optional[int] = None, + ) -> transformers.cache_utils.DynamicCache: + cache = cls() + for idx in range(len(splits[0])): + key_cache = [ + current.key_cache[idx] for current in splits if current.key_cache[idx] != [] + ] + value_cache = [ + current.value_cache[idx] + for current in splits + if current.value_cache[idx] != [] + ] + if key_cache != []: + layer_keys = torch.cat(key_cache, dim=0) + layer_values = torch.cat(value_cache, dim=0) + cache.update(layer_keys, layer_values, idx) + return cache + + def batch_repeat_interleave(self, repeats: int): + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave( + repeats, dim=0 + ) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave( + repeats, dim=0 + ) + + def batch_select_indices(self, indices: torch.Tensor): + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] diff --git a/onnx_diagnostic/torch_models/__init__.py b/onnx_diagnostic/torch_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/onnx_diagnostic/torch_models/llms.py b/onnx_diagnostic/torch_models/llms.py new file mode 100644 index 00000000..17207ddf --- /dev/null +++ b/onnx_diagnostic/torch_models/llms.py @@ -0,0 +1,96 @@ +from typing import Any, Dict +import torch +import transformers +from ..cache_helpers import make_dynamic_cache + + +def get_tiny_llm( + batch_size: int = 2, + input_cache: bool = True, + dynamic_rope: bool = False, + **kwargs, +) -> Dict[str, Any]: + """ + Gets a non initialized model. + + :param batch_size: batch size + :param input_cache: generate data for this iteration with or without cache + :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) + :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1`` + :return: dictionary + + See :ref:`l-plot-tiny-llm-export` for an example. + """ + config = { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 192, + "initializer_range": 0.02, + "intermediate_size": 1024, + "max_position_embeddings": 1024, + "model_type": "llama", + "num_attention_heads": 2, + "num_hidden_layers": 1, + "num_key_value_heads": 1, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": {"rope_type": "dynamic", "factor": 10.0} if dynamic_rope else None, + "tie_word_embeddings": False, + "torch_dtype": "float32", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + } + + config.update(**kwargs) + conf = transformers.LlamaConfig(**config) + model = transformers.LlamaForCausalLM(conf) + model.eval() + + # now the inputs + cache_last_dim = 96 + sequence_length = 30 + sequence_length2 = 3 + num_key_value_heads = 1 + max_token_id = config["vocab_size"] - 1 + n_layers = config["num_hidden_layers"] + + batch = torch.export.Dim("batch", min=1, max=1024) + seq_length = torch.export.Dim("seq_length", min=1, max=4096) + cache_length = torch.export.Dim("cache_length", min=1, max=4096) + + shapes = { + "input_ids": {0: batch, 1: seq_length}, + "attention_mask": { + 0: batch, + 1: torch.export.Dim.DYNAMIC, # cache_length + seq_length + }, + "past_key_values": [ + [{0: batch, 2: cache_length} for _ in range(n_layers)], + [{0: batch, 2: cache_length} for _ in range(n_layers)], + ], + } + inputs = dict( + input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to( + torch.int64 + ), + attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to( + torch.int64 + ), + past_key_values=make_dynamic_cache( + [ + ( + torch.randn( + batch_size, num_key_value_heads, sequence_length, cache_last_dim + ), + torch.randn( + batch_size, num_key_value_heads, sequence_length, cache_last_dim + ), + ) + for i in range(n_layers) + ] + ), + ) + return dict(inputs=inputs, model=model, dynamic_shapes=shapes) diff --git a/pyproject.toml b/pyproject.toml index a454d696..0c09a1d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,14 @@ disable_error_code = ["arg-type", "assignment", "import-untyped", "misc", "name- module = ["onnx_diagnostic.ort_session"] disable_error_code = ["union-attr"] +[[tool.mypy.overrides]] +module = ["onnx_diagnostic.torch_export_patches.*"] +disable_error_code = ["arg-type", "assignment", "attr-defined", "index", "misc", "name-defined", "operator", "return-value"] + +[[tool.mypy.overrides]] +module = ["onnx_diagnostic.torch_models.*"] +disable_error_code = ["attr-defined", "call-overload", "operator"] + [tool.ruff] # Exclude a variety of commonly ignored directories. @@ -86,3 +94,5 @@ select = [ "_doc/notebooks/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"] "_doc/recipes/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"] "_unittests/*/test*.py" = ["B008", "B904", "PIE808", "SIM117", "SIM105", "UP008"] +"onnx_diagnostic/torch_export_patches/__init__.py" = ["F401"] +"onnx_diagnostic/torch_export_patches/patches/__init__.py" = ["F401"]