Skip to content

Move gemm_to_matmul_add rule to ort fusion rules #2398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 19, 2025
Merged

Conversation

justinchuby
Copy link
Collaborator

Stop decomposing gemm to matmul add by default because it is a more compact representation. Move the ort fusion rules so it keeps functioning for ort.

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR moves the gemm_to_matmul_add rewrite rule out of the default optimization pipeline and into the ORT-specific fusion rules, ensuring it’s only applied when running optimize_for_ort.

  • Removed gemm_to_matmul_add from the default rewrites in onnxscript/rewriter/__init__.py.
  • Imported and applied gemm_to_matmul_add.rule in optimize_for_ort within onnxscript/rewriter/ort_fusions/_core.py.
  • Cleaned up the optimize tutorial docs, simplified the table format, and removed the default rule list.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
onnxscript/rewriter/ort_fusions/_core.py Imported gemm_to_matmul_add, applied its rule in optimize_for_ort.
onnxscript/rewriter/init.py Removed gemm_to_matmul_add from default rewrite rules.
docs/tutorial/optimizer/optimize.md Reformatted optimization table and removed the list of default patterns.
Comments suppressed due to low confidence (1)

docs/tutorial/optimizer/optimize.md:28

  • [nitpick] The docs no longer mention that gemm_to_matmul_add has been removed from the default optimization pipeline. Consider adding a note under the API section to explain that this rule now only runs in ORT-specific optimizations.
| Optimization | Description |

Copy link

codecov bot commented Jun 18, 2025

❌ 10 Tests Failed:

Tests completed Failed Passed Skipped
16446 10 16436 2363
View the top 3 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0049_test_argmax_no_keepdims_example_select_last_index
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_argmax_no_keepdims_example_select_last_index'

The above exception was the direct cause of the following exception:
.nox\test_torch_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_argmax_no_keepdims_example_select_last_index' (e=No module named 'tests.onnx_backend_test_code.test_argmax_no_keepdims_example_select_last_index') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_argmax_no_keepdims_example_select_last_index.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_argmax_no_keepdims_example_select_last_index.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset13
E   
E   @script()
E   def bck_test_argmax_no_keepdims_example_select_last_index(data: FLOAT[2,2]) -> (INT64[2]):
E       result = opset13.ArgMax(data, axis=1, keepdims=0, select_last_index=1)
E       return result
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1043_test_regex_full_match_email_domain
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_regex_full_match_email_domain'

The above exception was the direct cause of the following exception:
.nox\test_torch_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_regex_full_match_email_domain' (e=No module named 'tests.onnx_backend_test_code.test_regex_full_match_email_domain') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_regex_full_match_email_domain.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_regex_full_match_email_domain.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import BOOL, STRING
E   from onnxscript.onnx_opset import opset20
E   
E   @script()
E   def bck_test_regex_full_match_email_domain(X: STRING[2,2]) -> (BOOL[2,2]):
E       Y = opset20.RegexFullMatch(X, pattern='(\\W|^)[\\w.\\-]{0,25}@(yahoo|gmail)\\.com(\\W|$)')
E       return Y
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1160_test_scatternd_add
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_scatternd_add'

The above exception was the direct cause of the following exception:
.nox\test_torch_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_scatternd_add' (e=No module named 'tests.onnx_backend_test_code.test_scatternd_add') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_scatternd_add.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_scatternd_add.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset18
E   
E   @script()
E   def bck_test_scatternd_add(data: FLOAT[4,4,4], indices: INT64[2,1], updates: FLOAT[2,4,4]) -> (FLOAT[4,4,4]):
E       y = opset18.ScatterND(data, indices, updates, reduction='add')
E       return y

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

@justinchuby justinchuby merged commit e71c889 into main Jun 19, 2025
25 of 30 checks passed
@justinchuby justinchuby deleted the justinchu/gemm-rule branch June 19, 2025 04:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

2 participants