From 6d9b8c46f4ab466718d9f3e78135dd20ab917c59 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Thu, 8 Jul 2021 12:51:27 +0200 Subject: [PATCH 1/2] Fix repr of ExpressionScalar --- changes.d/604.bugfix | 1 + qupulse/expressions.py | 5 ++++- tests/expression_tests.py | 19 ++++++++++++++++++- 3 files changed, 23 insertions(+), 2 deletions(-) create mode 100644 changes.d/604.bugfix diff --git a/changes.d/604.bugfix b/changes.d/604.bugfix new file mode 100644 index 00000000..5c89ec78 --- /dev/null +++ b/changes.d/604.bugfix @@ -0,0 +1 @@ +Fix `repr` of `ExpressionScalar` when constructed from a sympy expression. Also replace `Expression` with `ExpressionScalar` in `repr`. diff --git a/qupulse/expressions.py b/qupulse/expressions.py index 6c577309..bf0a8575 100644 --- a/qupulse/expressions.py +++ b/qupulse/expressions.py @@ -265,7 +265,10 @@ def __str__(self) -> str: return str(self._sympified_expression) def __repr__(self) -> str: - return 'Expression({})'.format(repr(self._original_expression)) + if self._original_expression is None: + return f"ExpressionScalar('{self._sympified_expression!r}')" + else: + return f"ExpressionScalar({self._original_expression!r})" def __format__(self, format_spec): if format_spec == '': diff --git a/tests/expression_tests.py b/tests/expression_tests.py index 2d2a9e05..a5cb1601 100644 --- a/tests/expression_tests.py +++ b/tests/expression_tests.py @@ -2,6 +2,7 @@ import sys import numpy as np +import sympy.abc from sympy import sympify, Eq from qupulse.expressions import Expression, ExpressionVariableMissingException, NonNumericEvaluation, ExpressionScalar, ExpressionVector @@ -257,7 +258,23 @@ def test_evaluate_variable_missing(self) -> None: def test_repr(self): s = 'a * b' e = ExpressionScalar(s) - self.assertEqual("Expression('a * b')", repr(e)) + self.assertEqual("ExpressionScalar('a * b')", repr(e)) + + def test_repr_original_expression_is_sympy(self): + # in this case we test that we get the original expression back if we do + # eval(repr(e)) + + org = sympy.sympify(3.1415) + e = ExpressionScalar(org) + self.assertEqual(e, eval(repr(e))) + + org = sympy.abc.a * sympy.abc.b + e = ExpressionScalar(org) + self.assertEqual(e, eval(repr(e))) + + org = sympy.sympify('3/17') + e = ExpressionScalar(org) + self.assertEqual(e, eval(repr(e))) def test_str(self): s = 'a * b' From d42b2c71adcf62ef070eb0b3b31815afaf2a60ad Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Thu, 8 Jul 2021 13:05:55 +0200 Subject: [PATCH 2/2] Fix ArithmeticPT test --- tests/pulses/arithmetic_pulse_template_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pulses/arithmetic_pulse_template_tests.py b/tests/pulses/arithmetic_pulse_template_tests.py index 11b7d65d..7b5ec0e4 100644 --- a/tests/pulses/arithmetic_pulse_template_tests.py +++ b/tests/pulses/arithmetic_pulse_template_tests.py @@ -533,7 +533,7 @@ def test_repr(self): with mock.patch.object(DummyPulseTemplate, '__repr__', wraps=lambda *args: 'dummy'): r = repr(ArithmeticPulseTemplate(pt, '-', scalar)) - self.assertEqual("(dummy - Expression('x'))", r) + self.assertEqual("(dummy - ExpressionScalar('x'))", r) arith = ArithmeticPulseTemplate(pt, '-', scalar, identifier='id') self.assertEqual(super(ArithmeticPulseTemplate, arith).__repr__(), repr(arith))