From c5cb5f11dc7b759e41332ceb2348d01f82b83475 Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Fri, 28 Aug 2020 18:35:39 +0100 Subject: [PATCH] Fix sympy.testing.pytest.raises to behave like pytest.raises Code written as `assert raises(...)` would pass in pytest and fail in sympy, because sympy was not returning an appropriate object from `raises`. --- sympy/multipledispatch/tests/test_core.py | 4 +--- sympy/multipledispatch/tests/test_dispatcher.py | 8 +------- sympy/testing/pytest.py | 6 +++++- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/sympy/multipledispatch/tests/test_core.py b/sympy/multipledispatch/tests/test_core.py index d5df0bd605a7..c8e60caf5f0d 100644 --- a/sympy/multipledispatch/tests/test_core.py +++ b/sympy/multipledispatch/tests/test_core.py @@ -2,7 +2,7 @@ from sympy.multipledispatch import dispatch from sympy.multipledispatch.conflict import AmbiguityWarning -from sympy.testing.pytest import raises, XFAIL, warns +from sympy.testing.pytest import raises, warns from functools import partial test_namespace = dict() # type: Dict[str, Any] @@ -11,7 +11,6 @@ dispatch = partial(dispatch, namespace=test_namespace) -@XFAIL def test_singledispatch(): @dispatch(int) def f(x): # noqa:F811 @@ -66,7 +65,6 @@ def f(x): # noqa:F811 assert f(C()) == 'a' -@XFAIL def test_inheritance_and_multiple_dispatch(): @dispatch(A, A) def f(x, y): # noqa:F811 diff --git a/sympy/multipledispatch/tests/test_dispatcher.py b/sympy/multipledispatch/tests/test_dispatcher.py index 28732cae068a..b980a8b19afc 100644 --- a/sympy/multipledispatch/tests/test_dispatcher.py +++ b/sympy/multipledispatch/tests/test_dispatcher.py @@ -1,7 +1,7 @@ from sympy.multipledispatch.dispatcher import (Dispatcher, MDNotImplementedError, MethodDispatcher, halt_ordering, restart_ordering) -from sympy.testing.pytest import raises, XFAIL, warns +from sympy.testing.pytest import raises, warns def identity(x): @@ -88,7 +88,6 @@ def on_ambiguity(dispatcher, amb): assert ambiguities[0] -@XFAIL def test_raise_error_on_non_class(): f = Dispatcher('f') assert raises(TypeError, lambda: f.add((1,), inc)) @@ -165,7 +164,6 @@ def two(x, y): assert 'x - y' in f._source(1.0, 1.0) -@XFAIL def test_source_raises_on_missing_function(): f = Dispatcher('f') @@ -197,13 +195,11 @@ def func(*args): assert set(f.ordering) == {(int, object), (object, int)} -@XFAIL def test_no_implementations(): f = Dispatcher('f') assert raises(NotImplementedError, lambda: f('hello')) -@XFAIL def test_register_stacking(): f = Dispatcher('f') @@ -238,7 +234,6 @@ class MyList(list): assert f.dispatch(int, int) is add -@XFAIL def test_not_implemented(): f = Dispatcher('f') @@ -259,7 +254,6 @@ def _(x): assert raises(NotImplementedError, lambda: f(1, 2)) -@XFAIL def test_not_implemented_error(): f = Dispatcher('f') diff --git a/sympy/testing/pytest.py b/sympy/testing/pytest.py index 174fc590a035..d681775bb342 100644 --- a/sympy/testing/pytest.py +++ b/sympy/testing/pytest.py @@ -34,6 +34,10 @@ # Not using pytest so define the things that would have been imported from # there. + class _ExceptionInfo: + # pytest would implement this + pass + def raises(expectedException, code=None): """ Tests that ``code`` raises the exception ``expectedException``. @@ -93,7 +97,7 @@ def raises(expectedException, code=None): try: code() except expectedException: - return + return _ExceptionInfo() raise Failed("DID NOT RAISE") elif isinstance(code, str): raise TypeError(