/
optimize.py
872 lines (804 loc) · 34.2 KB
/
optimize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
import copy
import inspect
import logging
import time
from functools import partial
import numpy as np
from qutip import Qobj
from qutip.parallel import serial_map
from .conversions import (
control_onto_interval,
discretize,
extract_controls,
extract_controls_mapping,
plug_in_pulse_values,
pulse_onto_tlist,
pulse_options_dict_to_list,
)
from .info_hooks import chain
from .mu import derivative_wrt_pulse
from .propagators import Propagator, expm
from .result import Result
from .second_order import _overlap
from .shapes import one_shape, zero_shape
__all__ = ['optimize_pulses']
def optimize_pulses(
objectives,
pulse_options,
tlist,
*,
propagator,
chi_constructor,
mu=None,
sigma=None,
iter_start=0,
iter_stop=5000,
check_convergence=None,
info_hook=None,
modify_params_after_iter=None,
storage='array',
parallel_map=None,
store_all_pulses=False,
continue_from=None,
skip_initial_forward_propagation=False,
norm=None,
overlap=None
):
r"""Use Krotov's method to optimize towards the given `objectives`.
Optimize all time-dependent controls found in the Hamiltonians or
Liouvillians of the given `objectives`.
Args:
objectives (list[Objective]): List of objectives
pulse_options (dict): Mapping of time-dependent controls found in the
Hamiltonians of the objectives to a dictionary of options for that
control. There must be options given for *every* control. As numpy
arrays are unhashable and thus cannot be used as dict keys, the
options for a control that is an array must be set using the key
``id(control)`` (see the example below). The options of any
particular control *must* contain the following keys:
* ``'lambda_a'``: the Krotov step size (float value). This governs
the overall magnitude of the pulse update. Large values result in
small updates. Small values may lead to sharp spikes and
numerical instability.
* ``'update_shape'`` : Function S(t) in the range [0, 1] that
scales the pulse update for the pulse value at t. This can be
used to ensure boundary conditions (S(0) = S(T) = 0), and enforce
smooth switch-on and switch-off. This can be a callable that
takes a single argument `t`; or the values 1 or 0 for a constant
update-shape. The value 0 disables the optimization of that
particular control.
In addition, the following keys *may* occur:
* ``'args'``: If the control is a callable with arguments
``(t, args)`` (as required by QuTiP), a dict of argument values
to pass as `args`. If ``'args'`` is not specified via the
`pulse_options`, controls will be discretized using the default
``args=None``.
For example, for `objectives` that contain a Hamiltonian of the
form ``[H0, [H1, u], [H2, g]]``, where ``H0``, ``H1``, and ``H2``
are :class:`~qutip.Qobj` instances, ``u`` is a numpy array
.. doctest::
>>> u = numpy.zeros(1000)
and ``g`` is a control function
.. doctest::
>>> def g(t, args):
... E0 = args.get('E0', 0.0)
... return E0
then a possible value for `pulse_options` would look like this:
.. doctest::
>>> from krotov.shapes import flattop
>>> from functools import partial
>>> pulse_options = {
... id(u): {'lambda_a': 1.0, 'update_shape': 1},
... g: dict(
... lambda_a=1.0,
... update_shape=partial(
... flattop, t_start=0, t_stop=10, t_rise=1.5
... ),
... args=dict(E0=1.0)
... )
... }
The use of :class:`dict` and the ``{...}`` syntax are completely
equivalent, but :class:`dict` is better for nested indentation.
tlist (numpy.ndarray): Array of time grid values, cf.
:func:`~qutip.mesolve.mesolve`
propagator (callable or list[callable]): Function that propagates the
state backward or forwards in time by a single time step, between
two points in `tlist`. Alternatively, a list of functions, one for
each objective. If the propagator is stateful, it should be an
instance of :class:`krotov.propagators.Propagator`. See
:mod:`krotov.propagators` for details.
chi_constructor (callable): Function that calculates the boundary
condition for the backward propagation. This is where the
final-time functional (indirectly) enters the optimization.
See :mod:`krotov.functionals` for details.
mu (None or callable): Function that calculates the derivative
$\frac{\partial H}{\partial\epsilon}$ for an equation of motion
$\dot{\phi}(t) = -i H[\phi(t)]$ of an abstract operator $H$ and an
abstract state $\phi$. If None, defaults to
:func:`krotov.mu.derivative_wrt_pulse`, which covers the standard
Schrödinger and master equations. See :mod:`krotov.mu` for a
full explanation of the role of `mu` in the optimization, and the
required function signature.
sigma (None or krotov.second_order.Sigma): Function (instance of a
:class:`.Sigma` subclass) that calculates the
second-order contribution. If None, the first-order Krotov method
is used.
iter_start (int): The formal iteration number at which to start the
optimization
iter_stop (int): The iteration number after which to end the
optimization, whether or not convergence has been reached
check_convergence (None or callable): Function that determines whether
the optimization has converged. If None, the optimization will only
end when `iter_stop` is reached. See :mod:`krotov.convergence` for
details.
info_hook (None or callable): Function that is called after each
iteration of the optimization, for the purpose of analysis. Any
value returned by `info_hook` (e.g. an evaluated functional
:math:`J_T`) will be stored, for each iteration, in the `info_vals`
attribute of the returned :class:`.Result`. The `info_hook` must
have the same signature as
:func:`krotov.info_hooks.print_debug_information`. It should not
modify its arguments in any way, except for `shared_data`.
modify_params_after_iter (None or callable): Function that is called
after each iteration, which may modify its arguments for certain
advanced use cases, such as dynamically adjusting `lambda_vals`, or
applying spectral filters to the `optimized_pulses`. It has the
same interface as `info_hook` but should not return anything. The
`modify_params_after_iter` function is called immediately before
`info_hook`, and can transfer arbitrary data to any subsequent
`info_hook` via the `shared_data` argument.
storage (callable): Storage constructor for the storage of
propagated states. Must accept an integer parameter `N` and return
an empty array-like container of length `N`. The default value
'array' is equivalent to
``functools.partial(numpy.empty, dtype=object)``.
parallel_map (callable or tuple or None): Parallel function evaluator.
If given as a callable, the argument must have the same
specification as :func:`qutip.parallel.serial_map`.
A value of None is the same as passing
:func:`qutip.parallel.serial_map`. If given as a tuple, that tuple
must contain three callables, each of which has the same
specification as :func:`qutip.parallel.serial_map`. These three
callables are used to parallelize (1) the initial
forward-propagation, (2) the backward-propagation under the guess
pulses, and (3) the forward-propagation by a single time step under
the optimized pulses. See :mod:`krotov.parallelization` for
details.
store_all_pulses (bool): Whether or not to store the optimized pulses
from *all* iterations in :class:`.Result`.
continue_from (None or Result): If given, continue an optimization from
a previous :class:`.Result`. The result must have identical
`objectives`.
skip_initial_forward_propagation (bool): If given as `True` together
with `continue_from`, skip the initial forward propagation ("zeroth
iteration"), and take the forward-propagated states from
:attr:`.Result.states` instead.
norm (callable or None): A single-argument function to calculate the
norm of states. If None, delegate to the :meth:`~qutip.Qobj.norm`
method of the states.
overlap (callable or None): A two-argument function to calculate the
complex overlap of two states. If None, delegate to
:meth:`qutip.Qobj.overlap` for Hilbert space states and to the
Hilbert-Schmidt norm $\tr[\rho_1^\dagger \rho2]$ for density
matrices or operators.
Returns:
Result: The result of the optimization.
Raises:
ValueError: If any controls are not real-valued, or if any update
shape is not a real-valued function in the range [0, 1]; if using
`continue_from` with a :class:`.Result` with differing
`objectives`; if there are any required keys missing in
`pulse_options`.
"""
logger = logging.getLogger('krotov')
# Initialization
logger.info("Initializing optimization with Krotov's method")
if mu is None:
mu = derivative_wrt_pulse
second_order = sigma is not None
if norm is None:
norm = lambda state: state.norm()
if overlap is None:
overlap = _overlap
if modify_params_after_iter is not None:
# From a technical perspective, the `modify_params_after_iter` is
# really just another info_hook, the only difference is the
# convention that info_hooks shouldn't modify the parameters.
if info_hook is None:
info_hook = modify_params_after_iter
else:
info_hook = chain(modify_params_after_iter, info_hook)
if isinstance(propagator, list):
propagators = propagator
assert len(propagators) == len(objectives)
else:
propagators = [copy.deepcopy(propagator) for _ in objectives]
# copy.deepcopy will only do something on Propagator objects. For
# functions (even with closures), it just returns the same function.
_check_propagators_interface(propagators, logger)
adjoint_objectives = [obj.adjoint() for obj in objectives]
if storage == 'array':
storage = partial(np.empty, dtype=object)
if parallel_map is None:
parallel_map = serial_map
if not isinstance(parallel_map, (tuple, list)):
parallel_map = (parallel_map, parallel_map, parallel_map)
(
guess_controls, # "controls": sampled on the time grid
guess_pulses, # "pulses": sampled on the time grid intervals
pulses_mapping, # keep track of where to plug in pulse values
lambda_vals, # Krotov step width λₐ, for each control
shape_arrays, # update shape S(t), per control, sampled on intervals
) = _initialize_krotov_controls(objectives, pulse_options, tlist)
if continue_from is not None:
guess_controls, guess_pulses = _restore_from_previous_result(
continue_from, objectives, tlist, store_all_pulses
)
g_a_integrals = np.zeros(len(guess_pulses))
# ∫gₐ(t)dt is a very useful measure of whether λₐ is too small (large
# ∫gₐ(t)dt, relative to the pulse amplitude), and whether we're approaching
# convergence ("speeding up" for increasing values, "slowing down" for
# decreasing values)
if continue_from is None:
result = Result()
result.start_local_time = time.localtime()
else:
result = copy.deepcopy(continue_from)
# Initial forward-propagation
tic = time.time()
if skip_initial_forward_propagation:
forward_states = _skip_initial_forward_propagation(
objectives, continue_from, sigma, logger
)
else:
forward_states = parallel_map[0](
_forward_propagation,
list(range(len(objectives))),
(
objectives,
guess_pulses,
pulses_mapping,
tlist,
propagators,
storage,
),
)
toc = time.time()
fw_states_T = [states[-1] for states in forward_states]
tau_vals = np.array(
[
overlap(obj.target, state_T)
for (state_T, obj) in zip(fw_states_T, objectives)
]
)
if second_order:
forward_states0 = forward_states # ∀t: Δϕ=0, for iteration 0
else:
# the forward-propagated states only need to be stored for the second
# order update
forward_states0 = forward_states = None
info = None
optimized_pulses = copy.deepcopy(guess_pulses)
if info_hook is not None:
info = info_hook(
objectives=objectives,
adjoint_objectives=adjoint_objectives,
backward_states=None,
forward_states=forward_states,
forward_states0=forward_states0,
guess_pulses=guess_pulses,
optimized_pulses=optimized_pulses,
g_a_integrals=g_a_integrals,
lambda_vals=lambda_vals,
shape_arrays=shape_arrays,
fw_states_T=fw_states_T,
tlist=tlist,
tau_vals=tau_vals,
start_time=tic,
stop_time=toc,
iteration=0,
info_vals=[],
shared_data={},
)
# Initialize Result object
result.tlist = tlist
result.objectives = objectives
result.guess_controls = guess_controls
result.optimized_controls = optimized_pulses
result.controls_mapping = pulses_mapping
if continue_from is None:
# we only store information about the "0" iteration if we're starting a
# new optimization
if info is not None:
result.info_vals.append(info)
result.iters.append(0)
result.iter_seconds.append(int(toc - tic))
if not np.all(tau_vals == None): # noqa
result.tau_vals.append(tau_vals)
if store_all_pulses:
result.all_pulses.append(guess_pulses)
else:
iter_start = continue_from.iters[-1]
logger.info(
"Continuing from previous result, with iteration %d",
iter_start + 1,
)
result.states = fw_states_T
# Main optimization loop
for krotov_iteration in range(iter_start + 1, iter_stop + 1):
logger.info("Started Krotov iteration %d", krotov_iteration)
tic = time.time()
# Boundary condition for the backward propagation
# -- this is where the functional enters the optimization.
# `fw_states_T` are the states forward-propagated under the guess pulse
# of the current iteration, which is the optimized pulse of the
# previous iteration. This is how we get the `fw_states_T` here: they
# are left over from the forward-propagation in the previous iteration.
chi_states = chi_constructor(
fw_states_T=fw_states_T, objectives=objectives, tau_vals=tau_vals
)
chi_norms = [norm(chi) for chi in chi_states]
# normalizing χ improves numerical stability; the norm then has to be
# taken into account when calculating Δϵ
chi_states = [chi / nrm for (chi, nrm) in zip(chi_states, chi_norms)]
# Backward propagation
backward_states = parallel_map[1](
_backward_propagation,
list(range(len(objectives))),
(
chi_states,
adjoint_objectives,
guess_pulses,
pulses_mapping,
tlist,
propagators,
storage,
),
)
# Forward propagation and pulse update
logger.info("Started forward propagation/pulse update")
if second_order:
forward_states = [
storage(len(tlist)) for _ in range(len(objectives))
]
g_a_integrals[:] = 0.0
if second_order:
# In the update for the pulses in the first time interval, we use
# the states at t=0. Hence, Δϕ(t=0) = 0
delta_phis = [
Qobj(np.zeros(shape=chi_states[k].shape))
for k in range(len(objectives))
]
if second_order:
for i_obj in range(len(objectives)):
forward_states[i_obj][0] = objectives[i_obj].initial_state
delta_eps = [
np.zeros(len(tlist) - 1, dtype=np.complex128) for _ in guess_pulses
]
optimized_pulses = copy.deepcopy(guess_pulses)
fw_states = [obj.initial_state for obj in objectives]
for time_index in range(len(tlist) - 1): # iterate over time intervals
dt = tlist[time_index + 1] - tlist[time_index]
if second_order:
σ = sigma(tlist[time_index] + 0.5 * dt)
# pulse update
for (i_pulse, guess_pulse) in enumerate(guess_pulses):
for (i_obj, objective) in enumerate(objectives):
χ = backward_states[i_obj][time_index]
μ = mu(
objectives,
i_obj,
guess_pulses,
pulses_mapping,
i_pulse,
time_index,
)
Ψ = fw_states[i_obj]
update = overlap(χ, μ(Ψ)) # ⟨χ|μ|Ψ⟩ ∈ ℂ
update *= chi_norms[i_obj]
if second_order:
update += 0.5 * σ * overlap(delta_phis[i_obj], μ(Ψ))
delta_eps[i_pulse][time_index] += update
λₐ = lambda_vals[i_pulse]
S_t = shape_arrays[i_pulse][time_index]
Δϵ = (S_t / λₐ) * delta_eps[i_pulse][time_index].imag # ∈ ℝ
g_a_integrals[i_pulse] += abs(Δϵ) ** 2 * dt # dt may vary!
optimized_pulses[i_pulse][time_index] += Δϵ
# forward propagation
fw_states = parallel_map[2](
_forward_propagation_step,
list(range(len(objectives))),
(
fw_states,
objectives,
optimized_pulses,
pulses_mapping,
tlist,
time_index,
propagators,
),
)
if second_order:
# Δϕ(t + dt), to be used for the update in the next interval
delta_phis = [
fw_states[k] - forward_states0[k][time_index + 1]
for k in range(len(objectives))
]
# storage
for i_obj in range(len(objectives)):
forward_states[i_obj][time_index + 1] = fw_states[i_obj]
logger.info("Finished forward propagation/pulse update")
fw_states_T = fw_states
tau_vals = np.array(
[
overlap(obj.target, fw_state_T)
for (fw_state_T, obj) in zip(fw_states_T, objectives)
]
)
toc = time.time()
# Display information about iteration
if info_hook is not None:
info = info_hook(
objectives=objectives,
adjoint_objectives=adjoint_objectives,
backward_states=backward_states,
forward_states=forward_states,
forward_states0=forward_states0,
fw_states_T=fw_states_T,
tlist=tlist,
guess_pulses=guess_pulses,
optimized_pulses=optimized_pulses,
g_a_integrals=g_a_integrals,
lambda_vals=lambda_vals,
shape_arrays=shape_arrays,
tau_vals=tau_vals,
start_time=tic,
stop_time=toc,
info_vals=result.info_vals,
shared_data={},
iteration=krotov_iteration,
)
# Update optimization `result` with info from finished iteration
result.iters.append(krotov_iteration)
result.iter_seconds.append(int(toc - tic))
if info is not None:
result.info_vals.append(info)
if not np.all(tau_vals == None): # noqa
result.tau_vals.append(tau_vals)
result.optimized_controls = optimized_pulses
# pulses (time intervals) will be converted to controls (time grid
# points) farther below in "Finalize"
if store_all_pulses:
# we need to make a copy, so that the conversion in "Finalize"
# doesn't affect `all_pulses` as well.
result.all_pulses.append(copy.deepcopy(optimized_pulses))
result.states = fw_states_T
logger.info("Finished Krotov iteration %d", krotov_iteration)
# Convergence check
msg = None
if check_convergence is not None:
msg = check_convergence(result)
if bool(msg) is True: # this is not an anti-pattern!
result.message = "Reached convergence"
if isinstance(msg, str):
result.message += ": " + msg
break
else:
# prepare for next iteration
guess_pulses = optimized_pulses
if second_order:
sigma.refresh(
forward_states=forward_states,
forward_states0=forward_states0,
chi_states=chi_states,
chi_norms=chi_norms,
optimized_pulses=optimized_pulses,
guess_pulses=guess_pulses,
objectives=objectives,
result=result,
)
forward_states0 = forward_states
else: # optimization finished without `check_convergence` break
result.message = "Reached %d iterations" % max(iter_start, iter_stop)
# Finalize
result.end_local_time = time.localtime()
for i, pulse in enumerate(optimized_pulses):
result.optimized_controls[i] = pulse_onto_tlist(pulse)
return result
def _shape_val_to_callable(val):
if val == 1:
return one_shape
elif val == 0:
return zero_shape
else:
if callable(val):
return val
else:
raise ValueError("update_shape must be a callable")
def _enforce_shape_array_range(shape_array):
"""Enforce values ∈ [0, 1] in shape array, with some room for
rounding errors that will be clipped away.
"""
assert not np.iscomplexobj(shape_array) # `discretize` should catch this
# the rounding errors can be introduced by control_onto_interval, and
# result in values slightly below 0 or above 1. We allow a generous margin
# of ±0.01; if something nonsensical is passed as a shape, we can be pretty
# sure that it will deviate by a significantly larger error.
if np.min(shape_array) < -0.01 or np.max(shape_array) > 1.01:
raise ValueError(
"Update shapes ('update_shape' in pulse options-dict) must have "
"values in the range [0, 1], not [%s, %s]"
% (np.min(shape_array), np.max(shape_array))
)
return np.clip(shape_array, a_min=0.0, a_max=1.0)
def _check_propagators_interface(propagators, logger):
"""Warn if any of the propagators do not have the expected interface.
In order to pass muster, each propagator must either have the same
interface as :func:`krotov.propagators.expm`, or
:meth:`krotov.propagators.Propagator.__call__`
"""
for propagator in propagators:
spec = inspect.getfullargspec(propagator)
spec_func = inspect.getfullargspec(expm)
spec_obj = inspect.getfullargspec(Propagator.__call__)
if spec != spec_func and spec != spec_obj:
logger.warning(
"The propagator %s does not have the expected interface.",
propagator,
)
def _initialize_krotov_controls(objectives, pulse_options, tlist):
"""Extract discretized guess controls and pulses from `objectives`, and
return them with the associated mapping and option data"""
guess_controls = extract_controls(objectives)
pulses_mapping = extract_controls_mapping(objectives, guess_controls)
options_list = pulse_options_dict_to_list(pulse_options, guess_controls)
try:
guess_controls = [
discretize(
control,
tlist,
args=(pulse_options[control].get('args', None),),
)
for control in guess_controls
]
except (TypeError, np.ComplexWarning) as exc_info:
raise ValueError(
"Cannot discretize controls: %s. Note that "
"all controls must be real-valued. Complex controls must be "
"split into an independent real and imaginary part in the "
"objectives before passing them to the optimization" % exc_info
)
guess_pulses = [ # defined on the tlist intervals
control_onto_interval(control) for control in guess_controls
]
try:
lambda_vals = np.array(
[float(options['lambda_a']) for options in options_list]
)
except KeyError:
raise ValueError(
"Each value in pulse_options must be a dict that contains "
"the key 'lambda_a'."
)
shape_arrays = []
for options in options_list:
try:
S = discretize(
_shape_val_to_callable(options['update_shape']), tlist, args=()
)
except KeyError:
raise ValueError(
"Each value in pulse_options must be a dict that contains "
"the key 'update_shape'."
)
except (TypeError, np.ComplexWarning) as exc_info:
raise ValueError(
"Update shapes ('update_shape' in pulse options-dict) must be "
"real-valued: %s" % exc_info
)
shape_arrays.append(
_enforce_shape_array_range(control_onto_interval(S))
)
return (
guess_controls,
guess_pulses,
pulses_mapping,
lambda_vals,
shape_arrays,
)
def _restore_from_previous_result(result, objectives, tlist, store_all_pulses):
"""Load `guess_controls` and `guess_pulses` from the given Result object.
Raises:
ValueError: if `result` is incompatible with the given `objectives` and
`tlist`.
"""
if not isinstance(result, Result):
raise ValueError("Continuation is only possible from a Result object")
if len(objectives) != len(result.objectives):
raise ValueError(
"When continuing from a previous Result, the number of "
"objectives must be the same"
)
for (a, b) in zip(objectives, result.objectives):
if a != b:
raise ValueError(
"When continuing from a previous Result, the objectives must "
"remain unchanged"
)
if store_all_pulses:
if len(result.all_pulses) == 0:
raise ValueError(
"The store_all_pulses parameter cannot be changed when "
"continuing from a previous Result. Pass it as False."
)
else:
if len(result.all_pulses) > 0:
raise ValueError(
"The store_all_pulses parameter cannot be changed when "
"continuing from a previous Result. Pass it as True."
)
try:
if np.max(np.abs(np.array(tlist) - np.array(result.tlist))) > 1e-5:
# we're ok with pretty significant rounding errors: if someone is
# doing something stupid (like changing physical units), the error
# should be large
raise ValueError("tlist does not match")
except ValueError:
# if the size of the tlist has changed, the "if" statement will already
# raise a ValueError
raise ValueError(
"When continuing from a previous Result, the controls must be "
"defined on the same time grid"
)
nt = len(tlist)
guess_controls = []
for control in result.optimized_controls:
if len(control) == nt - 1:
# the Result was dumped before the optimization was complete
# (see krotov.konvergence.dump_result), so that the
# `optimized_controls` are actually pulses (defined on the
# intervals)
guess_controls.append(pulse_onto_tlist(control))
elif len(control) == nt:
# the Result was dumped after the optimization was complete, so
# that the `optmized_controls` are in fact controls (defined on the
# points of tlist)
guess_controls.append(control)
else:
# this should never happen
raise ValueError(
"Invalid Result: optimized_controls and tlist are incongruent"
)
guess_pulses = [ # defined on the tlist intervals
control_onto_interval(control) for control in guess_controls
]
return guess_controls, guess_pulses
def _skip_initial_forward_propagation(objectives, result, sigma, logger):
"""Extract forward_states from existing result"""
if sigma is not None:
raise ValueError(
"skip_initial_forward_propagation is incompatible with "
"second order Krotov (sigma is not None)"
)
if result is not None:
# forwards_states is normally a list of storage objects that store
# the *entire* forward propagation for each objective. The stored
# states are only needed for second order Krotov. When skipping the
# forward propagation, we'll use a fake "storage object" that
# contains only the final state. That'll be sufficient for setting
# fw_states_T and tau_vals below
forward_states = [[state] for state in result.states]
else:
logger.warning(
"You should not use `skip_initial_forward_propagation` unless "
"you are also passing `continue_from`"
)
# If the chi_constructor does not depend on the fw_states_T (like
# chis_re), and you're using first order Krotov, you don't actually
# have to do the initial forward propagation. It's not really worth
# it to skip that propagation, though, as we usually have to do
# hundreds of iterations. So, this is an undocumented feature.
forward_states = [[None] for _ in objectives]
return forward_states
def _forward_propagation(
i_objective,
objectives,
pulses,
pulses_mapping,
tlist,
propagators,
storage,
store_all=True,
):
"""Forward propagation of the initial state of a single objective over the
entire `tlist`"""
logger = logging.getLogger('krotov')
logger.info(
"Started initial forward propagation of objective %d", i_objective
)
obj = objectives[i_objective]
state = obj.initial_state
if store_all:
storage_array = storage(len(tlist))
storage_array[0] = state
mapping = pulses_mapping[i_objective]
for time_index in range(len(tlist) - 1): # index over intervals
H = plug_in_pulse_values(obj.H, pulses, mapping[0], time_index)
c_ops = [
plug_in_pulse_values(c_op, pulses, mapping[ic + 1], time_index)
for (ic, c_op) in enumerate(obj.c_ops)
]
dt = tlist[time_index + 1] - tlist[time_index]
state = propagators[i_objective](
H, state, dt, c_ops, initialize=(time_index == 0)
)
if store_all:
storage_array[time_index + 1] = state
logger.info(
"Finished initial forward propagation of objective %d", i_objective
)
if store_all:
return storage_array
else:
return state
def _backward_propagation(
i_state,
chi_states,
adjoint_objectives,
pulses,
pulses_mapping,
tlist,
propagators,
storage,
):
"""Backward propagation of chi_states[i_state] over the entire `tlist`"""
logger = logging.getLogger('krotov')
logger.info("Started backward propagation of state %d", i_state)
state = chi_states[i_state]
obj = adjoint_objectives[i_state]
storage_array = storage(len(tlist))
storage_array[-1] = state
mapping = pulses_mapping[i_state]
for time_index in range(len(tlist) - 2, -1, -1): # index bw over intervals
H = plug_in_pulse_values(
obj.H, pulses, mapping[0], time_index, conjugate=True
)
c_ops = [
plug_in_pulse_values(c_op, pulses, mapping[ic + 1], time_index)
for (ic, c_op) in enumerate(obj.c_ops)
]
dt = tlist[time_index + 1] - tlist[time_index]
state = propagators[i_state](
H,
state,
dt,
c_ops,
backwards=True,
initialize=(time_index == len(tlist) - 2),
)
storage_array[time_index] = state
logger.info("Finished backward propagation of state %d", i_state)
return storage_array
def _forward_propagation_step(
i_state,
states,
objectives,
pulses,
pulses_mapping,
tlist,
time_index,
propagators,
):
"""Forward-propagate states[i_state] by a single time step"""
state = states[i_state]
obj = objectives[i_state]
mapping = pulses_mapping[i_state]
H = plug_in_pulse_values(obj.H, pulses, mapping[0], time_index)
c_ops = [
plug_in_pulse_values(c_op, pulses, mapping[ic + 1], time_index)
for (ic, c_op) in enumerate(obj.c_ops)
]
dt = tlist[time_index + 1] - tlist[time_index]
return propagators[i_state](
H, state, dt, c_ops, initialize=(time_index == 0)
)