Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes.d/612.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`floor` will now return an integer in lambda expressions with numpy to allow usage in ForLoopPT range expression.
22 changes: 20 additions & 2 deletions qupulse/utils/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,30 @@ def numpy_compatible_add(*args) -> Union[sympy.Add, sympy.Array]:
}


def _float_arr_to_int_arr(float_arr):
"""Try to cast array to int64. Return original array if data is not representable."""
int_arr = float_arr.astype(numpy.int64)
if numpy.any(int_arr != float_arr):
# we either have a float that is too large or NaN
return float_arr
else:
return int_arr


def numpy_compatible_ceiling(input_value: Any) -> Any:
if isinstance(input_value, numpy.ndarray):
return numpy.ceil(input_value).astype(numpy.int64)
return _float_arr_to_int_arr(numpy.ceil(input_value))
else:
return sympy.ceiling(input_value)


def _floor_to_int(input_value: Any) -> Any:
if isinstance(input_value, numpy.ndarray):
return _float_arr_to_int_arr(numpy.floor(input_value))
else:
return sympy.floor(input_value)


def to_numpy(sympy_array: sympy.NDimArray) -> numpy.ndarray:
if isinstance(sympy_array, sympy.DenseNDimArray):
if len(sympy_array.shape) == 2:
Expand Down Expand Up @@ -327,7 +344,8 @@ def recursive_substitution(expression: sympy.Expr,
_numpy_environment = {**_base_environment, **numpy.__dict__}
_sympy_environment = {**_base_environment, **sympy.__dict__}

_lambdify_modules = [{'ceiling': numpy_compatible_ceiling, 'Broadcast': numpy.broadcast_to}, 'numpy', _special_functions]
_lambdify_modules = [{'ceiling': numpy_compatible_ceiling, 'floor': _floor_to_int,
'Broadcast': numpy.broadcast_to}, 'numpy', _special_functions]


def evaluate_compiled(expression: sympy.Expr,
Expand Down
10 changes: 10 additions & 0 deletions tests/pulses/bug_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from qupulse.pulses.repetition_pulse_template import RepetitionPulseTemplate
from qupulse.pulses.multi_channel_pulse_template import AtomicMultiChannelPulseTemplate
from qupulse.pulses.mapping_pulse_template import MappingPulseTemplate
from qupulse.pulses.loop_pulse_template import ForLoopPulseTemplate

from qupulse.pulses.plotting import plot

Expand Down Expand Up @@ -94,3 +95,12 @@ def test_issue_584_uninitialized_table_sample(self):

expected = np.full_like(times, fill_value=1.)
np.testing.assert_array_equal(expected, sampled)

def test_issue_612_for_loop_duration(self):
fpt = FunctionPulseTemplate('sin(2*pi*i*t*f)', '1/f')
pt = ForLoopPulseTemplate(fpt, 'i', 'floor(total_time*f)')
self.assertEqual(
(500 + 501) // 2,
pt.duration.evaluate_in_scope({'f': 1., 'total_time': 500})
)