Skip to content

Allow disabling assertion rewriting at a function level #12610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
chaudhary1337 opened this issue Jul 15, 2024 · 0 comments
Open

Allow disabling assertion rewriting at a function level #12610

chaudhary1337 opened this issue Jul 15, 2024 · 0 comments
Labels
topic: rewrite related to the assertion rewrite mechanism type: proposal proposal for a new feature, often to gather opinions or design the API around the new feature

Comments

@chaudhary1337
Copy link

Background

I was working on a function decorated with numba.jit(nopython=True) containing an assert statement. The test fails with numba errors, pointing to the assert statement, even though the logic is correct.

Example:

# test_numba.py

from numba import jit
import numpy as np


@jit(nopython=True)
def go_fast(a):
    trace = 0.0
    for i in range(a.shape[0]):
        trace += np.tanh(a[i, i])

    assert (a + trace).all()
    return a + trace


def test_go_fast():
    x = np.arange(4).reshape(2, 2)
    ans = go_fast(x).round(3)
    assert (
        ans.tolist() == [[0.995, 1.995], [2.995, 3.995]]
    )

What's the problem this feature will solve?

pytest assert rewriting deals with dumping and reading the bytecode - something numba interferes with too! This causes problems while running numba-decorated functions.

Describe the solution you'd like

The current solutions are:

  1. Turn off rich pytest asserts for the entirety of the execution with --assert=plain.
  2. Turn off rich pytest asserts per module with PYTEST_DONT_REWRITE in the docstring.

While (2) is better than (1), it requires manual intervention and edits to each of the modules containing any jitted functions and such,

The ideal solution would be to auto-disable pytest rewriting at a function level.
That is, on finding a jitted function, pytest can skip rewriting only that. The remaining all tests in a module can still have rich asserts.

Alternative Solutions

My current solution has been to patch rewrite.AssertionRewriter.run. I modified the part where we access ast.FunctionDef etc., to check the node.decorator_list and see if is using any numba decorators. If a nopython=True function has been found, I skip rewriting all asserts at the function level. Remaining logic stays the same.

@The-Compiler The-Compiler changed the title Allow disabling asserts at a function level Allow disabling assertion rewriting at a function level Jul 15, 2024
@Zac-HD Zac-HD added type: proposal proposal for a new feature, often to gather opinions or design the API around the new feature topic: rewrite related to the assertion rewrite mechanism labels Jul 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: rewrite related to the assertion rewrite mechanism type: proposal proposal for a new feature, often to gather opinions or design the API around the new feature
Projects
None yet
Development

No branches or pull requests

2 participants