diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0bcf23a..0c027ee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -92,3 +92,5 @@ repos: - jaxlib - diffrax - pytest + - gymnasium + - stable-baselines3 diff --git a/requirements.txt b/requirements.txt index a08565e..50da21d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,5 @@ qutip>=5.0.1 qutip-qtrl qutip-jax pre-commit +gymnasium>=0.29.1 +stable-baselines3>=2.3.2 diff --git a/setup.cfg b/setup.cfg index 3903195..7872023 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,6 +35,8 @@ install_requires = qutip-qtrl qutip-jax numpy>=1.16.6,<2.0 + gymnasium>=0.29.1 + stable-baselines3>=2.3.2 setup_requires = cython>=1.0 packaging diff --git a/src/qutip_qoc/_rl.py b/src/qutip_qoc/_rl.py index bb12900..9a75a55 100644 --- a/src/qutip_qoc/_rl.py +++ b/src/qutip_qoc/_rl.py @@ -1,8 +1,11 @@ """ -This module contains ... +This module contains functions that implement quantum optimal control +using reinforcement learning (RL) techniques, allowing for the optimization +of control pulse sequences in quantum systems. """ import qutip as qt -from qutip import Qobj, QobjEvo +from qutip import Qobj +from qutip_qoc import Result import numpy as np @@ -10,79 +13,185 @@ from gymnasium import spaces from stable_baselines3 import PPO from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.callbacks import BaseCallback +import time -class _RL(gym.Env): # TODO: this should be similar to your GymQubitEnv(gym.Env) implementation + +class _RL(gym.Env): """ - Class for storing a control problem and ... + Class for storing a control problem and implementing quantum optimal + control using reinforcement learning. This class defines a custom + Gym environment that models the dynamics of quantum systems + under various control pulses, and uses RL algorithms to optimize the + parameters of these pulses. """ def __init__( self, - objective, + objectives, + control_parameters, time_interval, time_options, - control_parameters, alg_kwargs, - guess_params, - **integrator_kwargs, + optimizer_kwargs, + minimizer_kwargs, + integrator_kwargs, + qtrl_optimizers, ): - super().__init__() # TODO: super init your gym environment here + """ + Initialize the reinforcement learning environment for quantum + optimal control. Sets up the system Hamiltonian, control parameters, + and defines the observation and action spaces for the RL agent. + """ + + super(_RL, self).__init__() + + self._Hd_lst, self._Hc_lst = [], [] + for objective in objectives: + # extract drift and control Hamiltonians from the objective + self._Hd_lst.append(objective.H[0]) + self._Hc_lst.append( + [H[0] if isinstance(H, list) else H for H in objective.H[1:]] + ) + + def create_pulse_func(idx): + """ + Create a control pulse lambda function for a given index. + """ + return lambda t, args: self._pulse(t, args, idx + 1) + + # create the QobjEvo with Hd, Hc and controls(args) + self._H_lst = [self._Hd_lst[0]] + dummy_args = {f"alpha{i+1}": 1.0 for i in range(len(self._Hc_lst[0]))} + for i, Hc in enumerate(self._Hc_lst[0]): + self._H_lst.append([Hc, create_pulse_func(i)]) + self._H = qt.QobjEvo(self._H_lst, args=dummy_args) - # ------------------------------- copied from _GOAT class ------------------------------- - - # TODO: you dont have to use (or keep them) if you don't need the following attributes - # this is just an inspiration how to extract information from the input + self.shorter_pulses = alg_kwargs.get( + "shorter_pulses", False + ) # lengthen the training to look for pulses of shorter duration, therefore episodes with fewer steps - self._Hd = objective.H[0] - self._Hc_lst = objective.H[1:] + # extract bounds for control_parameters + bounds = [] + for key in control_parameters.keys(): + bounds.append(control_parameters[key].get("bounds")) + self._lbound = [b[0][0] for b in bounds] + self._ubound = [b[0][1] for b in bounds] - self._control_parameters = control_parameters - self._guess_params = guess_params - self._H = self._prepare_generator() + self._alg_kwargs = alg_kwargs - self._initial = objective.initial - self._target = objective.target + self._initial = objectives[0].initial + self._target = objectives[0].target + self._state = None + self._dim = self._initial.shape[0] - self._evo_time = time_interval.evo_time + self._result = Result( + objectives=objectives, + time_interval=time_interval, + start_local_time=time.localtime(), # initial optimization time + n_iters=0, # Number of iterations(episodes) until convergence + iter_seconds=[], # list containing the time taken for each iteration(episode) of the optimization + var_time=True, # Whether the optimization was performed with variable time + guess_params=[], + ) + + self._backup_result = Result( # used as a backup in case the algorithm with shorter_pulses does not find an episode with infid= self.max_steps + ) # if the episode ended without reaching the goal + + observation = self._get_obs() + return observation, reward, bool(self.terminated), bool(self.truncated), {} + + def _get_obs(self): + """ + Get the current state observation for the RL agent. Converts the system's + quantum state or matrix into a real-valued NumPy array suitable for RL algorithms. + """ + rho = self._state.full().flatten() + obs = np.concatenate((np.real(rho), np.imag(rho))) + return obs.astype( + np.float32 + ) # Gymnasium expects the observation to be of type float32 + + def reset(self, seed=None): + """ + Reset the environment to the initial state, preparing for a new episode. + """ + self._save_episode_info() + + time_diff = self._episode_info[-1]["elapsed_time"] - ( + self._episode_info[-2]["elapsed_time"] + if len(self._episode_info) > 1 + else time.mktime(self._result.start_local_time) + ) + self._result.iter_seconds.append(time_diff) + self._current_step = 0 # Reset the step counter + self.current_episode += 1 # Increment episode counter + self._actions = self._temp_actions.copy() + self.terminated = False + self.truncated = False + self._temp_actions = [] + self._result._final_states = [self._state] + self._state = self._initial + return self._get_obs(), {} + + def _save_result(self): + """ + Save the results of the optimization process, including the optimized + pulse sequences, final states, and performance metrics. + """ + result_obj = self._backup_result if self._use_backup_result else self._result + + if self._use_backup_result: + self._backup_result.iter_seconds = self._result.iter_seconds.copy() + self._backup_result._final_states = self._result._final_states.copy() + self._backup_result.infidelity = self._result.infidelity + + result_obj.end_local_time = time.localtime() + result_obj.n_iters = len(self._result.iter_seconds) + result_obj.optimized_params = self._actions.copy() + [ + self._result.total_seconds + ] # If var_time is True, the last parameter is the evolution time + result_obj._optimized_controls = self._actions.copy() + result_obj._guess_controls = [] + result_obj._optimized_H = [self._H] def result(self): - # TODO: return qoc.Result object with the optimized pulse amplitudes - ... \ No newline at end of file + """ + Final conversions and return of optimization results + """ + if self._use_backup_result: + self._backup_result.start_local_time = time.strftime( + "%Y-%m-%d %H:%M:%S", self._backup_result.start_local_time + ) # Convert to a string + self._backup_result.end_local_time = time.strftime( + "%Y-%m-%d %H:%M:%S", self._backup_result.end_local_time + ) # Convert to a string + return self._backup_result + else: + self._save_result() + self._result.start_local_time = time.strftime( + "%Y-%m-%d %H:%M:%S", self._result.start_local_time + ) # Convert to a string + self._result.end_local_time = time.strftime( + "%Y-%m-%d %H:%M:%S", self._result.end_local_time + ) # Convert to a string + return self._result + + def train(self): + """ + Train the RL agent on the defined quantum control problem using the specified + reinforcement learning algorithm. Checks environment compatibility with Gym API. + """ + # Check if the environment follows Gym API + check_env(self, warn=True) + + # Create the model + model = PPO( + "MlpPolicy", self, verbose=1 + ) # verbose = 1 to display training progress and statistics in the terminal + + stop_callback = EarlyStopTraining(verbose=1) + + # Train the model + model.learn(total_timesteps=self._total_timesteps, callback=stop_callback) + + +class EarlyStopTraining(BaseCallback): + """ + A callback to stop training based on specific conditions (steps, infidelity, max iterations) + """ + + def __init__(self, verbose: int = 0): + super(EarlyStopTraining, self).__init__(verbose) + + def _on_step(self) -> bool: + """ + This method is required by the BaseCallback class. We use it to stop the training. + - Stop training if the maximum number of episodes is reached. + - Stop training if it finds an episode with infidelity <= than target infidelity + - If all of the last 100 episodes have infidelity below the target and use the same number of steps, stop training. + """ + env = self.training_env.get_attr("unwrapped")[0] + + # Check if we need to stop training + if env.current_episode >= env.max_episodes: + if env._use_backup_result is True: + env._backup_result.message = f"Reached {env.max_episodes} episodes, stopping training. Return the last founded episode with infid < target_infid" + else: + env._result.message = ( + f"Reached {env.max_episodes} episodes, stopping training." + ) + return False # Stop training + elif (env._result.infidelity <= env._fid_err_targ) and not (env.shorter_pulses): + env._result.message = "Stop training because an episode with infidelity <= target infidelity was found" + return False # Stop training + elif env.shorter_pulses: + if ( + env._result.infidelity <= env._fid_err_targ + ): # if it finds an episode with infidelity lower than target infidelity, I'll save it in the meantime + env._use_backup_result = True + env._save_result() + if len(env._episode_info) >= 100: + last_100_episodes = env._episode_info[-100:] + + min_steps = min(info["steps_used"] for info in last_100_episodes) + steps_condition = all( + ep["steps_used"] == min_steps for ep in last_100_episodes + ) + infid_condition = all( + ep["final_infidelity"] <= env._fid_err_targ + for ep in last_100_episodes + ) + + if steps_condition and infid_condition: + env._use_backup_result = False + env._result.message = "Training finished. No episode in the last 100 used fewer steps and infidelity was below target infid." + return False # Stop training + return True # Continue training diff --git a/src/qutip_qoc/pulse_optim.py b/src/qutip_qoc/pulse_optim.py index 661c577..9e8ecc3 100644 --- a/src/qutip_qoc/pulse_optim.py +++ b/src/qutip_qoc/pulse_optim.py @@ -1,7 +1,7 @@ """ This module is the entry point for the optimization of control pulses. It provides the function `optimize_pulses` which prepares and runs the -GOAT, JOPT, GRAPE or CRAB optimization. +GOAT, JOPT, GRAPE, CRAB or RL optimization. """ import numpy as np @@ -25,7 +25,7 @@ def optimize_pulses( integrator_kwargs=None, ): """ - Run GOAT, JOPT, GRAPE or CRAB optimization. + Run GOAT, JOPT, GRAPE, CRAB or RL optimization. Parameters ---------- @@ -41,6 +41,7 @@ def optimize_pulses( control_id : dict - guess: ndarray, shape (n,) + For RL you don't need to specify the guess. Initial guess. Array of real elements of size (n,), where ``n`` is the number of independent variables. @@ -49,7 +50,7 @@ def optimize_pulses( `guess`. None is used to specify no bound. __time__ : dict, optional - Only supported by GOAT and JOPT. + Only supported by GOAT, JOPT (for RL use `algorithm_kwargs: 'shorter_pulses'`). If given the pulse duration is treated as optimization parameter. It must specify both: @@ -71,14 +72,15 @@ def optimize_pulses( - alg : str Algorithm to use for the optimization. - Supported are: "GRAPE", "CRAB", "GOAT", "JOPT". + Supported are: "GRAPE", "CRAB", "GOAT", "JOPT" and "RL". - fid_err_targ : float, optional Fidelity error target for the optimization. - max_iter : int, optional Maximum number of iterations to perform. - Referes to local minimizer steps. + Referes to local minimizer steps or in the context of + `alg: "RL"` to the max. number of episodes. Global steps default to 0 (no global optimization). Can be overridden by specifying in minimizer_kwargs. @@ -349,8 +351,7 @@ def optimize_pulses( qtrl_optimizers.append(qtrl_optimizer) - # TODO: we can deal with proper handling later - if alg == "RL": + elif alg == "RL": rl_env = _RL( objectives, control_parameters, diff --git a/tests/test_result.py b/tests/test_result.py index 084aa74..7146b04 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -153,38 +153,49 @@ def sin_z_jax(t, r, **kwargs): ) # ----------------------- RL -------------------- -# TODO: this is the input for optimiz_pulses() function -# you can use this routine to test your implementation # state to state transfer -init = qt.basis(2, 0) -target = qt.basis(2, 1) +initial = qt.basis(2, 0) +target = (qt.basis(2, 0) + qt.basis(2, 1)).unit() # |+⟩ -H_c = [qt.sigmax(), qt.sigmay(), qt.sigmaz()] # control Hamiltonians +H_c = [qt.sigmax(), qt.sigmay(), qt.sigmaz()] # control Hamiltonians w, d, y = 0.1, 1.0, 0.1 -H_d = 1 / 2 * (w * qt.sigmaz() + d * qt.sigmax()) # drift Hamiltonian +H_d = 1 / 2 * (w * qt.sigmaz() + d * qt.sigmax()) # drift Hamiltonian -H = [H_d] + H_c # total Hamiltonian +H = [H_d] + H_c # total Hamiltonian state2state_rl = Case( objectives=[Objective(initial, H, target)], - control_parameters={"bounds": [-13, 13]}, # TODO: for now only consider bounds - tlist=np.linspace(0, 10, 100), # TODO: derive single step duration and max evo time / max num steps from this + control_parameters={ + "p": {"bounds": [(-13, 13)]}, + }, + tlist=np.linspace(0, 10, 100), algorithm_kwargs={ "fid_err_targ": 0.01, "alg": "RL", - "max_iter": 100, - } + "max_iter": 20000, + "shorter_pulses": True, + }, + optimizer_kwargs={}, ) -# TODO: no big difference for unitary evolution +# no big difference for unitary evolution -initial = qt.qeye(2) # Identity -target = qt.gates.hadamard_transform() +initial = qt.qeye(2) # Identity +target = qt.gates.hadamard_transform() unitary_rl = state2state_rl._replace( objectives=[Objective(initial, H, target)], + control_parameters={ + "p": {"bounds": [(-13, 13)]}, + }, + algorithm_kwargs={ + "fid_err_targ": 0.01, + "alg": "RL", + "max_iter": 300, + "shorter_pulses": True, + }, )