diff --git a/changes.d/604.bugfix b/changes.d/604.bugfix new file mode 100644 index 00000000..5c89ec78 --- /dev/null +++ b/changes.d/604.bugfix @@ -0,0 +1 @@ +Fix `repr` of `ExpressionScalar` when constructed from a sympy expression. Also replace `Expression` with `ExpressionScalar` in `repr`. diff --git a/qupulse/expressions.py b/qupulse/expressions.py index 6c577309..bf0a8575 100644 --- a/qupulse/expressions.py +++ b/qupulse/expressions.py @@ -265,7 +265,10 @@ def __str__(self) -> str: return str(self._sympified_expression) def __repr__(self) -> str: - return 'Expression({})'.format(repr(self._original_expression)) + if self._original_expression is None: + return f"ExpressionScalar('{self._sympified_expression!r}')" + else: + return f"ExpressionScalar({self._original_expression!r})" def __format__(self, format_spec): if format_spec == '': diff --git a/tests/expression_tests.py b/tests/expression_tests.py index 2d2a9e05..a5cb1601 100644 --- a/tests/expression_tests.py +++ b/tests/expression_tests.py @@ -2,6 +2,7 @@ import sys import numpy as np +import sympy.abc from sympy import sympify, Eq from qupulse.expressions import Expression, ExpressionVariableMissingException, NonNumericEvaluation, ExpressionScalar, ExpressionVector @@ -257,7 +258,23 @@ def test_evaluate_variable_missing(self) -> None: def test_repr(self): s = 'a * b' e = ExpressionScalar(s) - self.assertEqual("Expression('a * b')", repr(e)) + self.assertEqual("ExpressionScalar('a * b')", repr(e)) + + def test_repr_original_expression_is_sympy(self): + # in this case we test that we get the original expression back if we do + # eval(repr(e)) + + org = sympy.sympify(3.1415) + e = ExpressionScalar(org) + self.assertEqual(e, eval(repr(e))) + + org = sympy.abc.a * sympy.abc.b + e = ExpressionScalar(org) + self.assertEqual(e, eval(repr(e))) + + org = sympy.sympify('3/17') + e = ExpressionScalar(org) + self.assertEqual(e, eval(repr(e))) def test_str(self): s = 'a * b' diff --git a/tests/pulses/arithmetic_pulse_template_tests.py b/tests/pulses/arithmetic_pulse_template_tests.py index 11b7d65d..7b5ec0e4 100644 --- a/tests/pulses/arithmetic_pulse_template_tests.py +++ b/tests/pulses/arithmetic_pulse_template_tests.py @@ -533,7 +533,7 @@ def test_repr(self): with mock.patch.object(DummyPulseTemplate, '__repr__', wraps=lambda *args: 'dummy'): r = repr(ArithmeticPulseTemplate(pt, '-', scalar)) - self.assertEqual("(dummy - Expression('x'))", r) + self.assertEqual("(dummy - ExpressionScalar('x'))", r) arith = ArithmeticPulseTemplate(pt, '-', scalar, identifier='id') self.assertEqual(super(ArithmeticPulseTemplate, arith).__repr__(), repr(arith))