diff --git a/qupulse/utils/sympy.py b/qupulse/utils/sympy.py index 076f7dc90..763e23a20 100644 --- a/qupulse/utils/sympy.py +++ b/qupulse/utils/sympy.py @@ -15,15 +15,22 @@ import numpy try: - from sympy.printing.numpy import NumPyPrinter + import scipy +except ImportError: + scipy = None + +try: + from sympy.printing.numpy import NumPyPrinter, SciPyPrinter except ImportError: # sympy moved NumPyPrinter in release 1.8 from sympy.printing.pycode import NumPyPrinter + SciPyPrinter = None warnings.warn("Please update sympy.", DeprecationWarning) -try: + +if scipy: import scipy.special as _special_functions -except ImportError: +else: _special_functions = {fname: numpy.vectorize(fobject) for fname, fobject in math.__dict__.items() if callable(fobject) and not fname.startswith('_') and fname not in numpy.__dict__} @@ -31,6 +38,12 @@ 'manually vectorized functions in math.') +if scipy and SciPyPrinter: + PrinterBase = SciPyPrinter +else: + PrinterBase = NumPyPrinter + + __all__ = ["sympify", "substitute_with_eval", "to_numpy", "get_variables", "get_free_symbols", "recursive_substitution", "evaluate_lambdified", "get_most_simple_representation"] @@ -378,6 +391,10 @@ def recursive_substitution(expression: sympy.Expr, _lambdify_modules = [{'ceiling': numpy_compatible_ceiling, 'floor': _floor_to_int, 'Broadcast': numpy.broadcast_to}, 'numpy', _special_functions] +if scipy: + # this is required for Integral lambdification + _lambdify_modules.append("scipy") + def evaluate_compiled(expression: sympy.Expr, parameters: Dict[str, Union[numpy.ndarray, Number]], @@ -404,7 +421,7 @@ def evaluate_lambdified(expression: Union[sympy.Expr, numpy.ndarray], return lambdified(**parameters), lambdified -class HighPrecPrinter(NumPyPrinter): +class HighPrecPrinter(PrinterBase): """Custom printer that translates sympy.Rational into TimeType""" def _print_Rational(self, expr): return f'TimeType.from_fraction({expr.p}, {expr.q})' diff --git a/tests/utils/sympy_tests.py b/tests/utils/sympy_tests.py index b1ff6b1e6..38f7e9e10 100644 --- a/tests/utils/sympy_tests.py +++ b/tests/utils/sympy_tests.py @@ -14,8 +14,13 @@ import sympy import numpy as np +try: + import scipy +except ImportError: + scipy = None + from sympy.abc import a, b, c, d, e, f, k, l, m, n, i, j -from sympy import sin, Sum, IndexedBase, Rational +from sympy import sin, Sum, IndexedBase, Rational, Integral a_ = IndexedBase(a) b_ = IndexedBase(b) @@ -120,6 +125,9 @@ # TODO: this fails # (np.array([a, Rational(1, 3)]), {'a': 2}, np.array([2, TimeType.from_fraction(1, 3)])) ] +eval_integral = [ + (Integral(sin(b * a ** 2 + c * a) / a, (a, 0, c))/b, {'b': 5, 'c': 100.}, 0.26302083739430604) +] class TestCase(unittest.TestCase): @@ -305,6 +313,24 @@ def test_eval_exact_rational(self): except ValueError: np.testing.assert_equal(result, expected) + def test_integral(self): + if type(self) is CompiledEvaluationTest: + raise unittest.SkipTest("Integrals are not representable in pure repr lambdas.") + + if scipy is None: + # printer based evaluate requires scipy to print integrals + for expr, parameters, _ in eval_integral: + with self.assertRaises(NotImplementedError): + self.evaluate(expr, parameters) + return + + for expr, parameters, expected in eval_integral: + result = self.evaluate(expr, parameters) + try: + self.assertEqual(result, expected) + except ValueError: + np.testing.assert_equal(result, expected) + class LamdifiedEvaluationTest(EvaluationTestsBase, unittest.TestCase): def evaluate(self, expression: Union[sympy.Expr, np.ndarray], parameters):