diff --git a/changes.d/612.bugfix b/changes.d/612.bugfix new file mode 100644 index 000000000..65e80ec0a --- /dev/null +++ b/changes.d/612.bugfix @@ -0,0 +1 @@ +`floor` will now return an integer in lambda expressions with numpy to allow usage in ForLoopPT range expression. diff --git a/qupulse/utils/sympy.py b/qupulse/utils/sympy.py index b38d01b16..e27dc9032 100644 --- a/qupulse/utils/sympy.py +++ b/qupulse/utils/sympy.py @@ -215,13 +215,30 @@ def numpy_compatible_add(*args) -> Union[sympy.Add, sympy.Array]: } +def _float_arr_to_int_arr(float_arr): + """Try to cast array to int64. Return original array if data is not representable.""" + int_arr = float_arr.astype(numpy.int64) + if numpy.any(int_arr != float_arr): + # we either have a float that is too large or NaN + return float_arr + else: + return int_arr + + def numpy_compatible_ceiling(input_value: Any) -> Any: if isinstance(input_value, numpy.ndarray): - return numpy.ceil(input_value).astype(numpy.int64) + return _float_arr_to_int_arr(numpy.ceil(input_value)) else: return sympy.ceiling(input_value) +def _floor_to_int(input_value: Any) -> Any: + if isinstance(input_value, numpy.ndarray): + return _float_arr_to_int_arr(numpy.floor(input_value)) + else: + return sympy.floor(input_value) + + def to_numpy(sympy_array: sympy.NDimArray) -> numpy.ndarray: if isinstance(sympy_array, sympy.DenseNDimArray): if len(sympy_array.shape) == 2: @@ -327,7 +344,8 @@ def recursive_substitution(expression: sympy.Expr, _numpy_environment = {**_base_environment, **numpy.__dict__} _sympy_environment = {**_base_environment, **sympy.__dict__} -_lambdify_modules = [{'ceiling': numpy_compatible_ceiling, 'Broadcast': numpy.broadcast_to}, 'numpy', _special_functions] +_lambdify_modules = [{'ceiling': numpy_compatible_ceiling, 'floor': _floor_to_int, + 'Broadcast': numpy.broadcast_to}, 'numpy', _special_functions] def evaluate_compiled(expression: sympy.Expr, diff --git a/tests/pulses/bug_tests.py b/tests/pulses/bug_tests.py index fc1aead75..ef4035e9f 100644 --- a/tests/pulses/bug_tests.py +++ b/tests/pulses/bug_tests.py @@ -9,6 +9,7 @@ from qupulse.pulses.repetition_pulse_template import RepetitionPulseTemplate from qupulse.pulses.multi_channel_pulse_template import AtomicMultiChannelPulseTemplate from qupulse.pulses.mapping_pulse_template import MappingPulseTemplate +from qupulse.pulses.loop_pulse_template import ForLoopPulseTemplate from qupulse.pulses.plotting import plot @@ -94,3 +95,12 @@ def test_issue_584_uninitialized_table_sample(self): expected = np.full_like(times, fill_value=1.) np.testing.assert_array_equal(expected, sampled) + + def test_issue_612_for_loop_duration(self): + fpt = FunctionPulseTemplate('sin(2*pi*i*t*f)', '1/f') + pt = ForLoopPulseTemplate(fpt, 'i', 'floor(total_time*f)') + self.assertEqual( + (500 + 501) // 2, + pt.duration.evaluate_in_scope({'f': 1., 'total_time': 500}) + ) +