Skip to content

Commit 3889a63

Browse files
committed
added _backup_result to keep track of the result even if the algorithm continues to search for solutions with shorter pulses
1 parent 4485f9d commit 3889a63

File tree

2 files changed

+62
-14
lines changed

2 files changed

+62
-14
lines changed

src/qutip_qoc/_rl.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,19 @@ def create_pulse_func(idx):
8888
n_iters = 0, # Number of iterations(episodes) until convergence
8989
iter_seconds = [], # list containing the time taken for each iteration(episode) of the optimization
9090
var_time = True, # Whether the optimization was performed with variable time
91+
guess_params=[]
9192
)
93+
94+
self._backup_result = Result( # used as a backup in case the algorithm with shorter_pulses does not find an episode with infid<target_infid
95+
objectives = objectives,
96+
time_interval = time_interval,
97+
start_local_time = time.localtime(),
98+
n_iters = 0,
99+
iter_seconds = [],
100+
var_time = True,
101+
guess_params=[]
102+
)
103+
self._use_backup_result = False # if true, use self._backup_result as the final optimization result
92104

93105
#for the reward
94106
self._step_penalty = 1
@@ -228,21 +240,39 @@ def reset(self, seed=None):
228240
self._state = self._initial
229241
return self._get_obs(), {}
230242

231-
def result(self):
243+
def _save_result(self):
232244
"""
233-
Retrieve the results of the optimization process, including the optimized
245+
Save the results of the optimization process, including the optimized
234246
pulse sequences, final states, and performance metrics.
235247
"""
236-
self._result.end_local_time = time.localtime()
237-
self._result.n_iters = len(self._result.iter_seconds)
238-
self._result.optimized_params = self._actions.copy() + [self._result.total_seconds] # If var_time is True, the last parameter is the evolution time
239-
self._result._optimized_controls = self._actions.copy()
240-
self._result.start_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._result.start_local_time) # Convert to a string
241-
self._result.end_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._result.end_local_time) # Convert to a string
242-
self._result._guess_controls = []
243-
self._result._optimized_H = [self._H]
244-
self._result.guess_params = []
245-
return self._result
248+
result_obj = self._backup_result if self._use_backup_result else self._result
249+
250+
if(self._use_backup_result):
251+
self._backup_result.iter_seconds = self._result.iter_seconds.copy()
252+
self._backup_result._final_states = self._result._final_states.copy()
253+
self._backup_result.infidelity = self._result.infidelity
254+
255+
result_obj.end_local_time = time.localtime()
256+
result_obj.n_iters = len(self._result.iter_seconds)
257+
result_obj.optimized_params = self._actions.copy() + [self._result.total_seconds] # If var_time is True, the last parameter is the evolution time
258+
result_obj._optimized_controls = self._actions.copy()
259+
result_obj._guess_controls = []
260+
result_obj._optimized_H = [self._H]
261+
262+
263+
def result(self):
264+
"""
265+
Final conversions and return of optimization results
266+
"""
267+
if self._use_backup_result:
268+
self._backup_result.start_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._backup_result.start_local_time) # Convert to a string
269+
self._backup_result.end_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._backup_result.end_local_time) # Convert to a string
270+
return self._backup_result
271+
else:
272+
self._save_result()
273+
self._result.start_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._result.start_local_time) # Convert to a string
274+
self._result.end_local_time = time.strftime("%Y-%m-%d %H:%M:%S", self._result.end_local_time) # Convert to a string
275+
return self._result
246276

247277
def train(self):
248278
"""
@@ -266,7 +296,6 @@ class EarlyStopTraining(BaseCallback):
266296
"""
267297
def __init__(self, verbose: int = 0):
268298
super(EarlyStopTraining, self).__init__(verbose)
269-
self.stop_train = False
270299

271300
def _on_step(self) -> bool:
272301
"""
@@ -279,12 +308,18 @@ def _on_step(self) -> bool:
279308

280309
# Check if we need to stop training
281310
if env.current_episode >= env.max_episodes:
282-
env._result.message = f"Reached {env.max_episodes} episodes, stopping training."
311+
if env._use_backup_result == True:
312+
env._backup_result.message = f"Reached {env.max_episodes} episodes, stopping training. Return the last founded episode with infid < target_infid"
313+
else:
314+
env._result.message = f"Reached {env.max_episodes} episodes, stopping training."
283315
return False # Stop training
284316
elif (env._result.infidelity <= env._fid_err_targ) and not(env.shorter_pulses):
285317
env._result.message = f"Stop training because an episode with infidelity <= target infidelity was found"
286318
return False # Stop training
287319
elif env.shorter_pulses:
320+
if(env._result.infidelity <= env._fid_err_targ): # if it finds an episode with infidelity lower than target infidelity, I'll save it in the meantime
321+
env._use_backup_result = True
322+
env._save_result()
288323
if len(env._episode_info) >= 100:
289324
last_100_episodes = env._episode_info[-100:]
290325

@@ -293,6 +328,7 @@ def _on_step(self) -> bool:
293328
infid_condition = all(ep['final_infidelity'] <= env._fid_err_targ for ep in last_100_episodes)
294329

295330
if steps_condition and infid_condition:
331+
env._use_backup_result = False
296332
env._result.message = "Training finished. No episode in the last 100 used fewer steps and infidelity was below target infid."
297333
return False # Stop training
298334
return True # Continue training

tests/test_result.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,18 @@ def sin_z_jax(t, r, **kwargs):
190190

191191
unitary_rl = state2state_rl._replace(
192192
objectives=[Objective(initial, H, target)],
193+
control_parameters = {
194+
"p": {"bounds": [(-13, 13)]},
195+
"__time__": {
196+
"guess": np.array([0.0]), #dummy value
197+
"bounds": [(0.0, 0.0)] #dummy value
198+
}
199+
},
200+
algorithm_kwargs={
201+
"fid_err_targ": 0.01,
202+
"alg": "RL",
203+
"max_iter": 300,
204+
},
193205
)
194206

195207

0 commit comments

Comments
 (0)