diff --git a/_doc/patches.rst b/_doc/patches.rst index 287b38d7..8495732f 100644 --- a/_doc/patches.rst +++ b/_doc/patches.rst @@ -8,6 +8,7 @@ Function :func:`onnx_diagnostic.torch_export_patches.torch_export_patches` implements four kinds of patches to make it easier to export a model, usually coming from :epkg:`transformers`. All patches takes place in :mod:`onnx_diagnostic.torch_export_patches`. + .. code-block:: python with torch_export_patches(...) as f: @@ -121,13 +122,19 @@ requires the following value for parameter ``rewrite``: .. runpython:: :showcode: + import pprint from onnx_diagnostic.torch_export_patches.patch_module_helper import ( code_needing_rewriting, ) - print(code_needing_rewriting("BartForConditionalGeneration")) + pprint.pprint(code_needing_rewriting("BartForConditionalGeneration")) -And that produces: +This method has two tests. Only the first one needs to be rewritten. +The second one manipulates tuple and the automated rewritten does not handle +that because it cannot detect types. That explains why the parameter +``filter_node`` is filled. Then, the first test includes a condition relying on ``or`` +which must be replaced by ``|``. That explains the parameter ``pre_rewriter``. +We finally get: .. code-block:: diff diff --git a/onnx_diagnostic/torch_export_patches/patch_module_helper.py b/onnx_diagnostic/torch_export_patches/patch_module_helper.py index 3f9fd0c8..fbedb2a2 100644 --- a/onnx_diagnostic/torch_export_patches/patch_module_helper.py +++ b/onnx_diagnostic/torch_export_patches/patch_module_helper.py @@ -30,11 +30,12 @@ def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]: .. runpython:: :showcode: + import pprint from onnx_diagnostic.torch_export_patches.patch_module_helper import ( code_needing_rewriting, ) - print(code_needing_rewriting("BartForConditionalGeneration")) + pprint.pprint(code_needing_rewriting("BartForConditionalGeneration")) """ if cls_name in { "BartEncoderLayer",