Skip to content

Commit

Permalink
Handle iter_stop=0 correctly
Browse files Browse the repository at this point in the history
Closes #58
  • Loading branch information
goerz committed Nov 21, 2019
1 parent 0569680 commit 5796f7f
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 1 deletion.
2 changes: 2 additions & 0 deletions HISTORY.rst
Expand Up @@ -8,6 +8,7 @@ History

* Added: Allow to pass `args` to time-dependent control functions (`#56`_, thanks to `@timohillmann`_)
* Changed: Renamed ``krotov.structural_conversions`` to ``krotov.conversions``
* Bugfix: Crash when ``krotov.optimize_pulses`` is called with ``iter_stop=0`` (`#58`_)

0.4.1 (2019-10-11)
------------------
Expand Down Expand Up @@ -89,3 +90,4 @@ History
.. _#52: https://github.com/qucontrol/krotov/issues/42
.. _#54: https://github.com/qucontrol/krotov/issues/54
.. _#56: https://github.com/qucontrol/krotov/issues/56
.. _#58: https://github.com/qucontrol/krotov/issues/58
3 changes: 2 additions & 1 deletion src/krotov/optimize.py
Expand Up @@ -312,7 +312,7 @@ def optimize_pulses(
forward_states0 = forward_states = None

info = None
optimized_pulses = guess_pulses
optimized_pulses = copy.deepcopy(guess_pulses)
if info_hook is not None:
info = info_hook(
objectives=objectives,
Expand All @@ -339,6 +339,7 @@ def optimize_pulses(
result.tlist = tlist
result.objectives = objectives
result.guess_controls = guess_controls
result.optimized_controls = optimized_pulses
result.controls_mapping = pulses_mapping
if continue_from is None:
# we only store information about the "0" iteration if we're starting a
Expand Down
67 changes: 67 additions & 0 deletions tests/test_krotov.py
@@ -1,5 +1,6 @@
"""High-level tests for `krotov` package."""

import io
import logging
import os
from copy import deepcopy
Expand Down Expand Up @@ -162,6 +163,42 @@ def S(t):
return objectives, pulse_options, tlist


@pytest.mark.parametrize('iter_stop', [0, -1])
def test_zero_iterations(iter_stop, simple_state_to_state_system):
objectives, pulse_options, tlist = simple_state_to_state_system

with io.StringIO() as log_fh:

result = krotov.optimize_pulses(
objectives,
pulse_options=pulse_options,
tlist=tlist,
propagator=krotov.propagators.expm,
chi_constructor=krotov.functionals.chis_re,
store_all_pulses=True,
info_hook=krotov.info_hooks.print_table(
J_T=krotov.functionals.J_T_re, out=log_fh
),
iter_stop=iter_stop,
skip_initial_forward_propagation=True,
)

log = log_fh.getvalue()

assert len(log.splitlines()) == 2
assert result.message == 'Reached 0 iterations'
assert len(result.guess_controls) == 1 # one control field
assert len(result.guess_controls) == len(result.optimized_controls)
assert len(result.guess_controls[0]) == len(result.tlist)
assert len(result.optimized_controls[0]) == len(result.tlist)
for (c1, c2) in zip(result.guess_controls, result.optimized_controls):
assert np.all(c1 == c2)
for pulses_for_iteration in result.all_pulses:
for pulse in pulses_for_iteration:
# the pulses are defined on the *intervals* of tlist
assert len(pulse) == len(result.tlist) - 1


def test_continue_optimization(
simple_state_to_state_system, request, tmpdir, caplog
):
Expand Down Expand Up @@ -203,12 +240,42 @@ def test_continue_optimization(
in caplog.text
)

# fmt: off
assert len(oct_result1.iters) == 4 # 0 ... 3
assert len(oct_result1.iter_seconds) == 4
assert len(oct_result1.info_vals) == 4
assert len(oct_result1.all_pulses) == 4
assert len(oct_result1.states) == 1
assert len(oct_result1.guess_controls) == 1 # one control field
assert len(oct_result1.guess_controls) == len(oct_result1.optimized_controls)
assert len(oct_result1.guess_controls[0]) == len(oct_result1.tlist)
assert len(oct_result1.optimized_controls[0]) == len(oct_result1.tlist)
for pulses_for_iteration in oct_result1.all_pulses:
for pulse in pulses_for_iteration:
# the pulses are defined on the *intervals* of tlist
assert len(pulse) == len(oct_result1.tlist) - 1
assert "3 iterations" in oct_result1.message
# fmt: on

# repeating the same optimization only propagates the guess pulse
# (we'll check this later while verifying the output of the log file)
krotov.optimize_pulses(
objectives,
pulse_options=pulse_options,
tlist=tlist,
propagator=krotov.propagators.expm,
chi_constructor=krotov.functionals.chis_re,
store_all_pulses=True,
info_hook=krotov.info_hooks.print_table(
J_T=krotov.functionals.J_T_re, out=log_fh
),
check_convergence=krotov.convergence.Or(
krotov.convergence.check_monotonic_error,
krotov.convergence.dump_result(dumpfile, every=2),
),
continue_from=oct_result1,
iter_stop=3,
)

# another 2 iterations
oct_result2 = krotov.optimize_pulses(
Expand Down
2 changes: 2 additions & 0 deletions tests/test_krotov/oct.log
Expand Up @@ -3,6 +3,8 @@
1 7.65e-01 2.33e-02 7.88e-01 -2.35e-01 -2.12e-01 1
2 5.56e-01 2.07e-02 5.77e-01 -2.09e-01 -1.88e-01 1
3 3.89e-01 1.66e-02 4.05e-01 -1.67e-01 -1.51e-01 1
iter. J_T ∫gₐ(t)dt J ΔJ_T ΔJ secs
0 3.89e-01 0.00e+00 3.89e-01 n/a n/a 0
iter. J_T ∫gₐ(t)dt J ΔJ_T ΔJ secs
0 3.89e-01 0.00e+00 3.89e-01 n/a n/a 0
4 2.65e-01 1.23e-02 2.77e-01 -1.24e-01 -1.11e-01 1
Expand Down

0 comments on commit 5796f7f

Please sign in to comment.