Skip to content

Commit

Permalink
Merge branch 'master' into feat/roll_loop
Browse files Browse the repository at this point in the history
  • Loading branch information
terrorfisch committed Nov 15, 2021
2 parents 71327cc + cee8fc5 commit 7277ba8
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
1 change: 1 addition & 0 deletions changes.d/612.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`floor` will now return an integer in lambda expressions with numpy to allow usage in ForLoopPT range expression.
22 changes: 20 additions & 2 deletions qupulse/utils/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,30 @@ def numpy_compatible_add(*args) -> Union[sympy.Add, sympy.Array]:
}


def _float_arr_to_int_arr(float_arr):
"""Try to cast array to int64. Return original array if data is not representable."""
int_arr = float_arr.astype(numpy.int64)
if numpy.any(int_arr != float_arr):
# we either have a float that is too large or NaN
return float_arr
else:
return int_arr


def numpy_compatible_ceiling(input_value: Any) -> Any:
if isinstance(input_value, numpy.ndarray):
return numpy.ceil(input_value).astype(numpy.int64)
return _float_arr_to_int_arr(numpy.ceil(input_value))
else:
return sympy.ceiling(input_value)


def _floor_to_int(input_value: Any) -> Any:
if isinstance(input_value, numpy.ndarray):
return _float_arr_to_int_arr(numpy.floor(input_value))
else:
return sympy.floor(input_value)


def to_numpy(sympy_array: sympy.NDimArray) -> numpy.ndarray:
if isinstance(sympy_array, sympy.DenseNDimArray):
if len(sympy_array.shape) == 2:
Expand Down Expand Up @@ -327,7 +344,8 @@ def recursive_substitution(expression: sympy.Expr,
_numpy_environment = {**_base_environment, **numpy.__dict__}
_sympy_environment = {**_base_environment, **sympy.__dict__}

_lambdify_modules = [{'ceiling': numpy_compatible_ceiling, 'Broadcast': numpy.broadcast_to}, 'numpy', _special_functions]
_lambdify_modules = [{'ceiling': numpy_compatible_ceiling, 'floor': _floor_to_int,
'Broadcast': numpy.broadcast_to}, 'numpy', _special_functions]


def evaluate_compiled(expression: sympy.Expr,
Expand Down
10 changes: 10 additions & 0 deletions tests/pulses/bug_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from qupulse.pulses.repetition_pulse_template import RepetitionPulseTemplate
from qupulse.pulses.multi_channel_pulse_template import AtomicMultiChannelPulseTemplate
from qupulse.pulses.mapping_pulse_template import MappingPulseTemplate
from qupulse.pulses.loop_pulse_template import ForLoopPulseTemplate

from qupulse.pulses.plotting import plot

Expand Down Expand Up @@ -94,3 +95,12 @@ def test_issue_584_uninitialized_table_sample(self):

expected = np.full_like(times, fill_value=1.)
np.testing.assert_array_equal(expected, sampled)

def test_issue_612_for_loop_duration(self):
fpt = FunctionPulseTemplate('sin(2*pi*i*t*f)', '1/f')
pt = ForLoopPulseTemplate(fpt, 'i', 'floor(total_time*f)')
self.assertEqual(
(500 + 501) // 2,
pt.duration.evaluate_in_scope({'f': 1., 'total_time': 500})
)

0 comments on commit 7277ba8

Please sign in to comment.