Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions qupulse/utils/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,35 @@
import numpy

try:
from sympy.printing.numpy import NumPyPrinter
import scipy
except ImportError:
scipy = None

try:
from sympy.printing.numpy import NumPyPrinter, SciPyPrinter
except ImportError:
# sympy moved NumPyPrinter in release 1.8
from sympy.printing.pycode import NumPyPrinter
SciPyPrinter = None
warnings.warn("Please update sympy.", DeprecationWarning)

try:

if scipy:
import scipy.special as _special_functions
except ImportError:
else:
_special_functions = {fname: numpy.vectorize(fobject)
for fname, fobject in math.__dict__.items()
if callable(fobject) and not fname.startswith('_') and fname not in numpy.__dict__}
warnings.warn('scipy is not installed. This reduces the set of available functions to those present in numpy + '
'manually vectorized functions in math.')


if scipy and SciPyPrinter:
PrinterBase = SciPyPrinter
else:
PrinterBase = NumPyPrinter


__all__ = ["sympify", "substitute_with_eval", "to_numpy", "get_variables", "get_free_symbols", "recursive_substitution",
"evaluate_lambdified", "get_most_simple_representation"]

Expand Down Expand Up @@ -378,6 +391,10 @@ def recursive_substitution(expression: sympy.Expr,
_lambdify_modules = [{'ceiling': numpy_compatible_ceiling, 'floor': _floor_to_int,
'Broadcast': numpy.broadcast_to}, 'numpy', _special_functions]

if scipy:
# this is required for Integral lambdification
_lambdify_modules.append("scipy")


def evaluate_compiled(expression: sympy.Expr,
parameters: Dict[str, Union[numpy.ndarray, Number]],
Expand All @@ -404,7 +421,7 @@ def evaluate_lambdified(expression: Union[sympy.Expr, numpy.ndarray],
return lambdified(**parameters), lambdified


class HighPrecPrinter(NumPyPrinter):
class HighPrecPrinter(PrinterBase):
"""Custom printer that translates sympy.Rational into TimeType"""
def _print_Rational(self, expr):
return f'TimeType.from_fraction({expr.p}, {expr.q})'
Expand Down
28 changes: 27 additions & 1 deletion tests/utils/sympy_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
import sympy
import numpy as np

try:
import scipy
except ImportError:
scipy = None

from sympy.abc import a, b, c, d, e, f, k, l, m, n, i, j
from sympy import sin, Sum, IndexedBase, Rational
from sympy import sin, Sum, IndexedBase, Rational, Integral

a_ = IndexedBase(a)
b_ = IndexedBase(b)
Expand Down Expand Up @@ -120,6 +125,9 @@
# TODO: this fails
# (np.array([a, Rational(1, 3)]), {'a': 2}, np.array([2, TimeType.from_fraction(1, 3)]))
]
eval_integral = [
(Integral(sin(b * a ** 2 + c * a) / a, (a, 0, c))/b, {'b': 5, 'c': 100.}, 0.26302083739430604)
]


class TestCase(unittest.TestCase):
Expand Down Expand Up @@ -305,6 +313,24 @@ def test_eval_exact_rational(self):
except ValueError:
np.testing.assert_equal(result, expected)

def test_integral(self):
if type(self) is CompiledEvaluationTest:
raise unittest.SkipTest("Integrals are not representable in pure repr lambdas.")

if scipy is None:
# printer based evaluate requires scipy to print integrals
for expr, parameters, _ in eval_integral:
with self.assertRaises(NotImplementedError):
self.evaluate(expr, parameters)
return

for expr, parameters, expected in eval_integral:
result = self.evaluate(expr, parameters)
try:
self.assertEqual(result, expected)
except ValueError:
np.testing.assert_equal(result, expected)


class LamdifiedEvaluationTest(EvaluationTestsBase, unittest.TestCase):
def evaluate(self, expression: Union[sympy.Expr, np.ndarray], parameters):
Expand Down
Loading