@@ -88,7 +88,19 @@ def create_pulse_func(idx):
88
88
n_iters = 0 , # Number of iterations(episodes) until convergence
89
89
iter_seconds = [], # list containing the time taken for each iteration(episode) of the optimization
90
90
var_time = True , # Whether the optimization was performed with variable time
91
+ guess_params = []
91
92
)
93
+
94
+ self ._backup_result = Result ( # used as a backup in case the algorithm with shorter_pulses does not find an episode with infid<target_infid
95
+ objectives = objectives ,
96
+ time_interval = time_interval ,
97
+ start_local_time = time .localtime (),
98
+ n_iters = 0 ,
99
+ iter_seconds = [],
100
+ var_time = True ,
101
+ guess_params = []
102
+ )
103
+ self ._use_backup_result = False # if true, use self._backup_result as the final optimization result
92
104
93
105
#for the reward
94
106
self ._step_penalty = 1
@@ -228,21 +240,39 @@ def reset(self, seed=None):
228
240
self ._state = self ._initial
229
241
return self ._get_obs (), {}
230
242
231
- def result (self ):
243
+ def _save_result (self ):
232
244
"""
233
- Retrieve the results of the optimization process, including the optimized
245
+ Save the results of the optimization process, including the optimized
234
246
pulse sequences, final states, and performance metrics.
235
247
"""
236
- self ._result .end_local_time = time .localtime ()
237
- self ._result .n_iters = len (self ._result .iter_seconds )
238
- self ._result .optimized_params = self ._actions .copy () + [self ._result .total_seconds ] # If var_time is True, the last parameter is the evolution time
239
- self ._result ._optimized_controls = self ._actions .copy ()
240
- self ._result .start_local_time = time .strftime ("%Y-%m-%d %H:%M:%S" , self ._result .start_local_time ) # Convert to a string
241
- self ._result .end_local_time = time .strftime ("%Y-%m-%d %H:%M:%S" , self ._result .end_local_time ) # Convert to a string
242
- self ._result ._guess_controls = []
243
- self ._result ._optimized_H = [self ._H ]
244
- self ._result .guess_params = []
245
- return self ._result
248
+ result_obj = self ._backup_result if self ._use_backup_result else self ._result
249
+
250
+ if (self ._use_backup_result ):
251
+ self ._backup_result .iter_seconds = self ._result .iter_seconds .copy ()
252
+ self ._backup_result ._final_states = self ._result ._final_states .copy ()
253
+ self ._backup_result .infidelity = self ._result .infidelity
254
+
255
+ result_obj .end_local_time = time .localtime ()
256
+ result_obj .n_iters = len (self ._result .iter_seconds )
257
+ result_obj .optimized_params = self ._actions .copy () + [self ._result .total_seconds ] # If var_time is True, the last parameter is the evolution time
258
+ result_obj ._optimized_controls = self ._actions .copy ()
259
+ result_obj ._guess_controls = []
260
+ result_obj ._optimized_H = [self ._H ]
261
+
262
+
263
+ def result (self ):
264
+ """
265
+ Final conversions and return of optimization results
266
+ """
267
+ if self ._use_backup_result :
268
+ self ._backup_result .start_local_time = time .strftime ("%Y-%m-%d %H:%M:%S" , self ._backup_result .start_local_time ) # Convert to a string
269
+ self ._backup_result .end_local_time = time .strftime ("%Y-%m-%d %H:%M:%S" , self ._backup_result .end_local_time ) # Convert to a string
270
+ return self ._backup_result
271
+ else :
272
+ self ._save_result ()
273
+ self ._result .start_local_time = time .strftime ("%Y-%m-%d %H:%M:%S" , self ._result .start_local_time ) # Convert to a string
274
+ self ._result .end_local_time = time .strftime ("%Y-%m-%d %H:%M:%S" , self ._result .end_local_time ) # Convert to a string
275
+ return self ._result
246
276
247
277
def train (self ):
248
278
"""
@@ -266,7 +296,6 @@ class EarlyStopTraining(BaseCallback):
266
296
"""
267
297
def __init__ (self , verbose : int = 0 ):
268
298
super (EarlyStopTraining , self ).__init__ (verbose )
269
- self .stop_train = False
270
299
271
300
def _on_step (self ) -> bool :
272
301
"""
@@ -279,12 +308,18 @@ def _on_step(self) -> bool:
279
308
280
309
# Check if we need to stop training
281
310
if env .current_episode >= env .max_episodes :
282
- env ._result .message = f"Reached { env .max_episodes } episodes, stopping training."
311
+ if env ._use_backup_result == True :
312
+ env ._backup_result .message = f"Reached { env .max_episodes } episodes, stopping training. Return the last founded episode with infid < target_infid"
313
+ else :
314
+ env ._result .message = f"Reached { env .max_episodes } episodes, stopping training."
283
315
return False # Stop training
284
316
elif (env ._result .infidelity <= env ._fid_err_targ ) and not (env .shorter_pulses ):
285
317
env ._result .message = f"Stop training because an episode with infidelity <= target infidelity was found"
286
318
return False # Stop training
287
319
elif env .shorter_pulses :
320
+ if (env ._result .infidelity <= env ._fid_err_targ ): # if it finds an episode with infidelity lower than target infidelity, I'll save it in the meantime
321
+ env ._use_backup_result = True
322
+ env ._save_result ()
288
323
if len (env ._episode_info ) >= 100 :
289
324
last_100_episodes = env ._episode_info [- 100 :]
290
325
@@ -293,6 +328,7 @@ def _on_step(self) -> bool:
293
328
infid_condition = all (ep ['final_infidelity' ] <= env ._fid_err_targ for ep in last_100_episodes )
294
329
295
330
if steps_condition and infid_condition :
331
+ env ._use_backup_result = False
296
332
env ._result .message = "Training finished. No episode in the last 100 used fewer steps and infidelity was below target infid."
297
333
return False # Stop training
298
334
return True # Continue training
0 commit comments