From 794c98518d4655fa326ef505e6982546d2b18abe Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Fri, 28 Aug 2020 18:35:39 +0100 Subject: [PATCH 1/2] Fix sympy.testing.pytest.raises to behave like pytest.raises Code written as `assert raises(...)` would pass in pytest and fail in sympy, because sympy was not returning an appropriate object from `raises`. --- sympy/multipledispatch/tests/test_core.py | 4 +--- sympy/multipledispatch/tests/test_dispatcher.py | 8 +------- sympy/testing/pytest.py | 14 ++++++++++++-- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/sympy/multipledispatch/tests/test_core.py b/sympy/multipledispatch/tests/test_core.py index d5df0bd605a7..c8e60caf5f0d 100644 --- a/sympy/multipledispatch/tests/test_core.py +++ b/sympy/multipledispatch/tests/test_core.py @@ -2,7 +2,7 @@ from sympy.multipledispatch import dispatch from sympy.multipledispatch.conflict import AmbiguityWarning -from sympy.testing.pytest import raises, XFAIL, warns +from sympy.testing.pytest import raises, warns from functools import partial test_namespace = dict() # type: Dict[str, Any] @@ -11,7 +11,6 @@ dispatch = partial(dispatch, namespace=test_namespace) -@XFAIL def test_singledispatch(): @dispatch(int) def f(x): # noqa:F811 @@ -66,7 +65,6 @@ def f(x): # noqa:F811 assert f(C()) == 'a' -@XFAIL def test_inheritance_and_multiple_dispatch(): @dispatch(A, A) def f(x, y): # noqa:F811 diff --git a/sympy/multipledispatch/tests/test_dispatcher.py b/sympy/multipledispatch/tests/test_dispatcher.py index 28732cae068a..b980a8b19afc 100644 --- a/sympy/multipledispatch/tests/test_dispatcher.py +++ b/sympy/multipledispatch/tests/test_dispatcher.py @@ -1,7 +1,7 @@ from sympy.multipledispatch.dispatcher import (Dispatcher, MDNotImplementedError, MethodDispatcher, halt_ordering, restart_ordering) -from sympy.testing.pytest import raises, XFAIL, warns +from sympy.testing.pytest import raises, warns def identity(x): @@ -88,7 +88,6 @@ def on_ambiguity(dispatcher, amb): assert ambiguities[0] -@XFAIL def test_raise_error_on_non_class(): f = Dispatcher('f') assert raises(TypeError, lambda: f.add((1,), inc)) @@ -165,7 +164,6 @@ def two(x, y): assert 'x - y' in f._source(1.0, 1.0) -@XFAIL def test_source_raises_on_missing_function(): f = Dispatcher('f') @@ -197,13 +195,11 @@ def func(*args): assert set(f.ordering) == {(int, object), (object, int)} -@XFAIL def test_no_implementations(): f = Dispatcher('f') assert raises(NotImplementedError, lambda: f('hello')) -@XFAIL def test_register_stacking(): f = Dispatcher('f') @@ -238,7 +234,6 @@ class MyList(list): assert f.dispatch(int, int) is add -@XFAIL def test_not_implemented(): f = Dispatcher('f') @@ -259,7 +254,6 @@ def _(x): assert raises(NotImplementedError, lambda: f(1, 2)) -@XFAIL def test_not_implemented_error(): f = Dispatcher('f') diff --git a/sympy/testing/pytest.py b/sympy/testing/pytest.py index 174fc590a035..2e8744ac4475 100644 --- a/sympy/testing/pytest.py +++ b/sympy/testing/pytest.py @@ -34,6 +34,15 @@ # Not using pytest so define the things that would have been imported from # there. + # _pytest._code.code.ExceptionInfo + class ExceptionInfo: + def __init__(self, value): + self.value = value + + def __repr__(self): + return "".format(self.value) + + def raises(expectedException, code=None): """ Tests that ``code`` raises the exception ``expectedException``. @@ -54,6 +63,7 @@ def raises(expectedException, code=None): >>> from sympy.testing.pytest import raises >>> raises(ZeroDivisionError, lambda: 1/0) + >>> raises(ZeroDivisionError, lambda: 1/2) Traceback (most recent call last): ... @@ -92,8 +102,8 @@ def raises(expectedException, code=None): elif callable(code): try: code() - except expectedException: - return + except expectedException as e: + return ExceptionInfo(e) raise Failed("DID NOT RAISE") elif isinstance(code, str): raise TypeError( From 5ff9c4b33d1f7f31c875ecac4f6ff26d0f749ec4 Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Sun, 30 Aug 2020 13:40:44 +0100 Subject: [PATCH 2/2] Update sympy/testing/pytest.py --- sympy/testing/pytest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sympy/testing/pytest.py b/sympy/testing/pytest.py index 2e8744ac4475..11abbad3fe04 100644 --- a/sympy/testing/pytest.py +++ b/sympy/testing/pytest.py @@ -63,7 +63,7 @@ def raises(expectedException, code=None): >>> from sympy.testing.pytest import raises >>> raises(ZeroDivisionError, lambda: 1/0) - + >>> raises(ZeroDivisionError, lambda: 1/2) Traceback (most recent call last): ...