Skip to content

Clean up rewriter code: improve efficiency, finish TODOs, and enhance documentation #2392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Copilot
Copy link
Contributor

@Copilot Copilot AI commented Jun 14, 2025

This PR addresses the rewriter code cleanup issue by improving efficiency, finishing important TODOs, and refactoring for better readability and documentation.

Major Changes

1. Fixed Context Passing in Condition Functions

  • Problem: The context parameter in condition functions was set to None (TODO line 173)
  • Solution: Implemented proper _RewriteContext class that provides access to model, graph/function, current node, and match information
  • Impact: Condition functions can now access complete context for better decision making
def condition_fn(context, x, y):
    # Now has access to context.model, context.node, context.match, etc.
    return context.node.op_type == "Add" and y.const_value is not None

2. Simplified and Improved Binding Logic

  • Problem: Complex, redundant binding logic with poor equality checks (TODOs lines 145-151, 166)
  • Solution: Refactored with helper methods and robust _values_equal function
  • Impact: More reliable binding with proper IR value/node identity comparison

3. Efficiency Optimizations

  • Constant Propagation: Reduced redundant calls by batching operations upfront
  • Safe Iteration: Use list() for node iteration to avoid modification-during-iteration issues
  • Matcher Construction: Improved commutation logic with consistent matcher selection strategy

4. Enhanced Documentation

  • Added comprehensive module-level documentation explaining the rewriter architecture
  • Provided practical usage examples in pattern.py for common rewriting scenarios
  • Improved class and method docstrings throughout the codebase

5. TODO Cleanup

  • Resolved 7 major TODOs with proper implementations or clarifying comments
  • Removed outdated comments about opset import cleanup (handled by existing passes)
  • Added explanatory comments for complex logic that was previously unclear

Example Usage

The improvements enable more powerful condition functions:

from onnxscript.rewriter import pattern

def add_pattern(op, x, y):
    return op.Add(x, y)

def optimized_replacement(op, x, y):
    return op.Add(y, x)  # Commute for optimization

def smart_condition(context, x, y):
    # Can now access full context information
    if not hasattr(context, 'model'):
        return False
    # Only apply if y is a small constant
    return (y.const_value is not None and 
            abs(y.const_value.item()) < 1.0)

rule = pattern.RewriteRule(
    add_pattern,
    optimized_replacement,
    smart_condition
)

Testing

  • All existing tests pass (26/26 in pattern_test.py)
  • Added verification tests for context passing and binding improvements
  • Maintained full backward compatibility for existing APIs

Code Quality

  • 274 lines added, 38 lines removed across 3 core files
  • No breaking changes to public APIs
  • Comprehensive documentation with practical examples
  • Improved type safety and error handling

The rewriter is now more efficient, better documented, and provides the proper context infrastructure that was missing for advanced pattern matching scenarios.

Fixes #2391.


💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.

Copilot AI and others added 2 commits June 14, 2025 16:14
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
…tion

Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
@Copilot Copilot AI changed the title [WIP] Clean up rewriter code Clean up rewriter code: improve efficiency, finish TODOs, and enhance documentation Jun 14, 2025
@Copilot Copilot AI requested a review from justinchuby June 14, 2025 16:17
Copilot finished work on behalf of justinchuby June 14, 2025 16:17
@justinchuby justinchuby marked this pull request as ready for review June 14, 2025 22:12
# For other types, use regular equality
try:
return value1 == value2
except Exception:

Check warning

Code scanning / CodeQL

Unreachable code Warning

This statement is unreachable.

Copilot Autofix

AI 4 days ago

To fix the issue, the unreachable except Exception: block should be removed. This simplifies the code and makes it clearer, while preserving its intended functionality. The equality comparison (value1 == value2) is robust enough for standard use cases, and the removal of the unreachable code does not affect the behavior of the function.

Changes to make:

  1. Remove the try block and the except Exception: block entirely.
  2. Replace the try block with a direct equality comparison (return value1 == value2).

Suggested changeset 1
onnxscript/rewriter/_basics.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxscript/rewriter/_basics.py b/onnxscript/rewriter/_basics.py
--- a/onnxscript/rewriter/_basics.py
+++ b/onnxscript/rewriter/_basics.py
@@ -49,7 +49,3 @@
     # For other types, use regular equality
-    try:
-        return value1 == value2
-    except Exception:
-        # If comparison fails, values are not equal
-        return False
+    return value1 == value2
 
EOF
@@ -49,7 +49,3 @@
# For other types, use regular equality
try:
return value1 == value2
except Exception:
# If comparison fails, values are not equal
return False
return value1 == value2

Copilot is powered by AI and may make mistakes. Always verify output.
Copy link

codecov bot commented Jun 14, 2025

❌ 2 Tests Failed:

Tests completed Failed Passed Skipped
5646 2 5644 3076
View the top 3 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0335_test_div_example
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_div_example'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_div_example' (e=No module named 'tests.onnx_backend_test_code.test_div_example') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_div_example.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_div_example.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset14
E   
E   @script()
E   def bck_test_div_example(x: FLOAT[2], y: FLOAT[2]) -> (FLOAT[2]):
E       z = opset14.Div(x, y)
E       return z
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0374_test_flatten_axis1
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_flatten_axis1'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_flatten_axis1' (e=No module named 'tests.onnx_backend_test_code.test_flatten_axis1') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_flatten_axis1.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_flatten_axis1.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset21
E   
E   @script()
E   def bck_test_flatten_axis1(a: FLOAT[2,3,4,5]) -> (FLOAT[2,60]):
E       b = opset21.Flatten(a, axis=1)
E       return b
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0412_test_gemm_transposeB
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_gemm_transposeB'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_gemm_transposeB' (e=No module named 'tests.onnx_backend_test_code.test_gemm_transposeB') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_gemm_transposeB.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_gemm_transposeB.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset13
E   
E   @script()
E   def bck_test_gemm_transposeB(a: FLOAT[3,6], b: FLOAT[4,6], c: FLOAT[1,4]) -> (FLOAT[3,4]):
E       y = opset13.Gemm(a, b, c, transB=1)
E       return y

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Copy link
Contributor

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

Clean up rewriter code
2 participants