diff --git a/qopt/__init__.py b/qopt/__init__.py index 42f4f8b..48e362e 100644 --- a/qopt/__init__.py +++ b/qopt/__init__.py @@ -82,12 +82,3 @@ __version__ = '1.3' __license__ = 'GNU GPLv3+' __author__ = 'Julian Teske, Forschungszentrum Juelich' - - -try: - from jax.config import config - config.update("jax_enable_x64", True) - #TODO: add new objects here/ import other stuff? - # __all__ += [] -except ImportError: - pass \ No newline at end of file diff --git a/qopt/amplitude_functions.py b/qopt/amplitude_functions.py index f0b378a..832a94c 100644 --- a/qopt/amplitude_functions.py +++ b/qopt/amplitude_functions.py @@ -64,11 +64,10 @@ """ from abc import ABC, abstractmethod -from typing import Callable, Optional +from typing import Callable import numpy as np -from typing import Union class AmplitudeFunction(ABC): """Abstract Base class of the amplitude function. """ @@ -219,125 +218,3 @@ def derivative_by_chain_rule(self, deriv_by_ctrl_amps: np.ndarray, # return: shape (time, func, par) return np.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps) - - -############################################################################### - -try: - import jax.numpy as jnp - from jax import jit,vmap,jacfwd - _HAS_JAX = True -except ImportError: - from unittest import mock - jit, vmap, jacfwd = mock.Mock(), mock.Mock(), mock.Mock() - jnp = mock.Mock() - _HAS_JAX = False - - -class IdentityAmpFuncJAX(AmplitudeFunction): - """See docstring of class without JAX. - Designed to return jax-numpy-arrays. - """ - - def __init__(self): - if not _HAS_JAX: - raise ImportError("JAX not available") - - def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: - """See base class. """ - return jnp.asarray(x) - - def derivative_by_chain_rule( - self, - deriv_by_ctrl_amps: Union[np.ndarray,jnp.ndarray], - x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: - """See base class. """ - return jnp.asarray(deriv_by_ctrl_amps) - - -class UnaryAnalyticAmpFuncJAX(AmplitudeFunction): - """See docstring of class without JAX. - Designed to return jax-numpy-arrays. - Functions need to be compatible with jit. - (Includes that functions need to be pure - (i.e. output solely depends on input)). - """ - - def __init__(self, - value_function: Callable[[float, ], float], - derivative_function: [Callable[[float, ], float]]): - if not _HAS_JAX: - raise ImportError("JAX not available") - self.value_function = jit(jnp.vectorize(value_function)) - self.derivative_function = jit(jnp.vectorize(derivative_function)) - - def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: - """See base class. """ - return jnp.asarray(self.value_function(x)) - - def derivative_by_chain_rule( - self, - deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray], x): - """See base class. """ - du_by_dx = self.derivative_function(x) - # du_by_dx shape: (n_time, n_ctrl) - # deriv_by_ctrl_amps shape: (n_time, n_func, n_ctrl) - # deriv_by_opt_par shape: (n_time, n_func, n_ctrl - # since the function is unary we have n_ctrl = n_amps - return jnp.einsum('ij,ikj->ikj', du_by_dx, deriv_by_ctrl_amps) - - -class CustomAmpFuncJAX(AmplitudeFunction): - """See docstring of class without JAX. - Designed to return jax-numpy-arrays. - Functions need to be compatible with jit. - (Includes that functions need to be pure - (i.e. output solely depends on input)). - If derivative_function=None, autodiff is used. - t_to_vectorize: if value_function/derivative_function not yet - vectorized for num_t - """ - - def __init__( - self, - value_function: Callable[[Union[np.ndarray, jnp.ndarray],], - Union[np.ndarray, jnp.ndarray]], - derivative_function: Callable[[Union[np.ndarray, jnp.ndarray],], - Union[np.ndarray, jnp.ndarray]], - t_to_vectorize: bool = False - ): - if not _HAS_JAX: - raise ImportError("JAX not available") - if t_to_vectorize == True: - self.value_function = jit(vmap(value_function),in_axes=(0,)) - else: - self.value_function = jit(value_function) - if derivative_function is not None: - if t_to_vectorize == True: - self.derivative_function = jit(vmap(derivative_function),in_axes=(0,)) - else: - self.derivative_function = jit(derivative_function) - else: - if t_to_vectorize == True: - def der_wrapper(x): - return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(x)),in_axes=(0,))(x),1,2) - else: - def der_wrapper(x): - return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(jnp.expand_dims(x,axis=0))[0,:]),in_axes=(0,))(x),1,2) - self.derivative_function = jit(der_wrapper) - - def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: - """See base class. """ - return jnp.asarray(self.value_function(x)) - - def derivative_by_chain_rule( - self, - deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray], - x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: - """See base class. """ - du_by_dx = self.derivative_function(x) - # du_by_dx: shape (time, par, ctrl) - # deriv_by_ctrl_amps: shape (time, func, ctrl) - # return: shape (time, func, par) - - return jnp.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps) diff --git a/qopt/cost_functions.py b/qopt/cost_functions.py index cdb8c36..0bd8bff 100644 --- a/qopt/cost_functions.py +++ b/qopt/cost_functions.py @@ -104,9 +104,7 @@ from qopt.util import needs_refactoring, deprecated from qopt.matrix import ket_vectorize_density_matrix, \ convert_ket_vectorized_density_matrix_to_square, \ - convert_unitary_to_super_operator, DenseOperator - -from functools import partial + convert_unitary_to_super_operator class CostFunction(ABC): @@ -124,6 +122,7 @@ class CostFunction(ABC): storing the data. """ + def __init__(self, solver: solver_algorithms.Solver, label: Optional[List[str]] = None): self.solver = solver @@ -831,51 +830,6 @@ def derivative_entanglement_fidelity_with_du( return derivative_fidelity -def derivative_entanglement_fidelity_with_dfreq( - target: matrix.OperatorMatrix, - target_der: matrix.OperatorMatrix, - forward_propagators: List[matrix.OperatorMatrix], - computational_states: Optional[List[int]] = None, - map_to_closest_unitary: bool = False -) -> np.ndarray: - - target_unitary_dag = target.dag(do_copy=True) - if computational_states: - trace = np.conj( - ((forward_propagators[-1].truncate_to_subspace( - computational_states, - map_to_closest_unitary=map_to_closest_unitary) - * target_unitary_dag).tr()) - ) - else: - trace = np.conj(((forward_propagators[-1] * target_unitary_dag).tr())) - num_ctrls = 1 - num_time_steps = 1 - d = target.shape[0] - - derivative_fidelity = np.zeros(shape=(num_time_steps, num_ctrls), - dtype=float) - - target_unitary_dag = target_der.dag(do_copy=True) - - ctrl=0 - t=-1 - # here we need to take the real part. - if computational_states: - derivative_fidelity[t, ctrl] = 2 / d / d * np.real( - trace * (forward_propagators[t].truncate_to_subspace( - subspace_indices=computational_states, - map_to_closest_unitary=map_to_closest_unitary - ) - * target_unitary_dag).tr()) - else: - derivative_fidelity[t, ctrl] = 2 / d / d * np.real( - trace * (forward_propagators[t] - * target_unitary_dag).tr()) - - return derivative_fidelity - - def entanglement_fidelity_super_operator( target: Union[np.ndarray, matrix.OperatorMatrix], propagator: Union[np.ndarray, matrix.OperatorMatrix], @@ -1032,30 +986,6 @@ def deriv_entanglement_fid_sup_op_with_du( return derivative_fidelity -def deriv_entanglement_fid_sup_op_with_dfreq( - target: matrix.OperatorMatrix, - target_der: matrix.OperatorMatrix, - forward_propagators: List[matrix.OperatorMatrix], - computational_states: Optional[List[int]] = None, - map_to_closest_unitary: bool = False -): - - num_ctrls = 1 - num_time_steps = 1 - - derivative_fidelity = np.zeros(shape=(num_time_steps, num_ctrls), - dtype=float) - ctrl=0 - t=-1 - # here we need to take the real part. - derivative_fidelity[t, ctrl] = \ - entanglement_fidelity_super_operator( - target=target_der, - propagator=forward_propagators[t], - computational_states=computational_states) - return derivative_fidelity - - class StateInfidelity(CostFunction): """Quantum state infidelity. @@ -1108,65 +1038,6 @@ def grad(self) -> np.ndarray: return -1. * np.real(derivative_fid) -class StateInfidelity2(CostFunction): - """Quantum state infidelity. - - TODO: - * support super operator formalism - * handle leakage states? - """ - - def __init__(self, - solver: solver_algorithms.Solver, - target: matrix.OperatorMatrix, - initial: matrix.OperatorMatrix, - label: Optional[List[str]] = None, - computational_states: Optional[List[int]] = None, - rescale_propagated_state: bool = False - ): - if label is None: - label = ['State Infidelity', ] - super().__init__(solver=solver, label=label) - # assure target is a bra vector - - if target.shape[0] > target.shape[1]: - self.target = target.dag() - else: - self.target = target - - #1D - self.initial = initial - - self.computational_states = computational_states - self.rescale_propagated_state = rescale_propagated_state - - def costs(self) -> np.float64: - """See base class. """ - - final = DenseOperator((self.solver.forward_propagators[-1]*self.initial).data[:,np.newaxis]) - - infid = 1. - state_fidelity( - target=self.target, - propagated_state=final, - computational_states=self.computational_states, - rescale_propagated_state=self.rescale_propagated_state - ) - return infid - - def grad(self) -> np.ndarray: - """See base class. """ - derivative_fid = derivative_state_fidelity( - forward_propagators=[DenseOperator((p*self.initial).data[:,np.newaxis]) for p in self.solver.forward_propagators], - target=self.target, - reversed_propagators=self.solver.reversed_propagators, - propagator_derivatives=self.solver.frechet_deriv_propagators, - computational_states=self.computational_states, - rescale_propagated_state=self.rescale_propagated_state - ) - return -1. * np.real(derivative_fid) - - - class StateInfidelitySubspace(CostFunction): """ Quantum state infidelity on a subspace. @@ -1376,8 +1247,7 @@ def __init__(self, super_operator_formalism: bool = False, label: Optional[List[str]] = None, computational_states: Optional[List[int]] = None, - map_to_closest_unitary: bool = False, - total_ang_time = None, + map_to_closest_unitary: bool = False ): if label is None: if fidelity_measure == 'entanglement': @@ -1397,17 +1267,8 @@ def __init__(self, 'currently supported.') self.super_operator = super_operator_formalism - - - - if total_ang_time is None: - self.total_ang_time = 0 - elif total_ang_time <0: - self.total_ang_time = sum(solver.transferred_time)-0.5*solver.transferred_time[-1] - else: - self.total_ang_time = total_ang_time - - def costs_original(self) -> float: + + def costs(self) -> float: """Calculates the costs by the selected fidelity measure. """ final = self.solver.forward_propagators[-1] @@ -1429,12 +1290,9 @@ def costs_original(self) -> float: 'implemented in this version.') return np.real(infid) - def grad_original(self) -> np.ndarray: + def grad(self) -> np.ndarray: """Calculates the derivatives of the selected fidelity measure with respect to the control amplitudes. """ - - - if self.fidelity_measure == 'entanglement' and self.super_operator: derivative_fid = deriv_entanglement_fid_sup_op_with_du( forward_propagators=self.solver.forward_propagators, @@ -1457,90 +1315,6 @@ def grad_original(self) -> np.ndarray: 'version.') return -1 * np.real(derivative_fid) - - - def costs(self,freq=0) -> float: - """Calculates the costs by the selected fidelity measure. """ - final = self.solver.forward_propagators[-1] - - - r = DenseOperator(np.array([[np.exp(1j*freq/2*self.total_ang_time),0],[0,np.exp(-1j*freq/2*self.total_ang_time)]])) - - if self.fidelity_measure == 'entanglement' and self.super_operator: - infid = 1 - entanglement_fidelity_super_operator( - propagator=final, - target=r.dag()*self.target, - computational_states=self.computational_states, - ) - elif self.fidelity_measure == 'entanglement': - infid = 1 - entanglement_fidelity( - propagator=final, - target=r.dag()*self.target, - computational_states=self.computational_states, - map_to_closest_unitary=self.map_to_closest_unitary - ) - else: - raise NotImplementedError('Only the entanglement fidelity is ' - 'implemented in this version.') - return np.real(infid) - - def grad(self,freq=0) -> np.ndarray: - """Calculates the derivatives of the selected fidelity measure with - respect to the control amplitudes. """ - - - r = DenseOperator(np.array([[np.exp(1j*freq/2*self.total_ang_time),0],[0,np.exp(-1j*freq/2*self.total_ang_time)]])) - - if self.fidelity_measure == 'entanglement' and self.super_operator: - derivative_fid = deriv_entanglement_fid_sup_op_with_du( - forward_propagators=self.solver.forward_propagators, - target=r.dag()*self.target, - reversed_propagators=self.solver.reversed_propagators, - unitary_derivatives=self.solver.frechet_deriv_propagators, - computational_states=self.computational_states, - ) - elif self.fidelity_measure == 'entanglement': - derivative_fid = derivative_entanglement_fidelity_with_du( - forward_propagators=self.solver.forward_propagators, - target=r.dag()*self.target, - reversed_propagators=self.solver.reversed_propagators, - propagator_derivatives=self.solver.frechet_deriv_propagators, - computational_states=self.computational_states, - ) - else: - raise NotImplementedError('Only the average and entanglement' - 'fidelity is implemented in this ' - 'version.') - return -1 * np.real(derivative_fid) - - def der_freq_test(self,freq): - - - r_der = 1j*self.total_ang_time/2*DenseOperator(np.array([[np.exp(1j*freq/2*self.total_ang_time),0],[0,-np.exp(-1j*freq/2*self.total_ang_time)]])) - r = DenseOperator(np.array([[np.exp(1j*freq/2*self.total_ang_time),0],[0,np.exp(-1j*freq/2*self.total_ang_time)]])) - - if self.fidelity_measure == 'entanglement' and self.super_operator: - derivative_fid = deriv_entanglement_fid_sup_op_with_dfreq( - forward_propagators=self.solver.forward_propagators, - target_der = r_der.dag()*self.target, - target=r.dag()*self.target, - computational_states=self.computational_states, - ) - elif self.fidelity_measure == 'entanglement': - derivative_fid = derivative_entanglement_fidelity_with_dfreq( - forward_propagators=self.solver.forward_propagators, - target_der = r_der.dag()*self.target, - target=r.dag()*self.target, - computational_states=self.computational_states, - ) - - else: - raise NotImplementedError('Only the average and entanglement' - 'fidelity is implemented in this ' - 'version.') - return -1 * np.real(derivative_fid) - - class OperationNoiseInfidelity(CostFunction): """ @@ -1591,9 +1365,7 @@ def __init__(self, fidelity_measure: str = 'entanglement', computational_states: Optional[List[int]] = None, map_to_closest_unitary: bool = False, - neglect_systematic_errors: bool = True, - total_ang_time = None, - ): + neglect_systematic_errors: bool = True): if label is None: label = ['Operator Noise Infidelity'] super().__init__(solver=solver, label=label) @@ -1609,16 +1381,6 @@ def __init__(self, print('The systematic errors must be neglected if no target is ' 'set!') self.neglect_systematic_errors = True - - - - if total_ang_time is None: - self.total_ang_time = 0 - elif total_ang_time <0: - self.total_ang_time = sum(solver.transferred_time)-0.5*solver.transferred_time[-1] - else: - self.total_ang_time = total_ang_time - def _to_comp_space(self, dynamic_target: matrix.OperatorMatrix) -> matrix.OperatorMatrix: @@ -1631,28 +1393,18 @@ def _to_comp_space(self, else: return dynamic_target - def _effective_target(self,freq=0) -> matrix.OperatorMatrix: + def _effective_target(self) -> matrix.OperatorMatrix: if self.neglect_systematic_errors: return self._to_comp_space(self.solver.forward_propagators[-1]) else: - - r = DenseOperator(np.array([[np.exp(1j*freq/2*self.total_ang_time),0],[0,np.exp(-1j*freq/2*self.total_ang_time)]])) - return r.dag()*self.target - - def _effective_target_der(self,freq=0) -> matrix.OperatorMatrix: - if self.neglect_systematic_errors: - return 0*self.target - else: - - r = 1j*self.total_ang_time/2*DenseOperator(np.array([[np.exp(1j*freq/2*self.total_ang_time),0],[0,-np.exp(-1j*freq/2*self.total_ang_time)]])) - return r.dag()*self.target - - def costs(self,freq=0): + return self.target + + def costs(self): """See base class. """ n_traces = self.solver.noise_trace_generator.n_traces infidelities = np.zeros((n_traces,)) - target = self._effective_target(freq=freq) + target = self._effective_target() if self.fidelity_measure == 'entanglement': for i in range(n_traces): @@ -1669,9 +1421,9 @@ def costs(self,freq=0): return np.mean(np.real(infidelities)) - def grad(self,freq=0): + def grad(self): """See base class. """ - target = self._effective_target(freq) + target = self._effective_target() n_traces = self.solver.noise_trace_generator.n_traces num_t = len(self.solver.transferred_time) @@ -1700,28 +1452,7 @@ def grad(self,freq=0): ) derivative[:, :, i] = np.real(temp) return np.mean(-derivative, axis=2) - - def der_freq_test(self,freq): - - - - target_der = self._effective_target_der(freq) - target = self._effective_target(freq) - n_traces = self.solver.noise_trace_generator.n_traces - num_t=1 - num_ctrl = 1 - derivative = np.zeros((num_t, num_ctrl, n_traces, )) - for i in range(n_traces): - temp = derivative_entanglement_fidelity_with_dfreq( - target_der=target_der, - target=target, - forward_propagators=self.solver.forward_propagators_noise[i], - computational_states=self.computational_states - ) - derivative[:, :, i] = np.real(temp) - return np.mean(-derivative, axis=2) - class LiouvilleMonteCarloEntanglementInfidelity(CostFunction): """ @@ -2133,1135 +1864,66 @@ def costs(self): # the result should always be positive within numerical accuracy return leakage.data.real[0] - def grad(self): """See base class. """ - - num_ctrls = len(self.solver.frechet_deriv_propagators) - num_time_steps = len(self.solver.frechet_deriv_propagators[0]) - - derivative_leakage = np.zeros(shape=(num_time_steps, num_ctrls), - dtype=np.float64) - - for ctrl in range(num_ctrls): - for t in range(num_time_steps): - derivative_leakage[t, ctrl] = (1 / self.dim_comp) * ( - self.projector_leakage_bra - * self.solver.reversed_propagators[::-1][t + 1] \ - * self.solver.frechet_deriv_propagators[ctrl][t] - * self.solver.forward_propagators[t] - * self.projector_comp_ket - ).data.real[0] - - return derivative_leakage - - -############################################################################### - -try: - import jax.numpy as jnp - from jax import jit, vmap - import jax - _HAS_JAX = True -except ImportError: - from unittest import mock - jit = mock.Mock() - jnp = mock.Mock() - vmap = mock.Mock() - jax = mock.Mock() - _HAS_JAX = False - - -@jit -def _closest_unitary_jnp(matrix: jnp.ndarray) -> jnp.ndarray: - """Return the closest unitary to the matrix.""" - - left_singular_vec, __, right_singular_vec_h = jnp.linalg.svd( - matrix) - return left_singular_vec.dot(right_singular_vec_h) - - -@partial(jit,static_argnums=(1,)) -def _truncate_to_subspace_jnp_unmapped(arr: jnp.ndarray, - subspace_indices: Optional[tuple], - ) -> jnp.ndarray: - """Return the truncated jnp array""" - # subspace_indices = jnp.asarray(subspace_indices) - if subspace_indices is None: - return arr - elif arr.shape[0] == arr.shape[1]: - # square matrix - subspace_indices = jnp.asarray(subspace_indices) - out = arr[jnp.ix_(subspace_indices, subspace_indices)] - - elif arr.shape[0] == 1: - # bra-vector - subspace_indices = jnp.asarray(subspace_indices) - out = arr[jnp.ix_(jnp.array([0]), subspace_indices)] - - elif arr.shape[0] == 1: - # ket-vector - subspace_indices = jnp.asarray(subspace_indices) - out = arr[jnp.ix_(subspace_indices, jnp.array([0]))] - - else: - subspace_indices = jnp.asarray(subspace_indices) - out = arr[jnp.ix_(subspace_indices)] - - return out + raise NotImplementedError('The derivative of the cost function ' + 'LeakageLiouville has not been implemented' + 'yet.') -@partial(jit,static_argnums=(1,)) -def _truncate_to_subspace_jnp_mapped(arr: jnp.ndarray, - subspace_indices: Optional[tuple], - ) -> jnp.ndarray: - """Return the truncated jnp array mapped to the closest unitary (matrix) / - renormalized (vector) - """ - # subspace_indices = jnp.asarray(subspace_indices) - if subspace_indices is None: - return arr - elif arr.shape[0] == arr.shape[1]: - # square matrix - subspace_indices = jnp.asarray(subspace_indices) - out = arr[jnp.ix_(subspace_indices, subspace_indices)] - out = _closest_unitary_jnp(out) - elif arr.shape[0] == 1: - # bra-vector - subspace_indices = jnp.asarray(subspace_indices) - out = arr[jnp.ix_(jnp.array([0]), subspace_indices)] - out *= 1 / jnp.linalg.norm(out,'fro') - elif arr.shape[0] == 1: - # ket-vector - subspace_indices = jnp.asarray(subspace_indices) - out = arr[jnp.ix_(subspace_indices, jnp.array([0]))] - out *= 1 / jnp.linalg.norm(out,'fro') - else: - subspace_indices = jnp.asarray(subspace_indices) - out = arr[jnp.ix_(subspace_indices)] - return out - -@partial(jit,static_argnums=(1,2)) -def _truncate_to_subspace_jnp(arr,subspace_indices,map_to_closest_unitary): - """Return the truncated jnp array, either mapped to the - closest unitary (matrix) / renormalized (vector) or not +@deprecated +def derivative_entanglement_fidelity( + control_hamiltonians: List[matrix.OperatorMatrix], + forward_propagators: List[matrix.OperatorMatrix], + reversed_propagators: List[matrix.OperatorMatrix], + delta_t: List[float], + target_unitary: matrix.OperatorMatrix) -> np.ndarray: """ - if map_to_closest_unitary==True: - return _truncate_to_subspace_jnp_mapped(arr,subspace_indices) - else: - return _truncate_to_subspace_jnp_unmapped(arr,subspace_indices) - - -@partial(jit,static_argnums=(2,3)) -def _entanglement_fidelity_jnp( - target: jnp.ndarray, - propagator: jnp.ndarray, - computational_states: Optional[tuple] = None, - map_to_closest_unitary: bool = False -) -> jnp.float64: - """Return the entanglement fidelity of target and propagator""" + Derivative of the entanglement fidelity using the grape approximation. - d = target.shape[0] - if computational_states is None: - trace = (jnp.conj(target).T @ propagator).trace() - else: - trace = (jnp.conj(target).T @ _truncate_to_subspace_jnp(propagator, - computational_states, - map_to_closest_unitary)).trace() - return (jnp.abs(trace) ** 2) / d / d - - -@partial(jit,static_argnums=(2,3)) -def _entanglement_fidelity_super_operator_jnp( - target: jnp.ndarray, - propagator: jnp.ndarray, - dim_prop: int, - computational_states: Optional[tuple] = None, -) -> jnp.float64: - """Return the entanglement fidelity of target and propagator in super- - operator formalism - """ + dU / du = -i delta_t H_ctrl U - dim_comp = target.shape[0] + Parameters + ---------- + control_hamiltonians: List[ControlMatrix], len: num_ctrl + The control hamiltonians of the simulation. - if computational_states is None: - target_super_operator_inv = \ - jnp.kron(target.T, jnp.conj(target.T)) - trace = (target_super_operator_inv @ propagator).trace().real - else: - # Here we assume that the full Hilbertspace is the outer sum of a - # computational and a leakage space. + forward_propagators: List[ControlMatrix], len: num_t +1 + The forward propagators calculated in the systems simulation. - # Thus the dimension of the propagator is (d_comp + d_leak) ** 2 - d_leakage = dim_prop - dim_comp + reversed_propagators: List[ControlMatrix] + The reversed propagators calculated in the systems simulation. - # We fill zeros to the target on the leakage space. We will project - # onto the computational space anyway. + delta_t: List[float], len: num_t + The durations of the time steps. - target_inv = jnp.conj(target.T) - target_inv_full_space = jnp.zeros((d_leakage + dim_comp, - d_leakage + dim_comp),dtype=complex) - - clist = jnp.array(computational_states) + target_unitary: ControlMatrix + The target unitary evolution. - for i, row in enumerate(computational_states): - for k, column in enumerate(computational_states): - target_inv_full_space = target_inv_full_space.at[row, column].set(target_inv[i, k]) - - # Then convert the target unitary into Liouville space. - - target_super_operator_inv = jnp.kron(jnp.conj(target_inv_full_space), - target_inv_full_space) + Returns + ------- + derivative_fidelity: np.ndarray, shape: (num_t, num_ctrl) + The derivatives of the entanglement fidelity. - # We start the projector with a zero matrix of dimension - # (d_comp + d_leak). - projector_comp_state = 0 * jnp.identity(target_inv_full_space.shape[0]) - - # for state in computational_states: - projector_comp_state = projector_comp_state.at[clist, - clist].set(1) + """ + target_unitary_dag = target_unitary.dag(do_copy=True) + trace = np.conj(((forward_propagators[-1] * target_unitary_dag).tr())) + num_ctrls = len(control_hamiltonians) + num_time_steps = len(delta_t) + d = target_unitary.shape[0] - # Then convert the projector into liouville space. - projector_comp_state=jnp.kron(jnp.conj(projector_comp_state), - projector_comp_state) + derivative_fidelity = np.zeros(shape=(num_time_steps, num_ctrls), + dtype=complex) - trace = ( - projector_comp_state @ target_super_operator_inv @ propagator - ).trace().real - return trace / dim_comp / dim_comp - - -@partial(jit,static_argnums=(4,5)) -def _derivative_entanglement_fidelity_with_du_jnp( - target: jnp.ndarray, - forward_propagators_jnp: jnp.ndarray, - propagator_derivatives_jnp: jnp.ndarray, - reversed_propagators_jnp: jnp.ndarray, - computational_states: Optional[tuple] = None, - map_to_closest_unitary: bool = False -) -> jnp.ndarray: - """Return the derivative of the entanglement fidelity of target and - propagator - """ - target_unitary_dag = jnp.conj(target).T - if computational_states is not None: - trace = jnp.conj( - ((_truncate_to_subspace_jnp(forward_propagators_jnp[-1], - computational_states, - map_to_closest_unitary=map_to_closest_unitary) - @ target_unitary_dag).trace()) - ) - else: - trace = jnp.conj(((forward_propagators_jnp[-1]@ - target_unitary_dag).trace())) - d = target.shape[0] - - # here we need to take the real part. - if computational_states: - derivative_fidelity = 2/d/d * jnp.real(trace*_der_fid_comp_states( - propagator_derivatives_jnp, - reversed_propagators_jnp[::-1][1:], - forward_propagators_jnp[:-1],computational_states, - map_to_closest_unitary,target_unitary_dag)).T - - else: - derivative_fidelity = 2/d/d * jnp.real(trace*_der_fid( - propagator_derivatives_jnp, - reversed_propagators_jnp[::-1][1:], - forward_propagators_jnp[:-1],target_unitary_dag)).T - - return derivative_fidelity - - -def _der_fid_comp_states_loop(prop_der,rev_prop_rev,fwd_prop,comp_states, - map_to_closest_unitary,target_unitary_dag): - """Internal loop of derivative of entanglement fidelity w/ truncation""" - return (_truncate_to_subspace_jnp( - rev_prop_rev @ prop_der @ fwd_prop, - subspace_indices=comp_states, - map_to_closest_unitary=map_to_closest_unitary) - @ target_unitary_dag).trace() - - -#(to be used with additional .T for previously used shape) -@partial(jit,static_argnums=(3,4)) -def _der_fid_comp_states(prop_der,rev_prop_rev,fwd_prop,comp_states, - map_to_closest_unitary,target_unitary_dag): - """Derivative of entanglement fidelity w/ truncation, n_ctrl&n_timesteps on - first two axes - """ - return vmap(vmap(_der_fid_comp_states_loop,in_axes=(0,0,0,None,None,None)), - in_axes=(0,None,None,None,None,None))( - prop_der,rev_prop_rev,fwd_prop,comp_states, - map_to_closest_unitary,target_unitary_dag) - -def _der_fid_loop(prop_der,rev_prop_rev,fwd_prop,target_unitary_dag): - """Internal loop of derivative of entanglement fidelity w/o truncation""" - return (rev_prop_rev @ prop_der @ fwd_prop @ target_unitary_dag).trace() - -#(to be used with additional .T for previous shape) -@jit -def _der_fid(prop_der,rev_prop_rev,fwd_prop,target_unitary_dag): - """Derivative of entanglement fidelity w/o truncation""" - return vmap(vmap(_der_fid_loop,in_axes=(0,0,0,None)), - in_axes=(0,None,None,None))( - prop_der,rev_prop_rev,fwd_prop,target_unitary_dag) - - -@partial(jit,static_argnums=(4,5)) -def _deriv_entanglement_fid_sup_op_with_du_jnp( - target: jnp.ndarray, - forward_propagators: jnp.ndarray, - unitary_derivatives: jnp.ndarray, - reversed_propagators: jnp.ndarray, - dim_prop: int, - computational_states: Optional[tuple] = None -): - """Return the derivative of the entanglement fidelity of target and - propagator in super-operator formalism - """ - - derivative_fidelity = _der_entanglement_fidelity_super_operator_jnp( - target, - reversed_propagators[::-1][1:] @ unitary_derivatives @ - forward_propagators[:-1], - dim_prop, - computational_states).T - - return derivative_fidelity - - -#(to be used with additional .T for previous shape) -@partial(jit,static_argnums=(2,3)) -def _der_entanglement_fidelity_super_operator_jnp(target,propagators,dim_prop, - computational_states): - """Unnecessarily nested function for the derivative of the - entanglement fidelity of target and propagator in super-operator formalism - """ - return vmap(vmap(_entanglement_fidelity_super_operator_jnp, - in_axes=(None,0,None,None)),in_axes=(None,0,None,None))( - target,propagators,dim_prop,computational_states) - - -class StateInfidelityJAX(CostFunction): - """See docstring of class w/o JAX. Requires solver with JAX""" - def __init__(self, - solver: solver_algorithms.SolverJAX, - target: matrix.OperatorMatrix, - label: Optional[List[str]] = None, - computational_states: Optional[List[int]] = None, - rescale_propagated_state: bool = False - ): - if not _HAS_JAX: - raise ImportError("JAX not available") - if label is None: - label = ['State Infidelity', ] - super().__init__(solver=solver, label=label) - # assure target is a bra vector - - if target.shape[0] > target.shape[1]: - self.target = target.dag() - else: - self.target = target - - self._target_jnp = jnp.array(target.data) - if computational_states is None: - self.computational_states = None - else: - self.computational_states = tuple(computational_states) - self.rescale_propagated_state = rescale_propagated_state - - def costs(self) -> jnp.float64: - """See base class. """ - final = self.solver.forward_propagators_jnp[-1] - infid = 1. - _state_fidelity_jnp( - target=self._target_jnp, - propagated_state=final, - computational_states=self.computational_states, - rescale_propagated_state=self.rescale_propagated_state - ) - return jnp.real(infid) - - def grad(self) -> jnp.ndarray: - """See base class. """ - derivative_fid = _derivative_state_fidelity_jnp( - forward_propagators=self.solver.forward_propagators_jnp, - target=self._target_jnp, - reversed_propagators=self.solver.reversed_propagators_jnp, - propagator_derivatives=self.solver.frechet_deriv_propagators_jnp, - computational_states=self.computational_states, - rescale_propagated_state=self.rescale_propagated_state - ) - return -1. * jnp.real(derivative_fid) - - -@partial(jit,static_argnums=(2,3)) -def _state_fidelity_jnp( - target: jnp.ndarray, - propagated_state: jnp.ndarray, - computational_states: Optional[tuple] = None, - rescale_propagated_state: bool = False -) -> jnp.float64: - """Quantum state fidelity of target and propagated_state""" - - if computational_states is not None: - scalar_prod = jnp.dot( - target, - _truncate_to_subspace_jnp( - propagated_state, - computational_states, - map_to_closest_unitary=rescale_propagated_state - )) - else: - scalar_prod = jnp.dot(target, propagated_state) - - if scalar_prod.shape != (1, 1): - raise ValueError('The scalar product is not a scalar. This means that' - 'either the target is not a bra vector or the the ' - 'propagated state not a ket, or both!') - scalar_prod = scalar_prod[0, 0] - return jnp.abs(scalar_prod)**2 - - -@partial(jit,static_argnums=(4,5)) -def _derivative_state_fidelity_jnp( - target: jnp.ndarray, - forward_propagators: jnp.ndarray, - propagator_derivatives: jnp.ndarray, - reversed_propagators: jnp.ndarray, - computational_states: Optional[tuple] = None, - rescale_propagated_state: bool = False -) -> jnp.ndarray: - """Derivative of the state fidelity.""" - - if computational_states is not None: - scalar_prod = jnp.dot( - target, - _truncate_to_subspace_jnp( - forward_propagators[-1],subspace_indices=computational_states, - map_to_closest_unitary=rescale_propagated_state - )) - else: - scalar_prod = jnp.dot(target,forward_propagators[-1]) - - scalar_prod = jnp.conj(scalar_prod) - - if computational_states: - derivative_fidelity = 2 * jnp.real(scalar_prod*_der_fid_comp_states( - propagator_derivatives, - reversed_propagators[::-1][1:], - forward_propagators[:-1],computational_states, - rescale_propagated_state,target)).T - - else: - derivative_fidelity = 2 * jnp.real(scalar_prod*_der_fid( - propagator_derivatives, - reversed_propagators[::-1][1:], - forward_propagators[:-1],target)).T - - return derivative_fidelity - - -def _der_state_fid_comp_states_loop(prop_der,rev_prop_rev,fwd_prop,comp_states, - map_to_closest_unitary,target): - """Internal loop of derivative of state fidelity w/ truncation""" - return (target@_truncate_to_subspace_jnp( - rev_prop_rev@prop_der@fwd_prop, - subspace_indices=comp_states, - map_to_closest_unitary=map_to_closest_unitary))[0,0] - -#(to be used with additional .T for previous shape) -@partial(jit,static_argnums=(3,4)) -def _der_state_fid_comp_states(prop_der,rev_prop_rev,fwd_prop,comp_states, - map_to_closest_unitary,target): - """Derivative of state fidelity w/ truncation, - n_ctrl&n_time_steps on first two axes - """ - return vmap(vmap( - _der_state_fid_comp_states_loop,in_axes=(0,0,0,None,None,None)), - in_axes=(0,None,None,None,None,None))( - prop_der,rev_prop_rev,fwd_prop, - comp_states,map_to_closest_unitary,target) - -def _der_state_fid_loop(prop_der,rev_prop_rev,fwd_prop,target): - """Internal loop of derivative of state fidelity w/o truncation""" - - return (target @ rev_prop_rev @ prop_der @ fwd_prop)[0,0] - -#(to be used with additional .T for previous shape) -@jit -def _der_state_fid(prop_der,rev_prop_rev,fwd_prop,target): - """Derivative of state fidelity w/o truncation, - n_ctrl&n_time_steps on first two axes - """ - return vmap(vmap( - _der_state_fid_loop,in_axes=(0,0,0,None)),in_axes=(0,None,None,None))( - prop_der,rev_prop_rev,fwd_prop,target) - - -class StateInfidelitySubspaceJAX(CostFunction): - """See docstring of class w/o JAX. Requires solver with JAX""" - def __init__(self, - solver: solver_algorithms.SolverJAX, - target: matrix.OperatorMatrix, - dims: List[int], - remove: List[int], - label: Optional[List[str]] = None - ): - if not _HAS_JAX: - raise ImportError("JAX not available") - if label is None: - label = ['State Infidelity', ] - super().__init__(solver=solver, label=label) - # assure target is a bra vector - - if target.shape[0] > target.shape[1]: - self.target = target.dag() - else: - self.target = target - - self._target_jnp = jnp.asarray(self.target.data) - self.dims = tuple(dims) - self.remove = tuple(remove) - - def costs(self) -> jnp.float64: - """See base class. """ - final = self.solver.forward_propagators_jnp[-1] - infid = 1. - _state_fidelity_subspace_jnp( - target=self._target_jnp, - propagated_state=final, - dims=self.dims, - remove=self.remove - ) - return infid - - def grad(self) -> jnp.ndarray: - """See base class. """ - derivative_fid = _derivative_state_fidelity_subspace_jnp( - forward_propagators=self.solver.forward_propagators_jnp, - target=self._target_jnp, - reversed_propagators=self.solver.reversed_propagators_jnp, - propagator_derivatives=self.solver.frechet_deriv_propagators_jnp, - dims=self.dims, - remove=self.remove - ) - return -1. * derivative_fid - - -# @partial(jit,static_argnums=(2,3)) -def _state_fidelity_subspace_jnp( - target: jnp.ndarray, - propagated_state: jnp.ndarray, - dims: tuple, - remove: tuple -) -> jnp.float64: - r"""Derivative of the state fidelity on a subspace. - The unused subspace is traced out. - TODO: DID NOT include changes of last master commit -> WONT work with - vectorized density matrices. not as benefitial to have if statements in jax - functions; better create new func for it - """ - - rho = _ptrace_jnp(propagated_state,dims,remove) - - scalar_prod = target @ rho @ jnp.conj(target).T - - if scalar_prod.shape != (1, 1): - raise ValueError('The scalar product is not a scalar. This means that' - 'either the target is not a bra vector or the the ' - 'propagated state not a ket, or both!') - scalar_prod = scalar_prod[0, 0] - scalar_prod_real = scalar_prod.real - assert jnp.abs(scalar_prod - scalar_prod_real) < 1e-5 - return scalar_prod_real - - - -def _ptrace_jnp(mat: jnp.ndarray, - dims: tuple, - remove: tuple) -> jnp.ndarray: - """Partial trace of the matrix""" - - if mat.shape[1] == 1: - mat = (mat @ jnp.conj(mat).T) - - n_dim = len(dims) # number of subspaces - dims = jnp.asarray(dims, dtype=int) - - remove = jnp.sort(jnp.asarray(remove)) - - # indices of subspace that are kept - keep = jnp.array(jnp.where(jnp.arange(n_dim)!=remove)) - - keep=keep[0] - - dims_rm = dims[remove] - dims_keep = dims[keep] - dims = dims - - # 1. Reshape: Split matrix into subspaces - # 2. Transpose: Change subspace/index ordering such that the subspaces - # over which is traced correspond to the first axes - # 3. Reshape: Merge each, subspaces to be removed (A) and to be kept - # (B), common spaces/axes. - # The trace of the merged spaces (A \otimes B) can then be - # calculated as Tr_A(mat) using np.trace for input with - # more than two axes effectively resulting in - # pmat[j,k] = Sum_i mat[i,i,j,k] for all j,k = 0..prod(dims_keep) - pmat = jnp.trace(mat.reshape(jnp.hstack((dims,dims))) - .transpose(jnp.hstack((remove,n_dim + remove, - keep,n_dim +keep))) - .reshape(jnp.hstack((jnp.prod(dims_rm), - jnp.prod(dims_rm), - jnp.prod(dims_keep), - jnp.prod(dims_keep)))) - ) - - return pmat - - -def _derivative_state_fidelity_subspace_jnp( - target: jnp.ndarray, - forward_propagators: jnp.ndarray, - propagator_derivatives: jnp.ndarray, - reversed_propagators: jnp.ndarray, - dims: tuple, - remove: tuple -) -> jnp.ndarray: - """Derivative of the state fidelity on a subspace. - The unused subspace is traced out. - """ - - num_ctrls = len(propagator_derivatives) - num_time_steps = len(propagator_derivatives[0]) - - derivative_fidelity = np.zeros(shape=(num_time_steps, num_ctrls), - dtype=float) - - derivative_fidelity = 2 * jnp.real(_der_state_sub_fid_comp_states( - propagator_derivatives, - reversed_propagators[::-1][1:], - forward_propagators[:-1],dims, - remove,target)).T - - return derivative_fidelity - - -def _der_state_sub_fid_comp_states_loop(prop_der,rev_prop_rev,fwd_prop, - dims,remove,target): - """Internal loop of derivative of state fidelity on subspace""" - return (target @ _ptrace_jnp( - rev_prop_rev@prop_der@fwd_prop@ jnp.conj(fwd_prop[-1]).T,dims,remove)@ - jnp.conj(target).T)[0,0] - -#(to be used with additional .T for previous shape) -# @partial(jit,static_argnums=(3,4)) -def _der_state_sub_fid_comp_states(prop_der,rev_prop_rev,fwd_prop, - dims,remove,target): - """Derivative of state fidelity on subspace, n_ctrl&n_timesteps on first - two axes - """ - return vmap(vmap( - _der_state_sub_fid_comp_states_loop,in_axes=(0,0,0,None,None,None)), - in_axes=(0,None,None,None,None,None))( - prop_der,rev_prop_rev,fwd_prop,dims,remove,target) - - -class StateNoiseInfidelityJAX(CostFunction): - """See docstring of class w/o JAX. Requires solver with JAX""" - - def __init__(self, - solver: solver_algorithms.SchroedingerSMonteCarloJAX, - target: matrix.OperatorMatrix, - label: Optional[List[str]] = None, - computational_states: Optional[List[int]] = None, - rescale_propagated_state: bool = False, - neglect_systematic_errors: bool = True - ): - if not _HAS_JAX: - raise ImportError("JAX not available") - if label is None: - label = ['State Infidelity', ] - super().__init__(solver=solver, label=label) - self.solver = solver - - # assure target is a bra vector - if target.shape[0] > target.shape[1]: - self.target = target.dag() - else: - self.target = target - - self._target_jnp = jnp.array(target.data) - if computational_states is None: - self.computational_states = None - else: - self.computational_states = tuple(computational_states) - self.rescale_propagated_state = rescale_propagated_state - - self.neglect_systematic_errors = neglect_systematic_errors - if target is None and not neglect_systematic_errors: - print('The systematic errors must be neglected if no target is ' - 'set!') - self.neglect_systematic_errors = True - - def costs(self) -> jnp.float64: - """See base class. """ - n_traces = self.solver.noise_trace_generator.n_traces - infidelities = np.zeros((n_traces,)) - - if self.neglect_systematic_errors: - if self.computational_states is None: - target = self.solver.forward_propagators_jnp[-1] - else: - target = _truncate_to_subspace_jnp( - self.solver.forward_propagators_jnp[-1], - self.computational_states, - map_to_closest_unitary=self.rescale_propagated_state - ) - target = jnp.conj(target).T - else: - target = self._target_jnp - - # for i in range(n_traces): - final = self.solver.forward_propagators_noise_jnp[:,-1] - infidelities = 1. - jit(vmap( - _state_fidelity_jnp, - in_axes=(None,0,None,None)),static_argnums=(2,))( - target, - final, - self.computational_states, - self.rescale_propagated_state - ) - - return jnp.mean(jnp.real(infidelities)) - - def grad(self) -> jnp.ndarray: - """See base class. """ - raise NotImplementedError - - -class OperationInfidelityJAX(CostFunction): - """See docstring of class w/o JAX. Requires solver with JAX""" - - def __init__(self, - solver: solver_algorithms.SolverJAX, - target: matrix.OperatorMatrix, - fidelity_measure: str = 'entanglement', - super_operator_formalism: bool = False, - label: Optional[List[str]] = None, - computational_states: Optional[List[int]] = None, - map_to_closest_unitary: bool = False - ): - if not _HAS_JAX: - raise ImportError("JAX not available") - if label is None: - if fidelity_measure == 'entanglement': - label = ['Entanglement Infidelity', ] - else: - label = ['Operator Infidelity', ] - - super().__init__(solver=solver, label=label) - self.target = target - self._target_jnp = jnp.array(target.data) - if computational_states is None: - self.computational_states = None - else: - self.computational_states = tuple(computational_states) - self.map_to_closest_unitary = map_to_closest_unitary - - if fidelity_measure == 'entanglement': - self.fidelity_measure = fidelity_measure - else: - raise NotImplementedError('Only the entanglement fidelity is ' - 'currently supported.') - - self.super_operator = super_operator_formalism - - def costs(self) -> float: - """Calculates the costs by the selected fidelity measure. """ - final = self.solver.forward_propagators_jnp[-1] - - if self.fidelity_measure == 'entanglement' and self.super_operator: - infid = 1 - _entanglement_fidelity_super_operator_jnp( - self._target_jnp, - final, - jnp.sqrt(final.shape[0]).astype(int), - self.computational_states, - ) - elif self.fidelity_measure == 'entanglement': - infid = 1 - _entanglement_fidelity_jnp( - self._target_jnp, - final, - self.computational_states, - self.map_to_closest_unitary - ) - else: - raise NotImplementedError('Only the entanglement fidelity is ' - 'implemented in this version.') - return jnp.real(infid) - - - def grad(self) -> jnp.ndarray: - """Calculates the derivatives of the selected fidelity measure with - respect to the control amplitudes. """ - if self.fidelity_measure == 'entanglement' and self.super_operator: - derivative_fid = _deriv_entanglement_fid_sup_op_with_du_jnp( - self._target_jnp, - self.solver.forward_propagators_jnp, - self.solver.frechet_deriv_propagators_jnp, - self.solver.reversed_propagators_jnp, - jnp.sqrt(self.solver.forward_propagators_jnp.shape[1]).astype(int), - self.computational_states, - ) - elif self.fidelity_measure == 'entanglement': - derivative_fid = _derivative_entanglement_fidelity_with_du_jnp( - self._target_jnp, - self.solver.forward_propagators_jnp, - self.solver.frechet_deriv_propagators_jnp, - self.solver.reversed_propagators_jnp, - self.computational_states, - self.map_to_closest_unitary - ) - else: - raise NotImplementedError('Only the average and entanglement' - 'fidelity is implemented in this ' - 'version.') - return -1 * jnp.real(derivative_fid) - - -class OperationNoiseInfidelityJAX(CostFunction): - """See docstring of class w/o JAX. Requires solver with JAX""" - - def __init__(self, - solver: solver_algorithms.SchroedingerSMonteCarloJAX, - target: Optional[matrix.OperatorMatrix], - label: Optional[List[str]] = None, - fidelity_measure: str = 'entanglement', - computational_states: Optional[List[int]] = None, - map_to_closest_unitary: bool = False, - neglect_systematic_errors: bool = True): - if not _HAS_JAX: - raise ImportError("JAX not available") - if label is None: - label = ['Operator Noise Infidelity'] - super().__init__(solver=solver, label=label) - self.solver = solver - self.target = target - - self._target_jnp = jnp.array(target.data) - if computational_states is None: - self.computational_states = None - else: - self.computational_states = tuple(computational_states) - self.map_to_closest_unitary = map_to_closest_unitary - self.fidelity_measure = fidelity_measure - - self.neglect_systematic_errors = neglect_systematic_errors - if target is None and not neglect_systematic_errors: - print('The systematic errors must be neglected if no target is ' - 'set!') - self.neglect_systematic_errors = True - - def _to_comp_space(self, dynamic_target: jnp.ndarray) -> jnp.ndarray: - """Map an operator to the computational space""" - if self.computational_states is not None: - return _truncate_to_subspace_jnp(dynamic_target, - subspace_indices=self.computational_states, - map_to_closest_unitary=self.map_to_closest_unitary, - ) - else: - return dynamic_target - - def _effective_target(self) -> jnp.ndarray: - if self.neglect_systematic_errors: - return self._to_comp_space(self.solver.forward_propagators_jnp[-1]) - else: - return self._target_jnp - - def costs(self): - """See base class. """ - n_traces = self.solver.noise_trace_generator.n_traces - infidelities = np.zeros((n_traces,)) - - target = self._effective_target() - - if self.fidelity_measure == 'entanglement': - # for i in range(n_traces): - final = self.solver.forward_propagators_noise_jnp[:,-1] - - infidelities = 1 - jit(vmap( - _entanglement_fidelity_jnp, - in_axes=(None,0,None,None)),static_argnums=(2,))( - target,final, - self.computational_states, - self.map_to_closest_unitary - ) - else: - raise NotImplementedError('Only the entanglement fidelity is ' - 'currently implemented in this class.') - - return jnp.mean(jnp.real(infidelities)) - - def grad(self): - """See base class. """ - target = self._effective_target() - - temp = _derivative_entanglement_fidelity_with_du_noise_jnp( - target, - self.solver.forward_propagators_noise_jnp, - self.solver.frechet_deriv_propagators_noise_jnp, - self.solver.reversed_propagators_noise_jnp, - self.computational_states, - self.map_to_closest_unitary - ) - - if self.neglect_systematic_errors: - temp_target = vmap(self._to_comp_space,in_axes=(0,))( - self.solver.forward_propagators_noise_jnp[:,-1]) - - temp += _derivative_entanglement_fidelity_with_du_noise_sys_jnp( - temp_target, - self.solver.forward_propagators_jnp, - self.solver.frechet_deriv_propagators_jnp, - self.solver.reversed_propagators_jnp, - self.computational_states, - self.map_to_closest_unitary - ) - - return jnp.mean(-jnp.real(temp), axis=0) - - -@partial(jit,static_argnums=(4,5)) -def _derivative_entanglement_fidelity_with_du_noise_jnp( - target,fwd_props,prop_der,reversed_props,comp_states,map_to_closest): - """Return derivative of entanglement fidelity with vmap over traces""" - return vmap(_derivative_entanglement_fidelity_with_du_jnp, - in_axes=(None,0,0,0,None,None))( - target,fwd_props,prop_der,reversed_props, - comp_states,map_to_closest) - - -@partial(jit,static_argnums=(4,5)) -def _derivative_entanglement_fidelity_with_du_noise_sys_jnp( - target,fwd_props,prop_der,reversed_props,comp_states,map_to_closest): - """Return additional product rule part of derivative of entanglement - fidelity if systematic errors neglected""" - return vmap(_derivative_entanglement_fidelity_with_du_jnp, - in_axes=(0,None,None,None,None,None))( - target,fwd_props,prop_der,reversed_props, - comp_states,map_to_closest) - - -class LeakageErrorJAX(CostFunction): - """See docstring of class w/o JAX. Requires solver with JAX""" - - def __init__(self, solver: solver_algorithms.SolverJAX, - computational_states: List[int], - label: Optional[List[str]] = None): - if not _HAS_JAX: - raise ImportError("JAX not available") - if label is None: - label = ["Leakage Error", ] - super().__init__(solver=solver, label=label) - if computational_states is None: - self.computational_states = None - else: - self.computational_states = tuple(computational_states) - - def costs(self): - """See base class. """ - final_prop = self.solver.forward_propagators_jnp[-1] - clipped_prop = _truncate_to_subspace_jnp(final_prop, - self.computational_states,map_to_closest_unitary=False) - temp = jnp.conj(clipped_prop).T @ clipped_prop - - # the result should always be positive within numerical accuracy - return max(0, 1 - temp.trace().real / clipped_prop.shape[0]) - - def grad(self): - """See base class. """ - final = self.solver.forward_propagators_jnp[-1] - final_leak_dag = _truncate_to_subspace_jnp(jnp.conj(final).T, - self.computational_states,map_to_closest_unitary=False) - d = final_leak_dag.shape[0] - - derivative_fidelity = -2./d*jnp.real( - _der_leak_comp_states( - self.solver.frechet_deriv_propagators_jnp, - self.solver.reversed_propagators_jnp[::-1][1:], - self.solver.forward_propagators_jnp[:-1], - self.computational_states, - final_leak_dag).T) - - return derivative_fidelity - - -def _der_leak_comp_states_loop(prop_der,rev_prop_rev,fwd_prop,comp_states, - final_leak_dag): - """Internal loop of derivative of leakage""" - return (_truncate_to_subspace_jnp( - rev_prop_rev @ prop_der @ fwd_prop,subspace_indices=comp_states, - map_to_closest_unitary=False) @ final_leak_dag).trace() - -#(to be used with additional .T for previous shape) -@partial(jit,static_argnums=3) -def _der_leak_comp_states(prop_der,rev_prop_rev,fwd_prop,comp_states, - final_leak_dag): - """Derivative of leakage, n_ctrl&n_timesteps on first two axes""" - return vmap(vmap(_der_leak_comp_states_loop,in_axes=(0,0,0,None,None)), - in_axes=(0,None,None,None,None))( - prop_der,rev_prop_rev,fwd_prop,comp_states,final_leak_dag) - - -class IncoherentLeakageErrorJAX(CostFunction): - """See docstring of class w/o JAX. Requires solver with JAX""" - - def __init__(self, solver: solver_algorithms.SchroedingerSMonteCarloJAX, - computational_states: List[int], - label: Optional[List[str]] = None): - if not _HAS_JAX: - raise ImportError("JAX not available") - if label is None: - label = ["Incoherent Leakage Error", ] - super().__init__(solver=solver, label=label) - self.solver = solver - if computational_states is None: - self.computational_states = None - else: - self.computational_states = tuple(computational_states) - - def costs(self): - """See base class. """ - final_props = self.solver.forward_propagators_noise_jnp[:,-1] - - clipped_props = vmap(_truncate_to_subspace_jnp,in_axes=(0,None,None))( - final_props,self.computational_states,False) - - result = 1-jnp.real( - jnp.trace(jnp.transpose(jnp.conj(clipped_props),axes=(0,2,1))@ - clipped_props,axis1=1,axis2=2))/len( - self.computational_states) - - return jnp.mean(result) - - def grad(self): - """See base class. """ - raise NotImplementedError('Derivatives only implemented for the ' - 'coherent leakage.') - - -class LeakageLiouvilleJAX(CostFunction): - """See docstring of class w/o JAX. Requires solver with JAX""" - - def __init__(self, solver: solver_algorithms.SolverJAX, - computational_states: List[int], - label: Optional[List[str]] = None, - verbose: int = 0): - if not _HAS_JAX: - raise ImportError("JAX not available") - if label is None: - label = ["Leakage Error Lindblad", ] - super().__init__(solver=solver, label=label) - - self.computational_states = tuple(computational_states) - dim = self.solver.h_ctrl[0].shape[0] - self.dim_comp = len(self.computational_states) - self.verbose = verbose - # operator_class = type(self.solver.h_ctrl[0]) - - # create projectors - projector_comp = np.diag(np.ones([dim, ], dtype=complex)) - projector_leakage = np.diag(np.ones([dim, ], dtype=complex)) - - for state in computational_states: - projector_leakage[state, state] = 0 - projector_comp -= projector_leakage - - # vectorize projectors - self.projector_leakage_bra = jnp.asarray(ket_vectorize_density_matrix( - projector_leakage).transpose()) - - self.projector_comp_ket = jnp.asarray( - ket_vectorize_density_matrix(projector_comp)) - - - def costs(self): - """See base class. """ - leakage = (1 / self.dim_comp) * ( - self.projector_leakage_bra - @ self.solver.forward_propagators_jnp[-1] - @ self.projector_comp_ket - ) - - if self.verbose > 0: - print('leakage:') - print(leakage[0, 0]) - - # the result should always be positive within numerical accuracy - return leakage.real[0] - - def grad(self): - """See base class. """ - raise NotImplementedError('The derivative of the cost function ' - 'LeakageLiouville has not been implemented' - 'yet.') - - - - -@deprecated -def derivative_entanglement_fidelity( - control_hamiltonians: List[matrix.OperatorMatrix], - forward_propagators: List[matrix.OperatorMatrix], - reversed_propagators: List[matrix.OperatorMatrix], - delta_t: List[float], - target_unitary: matrix.OperatorMatrix) -> np.ndarray: - """ - Derivative of the entanglement fidelity using the grape approximation. - - dU / du = -i delta_t H_ctrl U - - Parameters - ---------- - control_hamiltonians: List[ControlMatrix], len: num_ctrl - The control hamiltonians of the simulation. - - forward_propagators: List[ControlMatrix], len: num_t +1 - The forward propagators calculated in the systems simulation. - - reversed_propagators: List[ControlMatrix] - The reversed propagators calculated in the systems simulation. - - delta_t: List[float], len: num_t - The durations of the time steps. - - target_unitary: ControlMatrix - The target unitary evolution. - - Returns - ------- - derivative_fidelity: np.ndarray, shape: (num_t, num_ctrl) - The derivatives of the entanglement fidelity. - - """ - target_unitary_dag = target_unitary.dag(do_copy=True) - trace = np.conj(((forward_propagators[-1] * target_unitary_dag).tr())) - num_ctrls = len(control_hamiltonians) - num_time_steps = len(delta_t) - d = target_unitary.shape[0] - - derivative_fidelity = np.zeros(shape=(num_time_steps, num_ctrls), - dtype=complex) - - for ctrl in range(num_ctrls): - for t in range(num_time_steps): - # we take the imaginary part because we took a factor of i out - derivative_fidelity[t, ctrl] = 2 / d / d * delta_t * np.imag( - trace * (reversed_propagators[::-1][t + 1] - * control_hamiltonians[ctrl] - * forward_propagators[t + 1] - * target_unitary_dag).tr()) - return derivative_fidelity + for ctrl in range(num_ctrls): + for t in range(num_time_steps): + # we take the imaginary part because we took a factor of i out + derivative_fidelity[t, ctrl] = 2 / d / d * delta_t * np.imag( + trace * (reversed_propagators[::-1][t + 1] + * control_hamiltonians[ctrl] + * forward_propagators[t + 1] + * target_unitary_dag).tr()) + return derivative_fidelity @needs_refactoring @@ -3336,7 +1998,8 @@ def default_set_orthorgonal(dim: int) -> List[matrix.OperatorMatrix]: @deprecated def derivative_average_gate_fidelity(control_hamiltonians, propagators, - propagators_past, delta_t, target_unitary): + propagators_past, delta_t, + target_unitary): """ The derivative of the average gate fidelity. """ @@ -3355,13 +2018,13 @@ def derivative_average_gate_fidelity(control_hamiltonians, propagators, dtype=complex) for ctrl in range(num_ctrls): for t in range(num_time_steps): - bkwd_prop_target = propagators_future[t+1].dag() * target_unitary + bkwd_prop_target = propagators_future[t + 1].dag() * target_unitary temp = 0 for ort in orthogonal_operators: lambda_ = bkwd_prop_target * ort.dag(do_copy=True) lambda_ *= bkwd_prop_target.dag() - rho = propagators_past[t+1] * ort - rho *= propagators_past[t+1].dag() + rho = propagators_past[t + 1] * ort + rho *= propagators_past[t + 1].dag() # everything rewritten to operate in place temp_mat2 = control_hamiltonians[t, ctrl] * rho temp_mat2 -= rho * control_hamiltonians[t, ctrl] @@ -3370,7 +2033,9 @@ def derivative_average_gate_fidelity(control_hamiltonians, propagators, temp_mat *= delta_t temp_mat *= temp_mat2 temp += temp_mat.tr() - + # temp += (lambda_ * -1j * delta_t * ( + # control_hamiltonians[t, ctrl] * rho + # - rho * control_hamiltonians[t, ctrl])).tr() derivative_fidelity[t, ctrl] = temp / (dim ** 2 * (dim + 1)) return derivative_fidelity @@ -3409,545 +2074,3 @@ def derivative_average_gate_fid_with_du(propagators, propagators_past, temp += lambda_.tr() derivative_fidelity[t, ctrl] = temp / (dim ** 2 * (dim + 1)) return derivative_fidelity - - -############################################################################### - -class OperationInfidelityJAXSpecial(OperationInfidelityJAX): - """ - """ - def __init__(self, - solver: solver_algorithms.Solver, - target: matrix.OperatorMatrix, - rot_frame_ang_freq: float, - fidelity_measure: str = 'entanglement', - super_operator_formalism: bool = False, - label: Optional[List[str]] = None, - computational_states: Optional[List[int]] = None, - map_to_closest_unitary: bool = False - ): - - super().__init__(solver=solver, - target=target, - fidelity_measure=fidelity_measure, - super_operator_formalism=super_operator_formalism, - label=label, - computational_states=computational_states, - map_to_closest_unitary=map_to_closest_unitary) - - - self.end_time = sum(solver.transferred_time)-0.5*solver.transferred_time[-1] - self.freq = rot_frame_ang_freq - - def rot_op_4(self,time): - return jnp.array([[np.exp(-1j*2*self.freq/2*time),0,0,0], - [0,np.exp(0*self.freq/2*time),0,0], - [0,0,np.exp(0*self.freq/2*time),0], - [0,0,0,np.exp(1j*2*self.freq/2*time)]]) - - def rot_op_4_der_t(self,time): - return 1j*2*self.freq/2*jnp.array([[-np.exp(-1j*2*self.freq/2*time),0,0,0], - [0,np.exp(0*self.freq/2*time),0,0], - [0,0,np.exp(0*self.freq/2*time),0], - [0,0,0,np.exp(1j*2*self.freq/2*time)]]) - - def costs(self,time_fact) -> float: - """Calculates the costs by the selected fidelity measure. """ - final = self.solver.forward_propagators_jnp[-1] - - if self.fidelity_measure == 'entanglement' and self.super_operator: - # raise NotImplementedError - infid = 1 - _entanglement_fidelity_super_operator_jnp( - self._target_jnp, - final, - jnp.sqrt(final.shape[0]).astype(int), - self.computational_states, - - ) - elif self.fidelity_measure == 'entanglement': - infid = 1 - _entanglement_fidelity_jnp( - self.rot_op_4(time_fact*self.end_time)@self._target_jnp, - final, - self.computational_states, - self.map_to_closest_unitary - ) - else: - raise NotImplementedError('Only the entanglement fidelity is ' - 'implemented in this version.') - return jnp.real(infid) - - - def grad(self, time_fact) -> jnp.ndarray: - """Calculates the derivatives of the selected fidelity measure with - respect to the control amplitudes. """ - if self.fidelity_measure == 'entanglement' and self.super_operator: - raise NotImplementedError - derivative_fid = _deriv_entanglement_fid_sup_op_with_du_jnp( - self._target_jnp, - self.solver.forward_propagators_jnp, - self.solver.frechet_deriv_propagators_jnp, - self.solver.reversed_propagators_jnp, - jnp.sqrt(self.solver.forward_propagators_jnp.shape[1]).astype(int), - self.computational_states, - ) - elif self.fidelity_measure == 'entanglement': - # raise NotImplementedError - derivative_fid = _derivative_entanglement_fidelity_with_du_jnp( - self.rot_op_4(time_fact*self.end_time)@self._target_jnp, - self.solver.forward_propagators_jnp, - self.solver.frechet_deriv_propagators_jnp, - self.solver.reversed_propagators_jnp, - self.computational_states, - self.map_to_closest_unitary - ) - else: - raise NotImplementedError('Only the average and entanglement' - 'fidelity is implemented in this ' - 'version.') - return -1 * jnp.real(derivative_fid) - - def der_time_fact(self,time_fact): - - - if self.fidelity_measure == 'entanglement' and self.super_operator: - raise NotImplementedError - - elif self.fidelity_measure == 'entanglement': - derivative_fid = _derivative_entanglement_fidelity_with_dtf_jnp( - self.rot_op_4(time_fact*self.end_time)@self._target_jnp, - self.end_time*self.rot_op_4_der_t(time_fact*self.end_time)@self._target_jnp, - self.solver.forward_propagators_jnp, - self.computational_states, - self.map_to_closest_unitary - ) - - else: - raise NotImplementedError('Only the average and entanglement' - 'fidelity is implemented in this ' - 'version.') - return -1 * np.real(derivative_fid) - - -@partial(jit,static_argnums=(3,4)) -def _derivative_entanglement_fidelity_with_dtf_jnp( - target: jnp.ndarray, - target_der: jnp.ndarray, - forward_propagators_jnp: jnp.ndarray, - computational_states: Optional[tuple] = None, - map_to_closest_unitary: bool = False -) -> jnp.ndarray: - """ - - """ - target_unitary_dag = jnp.conj(target).T - if computational_states is not None: - trace = jnp.conj( - ((_truncate_to_subspace_jnp(forward_propagators_jnp[-1], - computational_states, - map_to_closest_unitary=map_to_closest_unitary) - @ target_unitary_dag).trace()) - ) - else: - trace = jnp.conj(((forward_propagators_jnp[-1] @ target_unitary_dag).trace())) - # num_ctrls,num_time_steps = propagator_derivatives_jnp.shape[:2] - d = target.shape[0] - - # here we need to take the real part. - if computational_states: - derivative_fidelity = 2/d/d * jnp.real(trace*( - jnp.conj(target_der).T @ _truncate_to_subspace_jnp(forward_propagators_jnp[-1], - computational_states, - map_to_closest_unitary)).trace()) - - else: - derivative_fidelity = 2/d/d * jnp.real(trace*( - jnp.conj(target_der).T @ forward_propagators_jnp[-1]).trace()) - - return derivative_fidelity - - -class OperationInfidelityJAXSpecial2(OperationInfidelityJAX): - """ - """ - def __init__(self, - solver: solver_algorithms.Solver, - target: matrix.OperatorMatrix, - # rot_frame_ang_freq: float, - fidelity_measure: str = 'entanglement', - super_operator_formalism: bool = False, - label: Optional[List[str]] = None, - computational_states: Optional[List[int]] = None, - map_to_closest_unitary: bool = False - ): - - super().__init__(solver=solver, - target=target, - fidelity_measure=fidelity_measure, - super_operator_formalism=super_operator_formalism, - label=label, - computational_states=computational_states, - map_to_closest_unitary=map_to_closest_unitary) - - - # self.end_time = sum(solver.transferred_time)-0.5*solver.transferred_time[-1] - # self.freq = rot_frame_ang_freq - - # def rot_op_4(self,time): - # return jnp.array([[np.exp(-1j*2*self.freq/2*time),0,0,0], - # [0,np.exp(0*self.freq/2*time),0,0], - # [0,0,np.exp(0*self.freq/2*time),0], - # [0,0,0,np.exp(1j*2*self.freq/2*time)]]) - - # def rot_op_4_der_t(self,time): - # return 1j*2*self.freq/2*jnp.array([[-np.exp(-1j*2*self.freq/2*time),0,0,0], - # [0,np.exp(0*self.freq/2*time),0,0], - # [0,0,np.exp(0*self.freq/2*time),0], - # [0,0,0,np.exp(1j*2*self.freq/2*time)]]) - - def costs(self) -> float: - """Calculates the costs by the selected fidelity measure. """ - final = self.solver.forward_propagators_jnp[-1] - - if self.fidelity_measure == 'entanglement' and self.super_operator: - raise NotImplementedError - infid = 1 - _entanglement_fidelity_super_op_jnp_zphase( - self._target_jnp, - final, - jnp.sqrt(final.shape[0]).astype(int), - self.computational_states, - ) - elif self.fidelity_measure == 'entanglement': - infid = 1 - _entanglement_fidelity_jnp_zphase( - self._target_jnp, - final, - self.computational_states, - self.map_to_closest_unitary - ) - else: - raise NotImplementedError('Only the entanglement fidelity is ' - 'implemented in this version.') - return jnp.real(infid) - - - def grad(self) -> jnp.ndarray: - raise NotImplementedError - - -@jit -def _rot_op_p(ph_arr): - return jnp.diagflat(jnp.exp(1j*(ph_arr[0].real*jnp.array([1,1,-1,-1])+ph_arr[1].real*jnp.array([1,-1,1,-1])))) - -@partial(jit,static_argnums=(3,4)) -def _entanglement_infidelity_jnp_zphase_wrapper(ph_arr,target,prop,comp_states,to_closest): - return 1-_entanglement_fidelity_jnp(_rot_op_p(ph_arr)@target,prop,comp_states,to_closest) - -import jax.scipy.optimize as jsco - -@partial(jit,static_argnums=(2,3)) -def _entanglement_fidelity_jnp_zphase(target,prop,comp_states,to_closest): - res = jsco.minimize(_entanglement_infidelity_jnp_zphase_wrapper, - x0=jnp.array([0.,0.],dtype=jnp.float64),args=(target,prop,comp_states,to_closest), - method="BFGS") - return 1-res.fun - -@partial(jit,static_argnums=(2,3)) -def _entanglement_fidelity_jnp_zphase_returnopt(target,prop,comp_states,to_closest): - res = jsco.minimize(_entanglement_infidelity_jnp_zphase_wrapper, - x0=jnp.array([0.,0.],dtype=jnp.float64),args=(target,prop,comp_states,to_closest), - method="BFGS") - return 1-res.fun, res.x - -@partial(jit,static_argnums=(3,4,5)) -def _entanglement_infidelity_super_op_jnp_zphase_wrapper(ph_arr,target,prop,dim_prop,comp_states): - return 1-_entanglement_fidelity_super_operator_jnp(_rot_op_p(ph_arr)@target,prop,dim_prop,comp_states) - -@partial(jit,static_argnums=(2,3,4)) -def _entanglement_fidelity_super_op_jnp_zphase(target,prop,dim_prop,comp_states): - res = jsco.minimize(_entanglement_infidelity_super_op_jnp_zphase_wrapper, - x0=jnp.array([0.,0.],dtype=jnp.float64),args=(target,prop,dim_prop,comp_states), - method="BFGS") - return 1-res.fun - - -class TwoQubitEquivalenceClass(CostFunction): - """ - - """ - def __init__(self, - solver: solver_algorithms.Solver, - local_invariants: np.ndarray, - super_operator_formalism: bool = False, - label: Optional[List[str]] = None, - computational_states: Optional[List[int]] = None, - map_to_closest_unitary: bool = False - ): - if label is None: - label = ['Two Qubit Equivalence Class', ] - - super().__init__(solver=solver, label=label) - self.target_g = local_invariants - self._target_g_jnp = jnp.array(self.target_g) - self._target_g_c_jnp = jnp.array([self.target_g[0]+1j*self.target_g[1],self.target_g[2]]) - if computational_states is None: - self.computational_states = None - else: - self.computational_states = tuple(computational_states) - self.map_to_closest_unitary = map_to_closest_unitary - - # if fidelity_measure == 'entanglement': - # self.fidelity_measure = fidelity_measure - # else: - # raise NotImplementedError('Only the entanglement fidelity is ' - # 'currently supported.') - - self.super_operator = super_operator_formalism - - self._q_mat = 1/2**0.5*jnp.array([[1,0,0,1j], - [0,1j,1,0], - [0,1j,-1,0], - [1,0,0,-1j]]) - self._qq = jnp.conj(self._q_mat)@jnp.conj(self._q_mat).T - - - def costs(self) -> float: - """Calculates the costs by the selected fidelity measure. """ - final = self.solver.forward_propagators_jnp[-1] - - if self.computational_states is not None: - final = _truncate_to_subspace_jnp(final,self.computational_states,self.map_to_closest_unitary) - - m = _calc_m(final,self._q_mat) - g_arr_c = _calc_g_c(m,final) - l_sq_abs = jnp.sum(jnp.abs(g_arr_c-self._target_g_c_jnp)**2)**0.5 - return l_sq_abs - - - def grad(self) -> jnp.ndarray: - """Calculates the derivatives of the selected fidelity measure with - respect to the control amplitudes. """ - - final = self.solver.forward_propagators_jnp[-1] - - rev_prop_rev = self.solver.reversed_propagators_jnp[::-1][1:] - prop_der = self.solver.frechet_deriv_propagators_jnp - fwd_props = self.solver.forward_propagators_jnp[:-1] - - if self.computational_states is not None: - final = _truncate_to_subspace_jnp(final,self.computational_states,self.map_to_closest_unitary) - rpr_pd_fp = _truncate_to_subspace_jnp_dvmap(rev_prop_rev@prop_der@fwd_props,self.computational_states,self.map_to_closest_unitary) - - m = _calc_m(final,self._q_mat) - g_arr_c = _calc_g_c(m,final) - l_sq_abs = jnp.sum(jnp.abs(g_arr_c-self._target_g_c_jnp)**2)**0.5 - - derivative_lsq = _dlsq_du_c(m,self._q_mat,self._qq, - rpr_pd_fp, - final, - self._target_g_c_jnp, - g_arr_c, - l_sq_abs).T - - # should be shape: (num_t, num_ctrl) - return jnp.real(derivative_lsq) - -@jit -def _calc_m(arr,q): - ub = (jnp.conj(q).T)@arr@q - return (ub.T)@ub - -@jit -def _g_to_s_d(g_arr): - z_arr = jnp.roots([1,-g_arr[2],(4*(g_arr[0]**2+g_arr[1]**2)**0.5-1),(g_arr[2]-4*g_arr[0])]) - return jnp.pi-jnp.arccos(z_arr[0])-jnp.arccos(z_arr[2]), g_arr[2]*(g_arr[0]**2+g_arr[1]**2)**0.5-g_arr[0] - -@jit -def _calc_g_c(m,u): - g1 = 1/16 * jnp.trace(m)**2 - g3 = 1/4 * (jnp.trace(m)**2-jnp.trace(m@m)) - return jnp.asarray([g1,g3]) * jnp.linalg.det(jnp.conj(u).T) - -@jit -def _dm_dukj(q,qq,rpr_pd_fp,final): - return q.T@(rpr_pd_fp).T@qq@final@q+\ - q.T@final.T@qq@rpr_pd_fp@q - -@jit -def _ddetU_dukj(U,dUdukj): - return jnp.linalg.det(U)*jnp.trace(jnp.linalg.inv(U)@dUdukj) - -@jit -def _dg12_dukj(m,q,qq,rpr_pd_fp,final): - return 1/16*(2*m.trace()*_dm_dukj(q,qq,rpr_pd_fp,final).trace()*jnp.linalg.det(jnp.conj(final).T) - +m.trace()**2*_ddetU_dukj(jnp.conj(final).T,jnp.conj(rpr_pd_fp).T)) - -@jit -def _dg3_dukj(m,q,qq,rpr_pd_fp,final): - return 0.25*(2*(m.trace()*_dm_dukj(q,qq,rpr_pd_fp,final).trace()- - (m@_dm_dukj(q,qq,rpr_pd_fp,final)).trace())*jnp.linalg.det(jnp.conj(final).T) - +(m.trace()**2-(m@m).trace())*_ddetU_dukj(jnp.conj(final).T,jnp.conj(rpr_pd_fp).T)) - -@jit -def _dlsq_dukj_c(m,q,qq,rpr_pd_fp,final,g0_arr_c,g_arr_c,l_sq_abs): - dg12 = _dg12_dukj(m,q,qq,rpr_pd_fp,final) - dg3 = _dg3_dukj(m,q,qq,rpr_pd_fp,final) - return 1/l_sq_abs*jnp.sum(jnp.real((g_arr_c-g0_arr_c)*jnp.conj(jnp.array([dg12,dg3])))) - - -#(to be used with additional .T for previous shape) -@jit -def _dlsq_du_c(m,q,qq,rpr_pd_fp,final,g0_arr_c,g_arr_c,l_sq_abs): - return vmap(vmap(_dlsq_dukj_c,in_axes=(None,None,None,0,None,None,None,None)), - in_axes=(None,None,None,0,None,None,None,None))( - m,q,qq,rpr_pd_fp,final,g0_arr_c,g_arr_c,l_sq_abs) - -@partial(jit,static_argnums=(1,2)) -def _truncate_to_subspace_jnp_vmap(arr,subspace_indices,map_to_closest_unitary): - return vmap(_truncate_to_subspace_jnp,in_axes=(0,None,None))(arr,subspace_indices,map_to_closest_unitary) - -@partial(jit,static_argnums=(1,2)) -def _truncate_to_subspace_jnp_dvmap(arr,subspace_indices,map_to_closest_unitary): - return vmap(_truncate_to_subspace_jnp_vmap,in_axes=(0,None,None))(arr,subspace_indices,map_to_closest_unitary) - - - -############################################################################### - -class OperationInfidelityJAXzphase1Q(OperationInfidelityJAX): - """ - """ - def __init__(self, - solver: solver_algorithms.Solver, - target: matrix.OperatorMatrix, - # rot_frame_ang_freq: float, - fidelity_measure: str = 'entanglement', - super_operator_formalism: bool = False, - label: Optional[List[str]] = None, - computational_states: Optional[List[int]] = None, - map_to_closest_unitary: bool = False, - basis_change_op = None - ): - - super().__init__(solver=solver, - target=target, - fidelity_measure=fidelity_measure, - super_operator_formalism=super_operator_formalism, - label=label, - computational_states=computational_states, - map_to_closest_unitary=map_to_closest_unitary) - - self.basis_change_op = basis_change_op - - def costs(self) -> float: - """Calculates the costs by the selected fidelity measure. """ - if self.basis_change_op is not None: - final = self.basis_change_op @ self.solver.forward_propagators_jnp[-1] - else: - final = self.solver.forward_propagators_jnp[-1] - - if self.fidelity_measure == 'entanglement' and self.super_operator: - raise NotImplementedError - # infid = 1 - _entanglement_fidelity_super_op_jnp_zphase_1q( - # self._target_jnp, - # final, - # jnp.sqrt(final.shape[0]).astype(int), - # self.computational_states, - # ) - elif self.fidelity_measure == 'entanglement': - infid = 1 - _entanglement_fidelity_jnp_zphase_1q( - self._target_jnp, - final, - self.computational_states, - self.map_to_closest_unitary - ) - else: - raise NotImplementedError('Only the entanglement fidelity is ' - 'implemented in this version.') - return jnp.real(infid) - - - def grad(self) -> jnp.ndarray: - raise NotImplementedError - - -@jit -def _rot_op_p_1q(ph_arr): - return jnp.diagflat(jnp.exp(1j*(ph_arr[0].real*jnp.array([1,-1])))) - -@partial(jit,static_argnums=(3,4)) -def _entanglement_infidelity_jnp_zphase_wrapper_1q(ph_arr,target,prop,comp_states,to_closest): - return 1-_entanglement_fidelity_jnp(_rot_op_p_1q(ph_arr)@target,prop,comp_states,to_closest) - -@partial(jit,static_argnums=(2,3)) -def _entanglement_fidelity_jnp_zphase_1q(target,prop,comp_states,to_closest): - res = jsco.minimize(_entanglement_infidelity_jnp_zphase_wrapper_1q, - x0=jnp.array([0.,],dtype=jnp.float64),args=(target,prop,comp_states,to_closest), - method="BFGS") - return 1-res.fun - - - -class LeakageErrorBaseChangeJAX(CostFunction): - """See docstring of class w/o JAX. Requires solver with JAX""" - - def __init__(self, solver: solver_algorithms.SolverJAX, - computational_states: List[int], - label: Optional[List[str]] = None, - basis_change_op = None - ): - - if not _HAS_JAX: - raise ImportError("JAX not available") - if label is None: - label = ["Leakage Error", ] - super().__init__(solver=solver, label=label) - if computational_states is None: - self.computational_states = None - else: - self.computational_states = tuple(computational_states) - - self.basis_change_op = basis_change_op - - def costs(self): - """See base class. """ - if self.basis_change_op is not None: - final_prop = self.basis_change_op @ self.solver.forward_propagators_jnp[-1] - else: - final_prop = self.solver.forward_propagators_jnp[-1] - - clipped_prop = _truncate_to_subspace_jnp(final_prop, - self.computational_states,map_to_closest_unitary=False) - temp = jnp.conj(clipped_prop).T @ clipped_prop - - # the result should always be positive within numerical accuracy - return max(0, 1 - temp.trace().real / clipped_prop.shape[0]) - - def grad(self): - """See base class. """ - if self.basis_change_op is not None: - final = self.basis_change_op @ self.solver.forward_propagators_jnp[-1] - else: - final = self.solver.forward_propagators_jnp[-1] - - final_leak_dag = _truncate_to_subspace_jnp(jnp.conj(final).T, - self.computational_states,map_to_closest_unitary=False) - d = final_leak_dag.shape[0] - - if self.basis_change_op is not None: - derivative_fidelity = -2./d*jnp.real( - _der_leak_comp_states( - self.basis_change_op @ self.solver.frechet_deriv_propagators_jnp, - self.basis_change_op @ self.solver.reversed_propagators_jnp[::-1][1:], - self.basis_change_op @ self.solver.forward_propagators_jnp[:-1], - self.computational_states, - final_leak_dag).T) - - else: - derivative_fidelity = -2./d*jnp.real( - _der_leak_comp_states( - self.solver.frechet_deriv_propagators_jnp, - self.solver.reversed_propagators_jnp[::-1][1:], - self.solver.forward_propagators_jnp[:-1], - self.computational_states, - final_leak_dag).T) - - return derivative_fidelity diff --git a/qopt/matrix.py b/qopt/matrix.py index 88bb5b9..c77968d 100644 --- a/qopt/matrix.py +++ b/qopt/matrix.py @@ -1451,668 +1451,3 @@ def closest_unitary(matrix: OperatorMatrix): left_singular_vec, __, right_singular_vec_h = scipy.linalg.svd( matrix.data) return type(matrix)(left_singular_vec.dot(right_singular_vec_h)) - -############################################################################### - -try: - import jax.numpy as jnp - from jax import jit, vmap - import jax - _HAS_JAX = True -except ImportError: - from unittest import mock - jit = mock.Mock() - jnp = mock.Mock() - vmap = mock.Mock() - jax = mock.Mock() - _HAS_JAX = False - - -class DenseOperatorJAX(OperatorMatrix): - """See docstring of class w/o JAX. Works with jnp arrays""" - - __slots__ = ("data",) - - def __init__( - self, - obj: Union[Qobj, np.ndarray, jnp.ndarray, - sp.csr_matrix, 'DenseOperator']) \ - -> None: - if not _HAS_JAX: - raise ImportError("JAX not available") - super().__init__() - self.data = None - if isinstance(obj,jnp.ndarray): - self.data = obj.astype(jnp.complex128) - elif type(obj) is DenseOperatorJAX: - self.data = obj.data - elif type(obj) is DenseOperator: - self.data = obj.data.astype(jnp.complex128) - elif type(obj) is np.ndarray: - self.data = obj.astype(np.complex128) - elif type(obj) is Qobj: - self.data = jnp.array(obj.data.todense(),dtype=jnp.complex128) - elif type(obj) is sp.csr_matrix: - self.data = obj.toarray() - self.data = jnp.array(self.data,dtype=jnp.complex128) - else: - raise ValueError("Data of this type can not be broadcasted into a " - "dense control matrix. Type: " + str(type(obj))) - - def copy(self): - """See base class. """ - copy_ = DenseOperatorJAX(jnp.array(self.data,copy=True)) - # numpy copies are deep - return copy_ - - def __imul__( - self, - other: Union['DenseOperatorJAX', 'DenseOperator', complex, float, - int, np.generic, jnp.ndarray] - ) -> 'DenseOperatorJAX': - """See base class. """ - - if type(other) == DenseOperatorJAX or type(other) == DenseOperator: - jnp.matmul(self.data, other.data, out=self.data) - elif isinstance(other,jnp.ndarray) or isinstance(other,np.ndarray): - jnp.matmul(self.data, other, out=self.data) - elif type(other) in VALID_SCALARS: - self.data *= other - else: - raise NotImplementedError(str(type(other))) - return self - - def __mul__( - self, - other: Union['DenseOperatorJAX', 'DenseOperator', complex, float, - int, np.generic, jnp.ndarray] - ) -> 'DenseOperatorJAX': - """See base class. """ - - if type(other) in VALID_SCALARS: - out = self.copy() - out *= other - if type(other) == DenseOperatorJAX or type(other) == DenseOperator: - out = DenseOperatorJAX(jnp.matmul(self.data, other.data)) - elif type(other) == np.ndarray: - out = DenseOperatorJAX(jnp.matmul(self.data, jnp.array(other))) - elif isinstance(other,jnp.ndarray): - if other.shape==(): - out = DenseOperatorJAX(self.data*other) - else: - out = DenseOperatorJAX(jnp.matmul(self.data, jnp.array(other))) - else: - raise NotImplementedError(str(type(other))) - return out - - def __rmul__( - self, - other: Union['DenseOperatorJAX', 'DenseOperator', complex, float, - int, np.generic, jnp.ndarray] - ) -> 'DenseOperatorJAX': - """See base class. """ - - if isinstance(other,jnp.ndarray) or isinstance(other,np.ndarray): - out = DenseOperatorJAX(jnp.matmul(other, self.data)) - elif type(other) in VALID_SCALARS: - out = self.copy() - out *= other - else: - raise NotImplementedError(str(type(other))) - return out - - def __iadd__(self, other: 'DenseOperatorJAX') -> 'DenseOperatorJAX': - """See base class. """ - if type(other) is DenseOperatorJAX: - self.data += other.data - elif isinstance(other,jnp.ndarray) or isinstance(other,np.ndarray): - self.data += other - elif type(other) in VALID_SCALARS: - self.data += other - else: - raise NotImplementedError(str(type(other))) - return self - - def __isub__(self, other: 'DenseOperatorJAX') -> 'DenseOperatorJAX': - """See base class. """ - - if type(other) is DenseOperatorJAX: - self.data -= other.data - elif isinstance(other,jnp.ndarray) or isinstance(other,np.ndarray): - self.data -= other - elif type(other) in VALID_SCALARS: - self.data -= other - else: - raise NotImplementedError(str(type(other))) - return self - - def __truediv__(self, other: 'DenseOperatorJAX') -> 'DenseOperatorJAX': - if isinstance(other, (np.ndarray,jnp.ndarray, *VALID_SCALARS)): - return DenseOperatorJAX(self.data / other) - raise NotImplementedError(str(type(other))) - - def __itruediv__(self, other: 'DenseOperatorJAX') -> 'DenseOperatorJAX': - if isinstance(other, (np.ndarray,jnp.ndarray, *VALID_SCALARS)): - self.data /= other - return self - raise NotImplementedError(str(type(other))) - - def __getitem__(self, index: tuple) -> jnp.complex128: - """See base class. """ - return self.data[index] - - def __setitem__(self, key, value) -> None: - """See base class. """ - self.data = self.data.at[key].set(value) - - def __repr__(self): - """Representation as numpy array. """ - return 'DenseOperatorJAX with data: \n' + self.data.__repr__() - - def dag(self, do_copy: bool = True) -> Optional['DenseOperatorJAX']: - """See base class. """ - if do_copy: - cp = self.copy() - #was additional statement with "out" before, not in jnp? - - cp.data = jnp.conj(cp.data).T - return cp - else: - self.data = jnp.conj(self.data).T - return self - - def conj(self, do_copy: bool = True) -> Optional['DenseOperatorJAX']: - """See base class. """ - if do_copy: - copy = self.copy() - copy.data = jnp.conj(copy.data) - return copy - else: - self.data = jnp.conj(self.data) - return self - - def transpose(self, do_copy: bool = True) -> Optional['DenseOperatorJAX']: - """See base class. """ - if do_copy: - out = self.copy() - else: - out = self - out.data = out.data.transpose() - return out - - def flatten(self) -> jnp.ndarray: - """See base class. """ - return self.data.flatten() - - def norm(self, ord: Union[str, None, int] = 'fro') -> jnp.float64: - """ - Calulates the norm of the matrix. - - Uses the implementation of numpy.linalg.norm. - - Parameters - ---------- - ord: string - Defines the norm which is calculated. Defaults to the Frobenius norm - 'fro'. - - Returns - ------- - norm: float - Norm of the Matrix. - - """ - return jnp.linalg.norm(self.data, ord=ord) - - def tr(self) -> complex: - """See base class. """ - return self.data.trace() - - def ptrace(self, - dims: Sequence[int], - remove: Sequence[int], - do_copy: bool = True) -> 'DenseOperatorJAX': - """ - Partial trace of the matrix. - - If the matrix describes a ket, the corresponding density matrix is - calculated and used for the partial trace. - - This implementation closely follows that of QuTip's qobj._ptrace_dense. - Parameters - ---------- - dims : list of int - Dimensions of the subspaces making up the total space on which - the matrix operates. The product of elements in 'dims' must be - equal to the matrix' dimension. - remove : list of int - The selected subspaces as indices over which the partial trace is - formed. The given indices correspond to the ordering of - subspaces specified in the 'dim' argument. - do_copy : bool, optional - If false, the operation is executed inplace. Otherwise returns - a new instance. Defaults to True. - - Returns - ------- - pmat : OperatorMatrix - The partially traced OperatorMatrix. - - Raises - ------ - AssertionError: - If matrix dimension does not match specified dimensions. - - Examples - -------- - ghz_ket = DenseOperator(np.array([[1,0,0,0,0,0,0,1]]).T) / np.sqrt(2) - ghz_rho = ghz_ket * ghz_ket.dag() - ghz_rho.ptrace(dims=[2,2,2], remove=[0,2]) - DenseOperator with data: - array([[0.5+0.j, 0. +0.j], - [0. +0.j, 0.5+0.j]]) - """ - - if self.shape[1] == 1: - mat = (self * self.dag()).data - else: - mat = self.data - if mat.shape[0] != jnp.prod(dims): - raise AssertionError("Specified dimensions do not match " - "matrix dimension.") - n_dim = len(dims) # number of subspaces - dims = jnp.asarray(dims, dtype=int) - - remove = list(jnp.sort(remove)) - # indices of subspace that are kept - keep = list(set(np.arange(n_dim)) - set(remove)) - - dims_rm = (dims[remove]).tolist() - dims_keep = (dims[keep]).tolist() - dims = list(dims) - - # 1. Reshape: Split matrix into subspaces - # 2. Transpose: Change subspace/index ordering such that the subspaces - # over which is traced correspond to the first axes - # 3. Reshape: Merge each, subspaces to be removed (A) and to be kept - # (B), common spaces/axes. - # The trace of the merged spaces (A \otimes B) can then be - # calculated as Tr_A(mat) using np.trace for input with - # more than two axes effectively resulting in - # pmat[j,k] = Sum_i mat[i,i,j,k] for all j,k = 0..prod(dims_keep) - pmat = jnp.trace(mat.reshape(dims + dims) - .transpose(remove + [n_dim + q for q in remove] + - keep + [n_dim + q for q in keep]) - .reshape([jnp.prod(dims_rm), - jnp.prod(dims_rm), - jnp.prod(dims_keep), - jnp.prod(dims_keep)]) - ) - - if do_copy: - return DenseOperatorJAX(pmat) - else: - self.data = pmat - return self - - def kron(self, other: 'DenseOperatorJAX') -> 'DenseOperatorJAX': - """See base class. """ - if type(other) == DenseOperatorJAX: - out = jnp.kron(self.data, other.data) - elif isinstance(other,jnp.ndarray) or isinstance(other,np.ndarray): - out = jnp.kron(self.data, other) - else: - raise ValueError('The kronecker product of dense control matrices' - 'is not defined for: ' + str(type(other))) - return DenseOperatorJAX(out) - - def _exp_diagonalize(self, tau: complex = 1, - is_skew_hermitian: bool = False) -> 'DenseOperatorJAX': - """ Calculates the matrix exponential by spectral decomposition. - - Refactored version of _spectral_decomp. - - Parameters - ---------- - tau : complex - The matrix is multiplied by tau. - - is_skew_hermitian : bool - If True, the matrix is expected to be skew hermitian. - - Returns - ------- - exp: DenseOperator - Dense operator matrix containing the matrix exponential. - - """ - if is_skew_hermitian: - eig_val, eig_vec = jnp.linalg.eigh(-1j * self.data) - eig_val = 1j * eig_val - else: - eig_val, eig_vec = jnp.linalg.eig(self.data) - - # apply the exponential function to the eigenvalues and invert the - # diagonalization transformation - exp = jnp.einsum('ij,j,kj->ik', eig_vec, jnp.exp(tau * eig_val), - eig_vec.conj()) - - return DenseOperatorJAX(exp) - - def _dexp_diagonalization(self, - direction: 'DenseOperatorJAX', tau: complex = 1, - is_skew_hermitian: bool = False, - compute_expm: bool = False): - """ Calculates the matrix exponential by spectral decomposition. - - Refactored version of _spectral_decomp. - - Parameters - ---------- - direction: DenseOperator - Direction in which the frechet derivative is calculated. Must be of - the same shape as self. - - tau : complex - The matrix is multiplied by tau. - - is_skew_hermitian : bool - If True, the matrix is expected to be skew hermitian. - - compute_expm : bool - If True, the matrix exponential is calculated as well. - - Returns - ------- - exp: DenseOperator - The matrix exponential. Only returned if compute_expm is set to - True. - - dexp: DenseOperator - Frechet derivative of the matrix exponential. - - """ - if is_skew_hermitian: - eig_val, eig_vec = jnp.linalg.eigh(-1j * self.data) - eig_val = 1j * eig_val - else: - eig_val, eig_vec = jnp.linalg.eig(self.data) - - eig_vec_dag = eig_vec.conj().T - - eig_val_cols = eig_val * jnp.ones(self.shape) - eig_val_diffs = eig_val_cols - eig_val_cols.T - - # avoid devision by zero - eig_val_diffs += jnp.eye(self.data.shape[0]) - - omega = (jnp.exp(eig_val_diffs * tau) - 1.) / eig_val_diffs - - # override the false diagonal elements. - np.fill_diagonal(omega, tau) - - direction_transformed = eig_vec @ direction.data @ eig_vec_dag - dk_dalpha = direction_transformed * omega - - exp = jnp.einsum('ij,j,jk->ik', eig_vec, jnp.exp(tau * eig_val), - eig_vec_dag) - # einsum might be less accurate than the @ operator - dv_dalpha = eig_vec_dag @ dk_dalpha @ eig_vec - du_dalpha = exp @ dv_dalpha - - if compute_expm: - return exp, du_dalpha - else: - return du_dalpha - - def spectral_decomposition(self, hermitian: bool = False): - """See base class. """ - if hermitian is False: - eig_val, eig_vec = jax.scipy.linalg.eig(self.data) - else: - eig_val, eig_vec = jax.scipy.linalg.eigh(self.data) - - return eig_val, eig_vec - - def exp(self, tau: complex = 1, - method: str = "spectral", - is_skew_hermitian: bool = False) -> 'DenseOperatorJAX': - """ - Matrix exponential. - - Parameters - ---------- - tau: complex - The matrix is multiplied by tau before calculating the exponential. - - method: string - Numerical method used for the calculation of the matrix - exponential. - Currently the following are implemented: - - 'approx', 'Frechet': use the scipy linalg matrix exponential - - 'first_order': First order taylor approximation - - 'second_order': Second order taylor approximation - - 'third_order': Third order taylor approximation - - 'spectral': Use the self implemented spectral decomposition - - is_skew_hermitian: bool - Only important for the method 'spectral'. If set to True then the - matrix is assumed to be skew hermitian in the spectral - decomposition. - - Returns - ------- - prop: DenseOperator - The matrix exponential. - - Raises - ------ - NotImplementedError: - If the method given as parameter is not implemented. - - """ - - if method == "spectral": - prop = self._exp_diagonalize(tau=tau, - is_skew_hermitian=is_skew_hermitian) - - elif method in ["approx", "Frechet"]: - prop = jax.scipy.linalg.expm(self.data * tau) - - elif method == "first_order": - prop = jnp.eye(self.data.shape[0]) + self.data * tau - - elif method == "second_order": - prop = jnp.eye(self.data.shape[0]) + self.data * tau - prop += self.data @ self.data * (tau * tau * 0.5) - - elif method == "third_order": - b = self.data * tau - prop = jnp.eye(self.data.shape[0]) + b - bb = b @ b * 0.5 - prop += bb - prop += bb @ b * 0.3333333333333333333 - else: - raise ValueError("Unknown or not specified method for the " - "calculation of the matrix exponential:" - + str(method)) - return DenseOperatorJAX(prop) - - def prop(self, tau: complex = 1) -> 'DenseOperatorJAX': - """See base class. """ - return DenseOperatorJAX(self.exp(tau)) - - def dexp(self, - direction: 'DenseOperatorJAX', - tau: complex = 1, - compute_expm: bool = False, - method: str = "spectral", - is_skew_hermitian: bool = False, - epsilon: float = 1e-10, - ) \ - -> Union['DenseOperatorJAX', Tuple['DenseOperatorJAX']]: - """ - Frechet derivative of the matrix exponential. - - Parameters - ---------- - direction: DenseOperator - Direction in which the frechet derivative is calculated. Must be of - the same shape as self. - - tau: complex - The matrix is multiplied by tau before calculating the exponential. - - compute_expm: bool - If true, then the matrix exponential is calculated and returned as - well. - - method: string - Numerical method used for the calculation of the matrix - exponential. - Currently the following are implemented: - - 'Frechet': Uses the scipy linalg matrix exponential for - simultaniously calculation of the frechet derivative expm_frechet - - 'approx': Approximates the Derivative by finite differences. - - 'first_order': First order taylor approximation - - 'second_order': Second order taylor approximation - - 'third_order': Third order taylor approximation - - 'spectral': Use the self implemented spectral decomposition - - is_skew_hermitian: bool - Only required, for the method 'spectral'. If set to True, then the - matrix is assumed to be skew hermitian in the spectral - decomposition. - - epsilon: float - Width of the finite difference. Only relevant for the method - 'approx'. - - Returns - ------- - prop: DenseOperator - The matrix exponential. Only returned if compute_expm is True! - prop_grad: DenseOperator - The frechet derivative d exp(Ax + B)/dx at x=0 where A is the - direction and B is the matrix stored in self. - - Raises - ------ - NotImplementedError: - If the method given as parameter is not implemented. - - """ - prop = None - - if type(direction) != DenseOperatorJAX: - direction = DenseOperatorJAX(direction) - - if method == "Frechet": - a = self.data * tau - e = direction.data * tau - if compute_expm: - prop, prop_grad = jax.scipy.linalg.expm_frechet( - a, e, compute_expm=True) - prop_grad = DenseOperatorJAX(prop_grad) - prop = DenseOperatorJAX(prop) - - else: - prop_grad = jax.scipy.linalg.expm_frechet( - a, e, compute_expm=False) - prop_grad = DenseOperatorJAX(prop_grad) - - - elif method == "spectral": - if compute_expm: - prop, prop_grad = self._dexp_diagonalization( - direction=direction, tau=tau, - is_skew_hermitian=is_skew_hermitian, - compute_expm=compute_expm - ) - else: - prop_grad = self._dexp_diagonalization( - direction=direction, tau=tau, - is_skew_hermitian=is_skew_hermitian, - compute_expm=compute_expm - ) - - elif method == "approx": - d_m = (self.data + epsilon * direction.data) * tau - dprop = jax.scipy.linalg.expm(d_m) - prop = self.exp(tau) - prop_grad = (dprop - prop) * (1 / epsilon) - - elif method == "first_order": - if compute_expm: - prop = self.exp(tau) - prop_grad = direction.data * tau - - elif method == "second_order": - if compute_expm: - prop = self.exp(tau) - prop_grad = direction.data * tau - prop_grad += (self.data @ direction.data - + direction.data @ self.data) * (tau * tau * 0.5) - - elif method == "third_order": - if compute_expm: - prop = self.exp(tau) - prop_grad = direction.data * tau - prop_grad += (self.data @ direction.data - + direction.data @ self.data) * tau * tau * 0.5 - prop_grad += ( - self.data @ self.data @ direction.data - + direction.data @ self.data @ self.data - + self.data @ direction.data @ self.data - ) * (tau * tau * tau * 0.16666666666666666) - else: - raise NotImplementedError( - 'The specified method ' + method + "is not implemented!") - if compute_expm: - if type(prop) != DenseOperatorJAX: - prop = DenseOperatorJAX(prop) - if type(prop_grad) != DenseOperatorJAX: - prop_grad = DenseOperatorJAX(prop_grad) - if compute_expm: - return prop, prop_grad - else: - return prop_grad - - def identity_like(self) -> 'DenseOperatorJAX': - """See base class. """ - assert self.shape[0] == self.shape[1] - return DenseOperatorJAX(jnp.eye(self.shape[0], dtype=complex)) - - def truncate_to_subspace( - self, subspace_indices: Optional[Sequence[int]], - map_to_closest_unitary: bool = False - ) -> 'DenseOperatorJAX': - """See base class. """ - if subspace_indices is None: - return self - elif self.shape[0] == self.shape[1]: - # square matrix - out = type(self)( - self.data[jnp.ix_(jnp.array(subspace_indices), - jnp.array(subspace_indices))]) - if map_to_closest_unitary: - out = closest_unitary(out) - elif self.shape[0] == 1: - # bra-vector - out = type(self)(self.data[jnp.ix_(jnp.array([0]), - jnp.array(subspace_indices))]) - if map_to_closest_unitary: - out *= 1 / out.norm('fro') - elif self.shape[0] == 1: - # ket-vector - out = type(self)(self.data[jnp.ix_(jnp.array(subspace_indices), - jnp.array([0]))]) - if map_to_closest_unitary: - out *= 1 / out.norm('fro') - else: - out = type(self)(self.data[jnp.ix_(jnp.array(subspace_indices))]) - - return out - - - diff --git a/qopt/noise.py b/qopt/noise.py index cbe77f3..bc81527 100644 --- a/qopt/noise.py +++ b/qopt/noise.py @@ -77,8 +77,6 @@ from qopt.util import deprecated -import random -from functools import partial def bell_curve_1dim(x: Union[np.ndarray, float], stdx: float) -> Union[np.ndarray, float]: @@ -693,370 +691,3 @@ def plot_periodogram(self, n_average: int, scaling: str = 'density', np.mean(spectral_density_or_spectrum, axis=0)[1:-1] - self.noise_spectral_density(sample_frequencies)[1:-1]) return deviation_norm - - -############################################################################### - -try: - import jax.numpy as jnp - from jax import jit, vmap - import jax - _HAS_JAX = True -except ImportError: - from unittest import mock - jit = mock.Mock() - jnp = mock.Mock() - vmap = mock.Mock() - jax = mock.Mock() - _HAS_JAX = False - - -@jit -def _inverse_cumulative_gaussian_distribution_function_jnp( - z: Union[float, np.array, jnp.ndarray], std: float, mean: float): - """ - Calculates the inverse cumulative function for the gaussian distribution. - - Parameters - ---------- - z: Union[float, np.array, jnp.array] - Function value. - - std: float - Standard deviation of the bell curve. - - mean: float - Mean value of the gaussian distribution. Defaults to 0. - - Returns - ------- - selected_x: list of float - Noise samples. - - """ - return std * jnp.sqrt(2) * jax.scipy.special.erfinv(2 * z - 1) + mean - - -@partial(jit,static_argnums=1) -def _sample_1dim_gaussian_distribution_jnp(std: float, n_samples: int, mean: float = 0)\ - -> jnp.ndarray: - """ - Returns 'n_samples' samples from the one dimensional bell curve. - - The samples are chosen such, that the integral over the bell curve between - two adjacent samples is always the same. The samples reproduce the correct - standard deviation only in the limit n_samples -> inf due to the - discreteness of the approximation. The error is to good approximation - 1/n_samples. - - Parameters - ---------- - std: float - Standard deviation of the bell curve. - - n_samples: int - Number of samples returned. - - mean: float - Mean value of the gaussian distribution. Defaults to 0. - - Returns - ------- - selected_x: numpy array of shape:(n_samples, ) - Noise samples. - - """ - z = jnp.linspace(start=0, stop=1, num=n_samples, endpoint=False) - z += 1 / (2 * n_samples) - # we distribute the total probability of 1 into n_samples equal parts. - # The z-values are in the center of each part. - - x = _inverse_cumulative_gaussian_distribution_function_jnp( - z=jnp.expand_dims(z,0), std=jnp.expand_dims(std,1), mean=mean - ) - # We use the inverse cumulative gaussian distribution to find the values x. - # The integral over the Gaussian distribution between x[i] and x[i+1] - # now always equals 1/n_samples. - return x - - -class NTGQuasiStaticJAX(NoiseTraceGenerator): - """See docstring of class w/o JAX. - - Additional parameter: seed: int, optional: seed for jax.random.PRNGKey - """ - - - def __init__(self, standard_deviation: List[float], - n_samples_per_trace: int, - n_traces: int = 1, - noise_samples: Optional[np.ndarray] = None, - always_redraw_samples: bool = True, - correct_std_for_discrete_sampling: bool = True, - sampling_mode: str = 'uncorrelated_deterministic', - seed: Optional[int] = None): - if not _HAS_JAX: - raise ImportError("JAX not available") - n_noise_operators = len(standard_deviation) - super().__init__(noise_samples=noise_samples, - n_samples_per_trace=n_samples_per_trace, - n_traces=n_traces, - n_noise_operators=n_noise_operators, - always_redraw_samples=always_redraw_samples) - self.standard_deviation = jnp.asarray(standard_deviation) - - self.sampling_mode = sampling_mode - self.seed = seed if seed is not None else random.randint(0,2**32-1) - self.rnd_key_first = jax.random.PRNGKey(self.seed) - self.rnd_key_arr = [self.rnd_key_first] - - if correct_std_for_discrete_sampling: - if self.n_traces == 1: - raise RuntimeWarning('Standard deviation cannot be estimated' - 'for a single trace!') - elif self.sampling_mode == 'uncorrelated_deterministic': - - - n_std_dev = len(self.standard_deviation) - _noise_samples = _sample_1dim_gaussian_distribution_jnp( - self.standard_deviation, self._n_traces) - _noise_samples = jnp.broadcast_to( - jnp.expand_dims(jnp.tile(_noise_samples,n_std_dev)* - jnp.repeat(jnp.eye(n_std_dev),self._n_traces,axis=1),2), - (n_std_dev,self._n_traces*n_std_dev,self.n_samples_per_trace)) - - actual_std = jnp.std(_noise_samples,axis=(1,2)) - if jnp.any(actual_std < 1e-20): - raise RuntimeError('The standard deviation was ' - 'estimated close to 0!') - self.standard_deviation *= \ - self.standard_deviation / actual_std - - @property - def n_traces(self) -> int: - """Number of traces. - - The number of requested traces must be multiplied with the number of - standard deviations because if standard deviation is sampled - separately. - - """ - if self._n_traces: - if self.sampling_mode == 'uncorrelated_deterministic': - return self._n_traces * len(self.standard_deviation) - elif self.sampling_mode == 'monte_carlo': - return self._n_traces - else: - raise ValueError('Unsupported sampling mode!') - else: - return self.noise_samples.shape[1] - - def _sample_noise(self) -> None: - """ - Draws quasi static noise samples from a normal distribution. - - Each noise contribution (corresponding to one noise operator) is - sampled separately. For each standard deviation n_traces traces are - calculated. - - """ - if self.sampling_mode == 'uncorrelated_deterministic': - - n_std_dev = len(self.standard_deviation) - _noise_samples = _sample_1dim_gaussian_distribution_jnp( - self.standard_deviation, self._n_traces) - self._noise_samples = jnp.broadcast_to( - jnp.expand_dims(jnp.tile(_noise_samples,n_std_dev)* - jnp.repeat(jnp.eye(n_std_dev),self._n_traces,axis=1),2), - (n_std_dev,self._n_traces*n_std_dev,self.n_samples_per_trace)) - - elif self.sampling_mode == 'monte_carlo': - - self._noise_samples = jnp.einsum( - 'i,ijk->ijk', - self.standard_deviation, - jax.random.normal( - key=self.rnd_key_arr[-1], - shape=(len(self.standard_deviation),self.n_traces,1)) - ) - self._noise_samples = jnp.repeat( - self._noise_samples, self.n_samples_per_trace, axis=2) - - self.rnd_key_arr.append( - jax.random.split(self.rnd_key_arr[-1],num=2)[1]) - - else: - raise ValueError('Unsupported sampling mode!') - - -def _fast_colored_noise_jnp(spectral_density: Callable, dt: float, - n_samples: int, output_shape: tuple, key, - r_power_of_two=False - ) -> jnp.ndarray: - """See docstring of function without _jnp""" - f_max = 1 / dt - f_nyquist = f_max / 2 - s0 = 1 / f_nyquist - if r_power_of_two: - actual_n_samples = int(2 ** jnp.ceil(jnp.log2(n_samples))) - else: - actual_n_samples = int(n_samples) - - delta_white = jax.random.normal(key,(*output_shape, actual_n_samples)) - delta_white_ft = jnp.fft.rfft(delta_white, axis=-1) - # Only positive frequencies since FFT is real and therefore symmetric - f = jnp.linspace(0, f_nyquist, actual_n_samples // 2 + 1) - f = spectral_density(f[1:]) - f = jnp.pad(f,((1, 0),)) - delta_colored = jnp.fft.irfft(delta_white_ft * jnp.sqrt(f / s0), - n=actual_n_samples, axis=-1) - # the ifft takes r//2 + 1 inputs to generate r outputs - - return delta_colored - - -class NTGColoredNoiseJAX(NoiseTraceGenerator): - """See docstring of class w/o JAX. - - Additional parameter: seed: int, optional: seed for jax.random.PRNGKey - """ - - def __init__(self, - n_samples_per_trace: int, - noise_spectral_density: Callable, - dt: float, - n_traces: int = 1, - n_noise_operators: int = 1, - always_redraw_samples: bool = True, - low_frequency_extension_ratio: int = 1, - seed: Optional[int] = None): - if not _HAS_JAX: - raise ImportError("JAX not available") - super().__init__(n_traces=n_traces, - n_samples_per_trace=n_samples_per_trace, - noise_samples=None, - n_noise_operators=n_noise_operators, - always_redraw_samples=always_redraw_samples) - self.noise_spectral_density = noise_spectral_density - self.dt = dt - if low_frequency_extension_ratio < 1: - raise ValueError("The low frequency extension ratio must be " - "greater or equal to 1.") - self.low_frequency_extension_ratio = low_frequency_extension_ratio - if hasattr(dt, "__len__"): - raise ValueError('dt is supposed to be a scalar value!') - - self.seed = seed if seed is not None else random.randint(0,2**32-1) - self.rnd_key_first = jax.random.PRNGKey(self.seed) - self.rnd_key_arr = [self.rnd_key_first] - - def _sample_noise(self, **kwargs) -> None: - """Samples noise from an arbitrary colored spectrum. """ - if self._n_noise_operators is None: - raise ValueError('Please specify the number of noise operators!') - if self._n_traces is None: - raise ValueError('Please specify the number of noise traces!') - if self._n_samples_per_trace is None: - raise ValueError('Please specify the number of noise samples per' - 'trace!') - - - noise_samples = _fast_colored_noise_jnp( - spectral_density=self.noise_spectral_density, - n_samples= - self.n_samples_per_trace * self.low_frequency_extension_ratio, - output_shape=(self.n_noise_operators, self.n_traces), - r_power_of_two=False, - dt=self.dt, - key=self.rnd_key_arr[-1]) - self._noise_samples = noise_samples[:, :, :self.n_samples_per_trace] - - self.rnd_key_arr.append( - jax.random.split(self.rnd_key_arr[-1],num=2)[1]) - - def plot_periodogram(self, n_average: int, scaling: str = 'density', - log_plot: Optional[str] = None, draw_plot=True): - """Creates noise samples and plots the corresponding periodogram. - - Parameters - ---------- - n_average: int - Number of Periodograms which are averaged. - - scaling: {'density', 'spectrum'}, optional - If 'density' then the power spectral density in units of V**2/Hz is - plotted. - If 'spectral' then the power spectrum in units of V**2 is plotted. - Defaults to 'density'. - - log_plot: {None, 'semilogy', 'semilogx', 'loglog'}, optional - If None, then the plot is not plotted logarithmically. If - 'semilogy' only the y-axis is plotted logarithmically, if - 'semilogx' only the x-axis is plotted logarithmically, if 'loglog' - both axis are plotted logarithmically. Defaults to None. - - draw_plot: bool, optional - If true, then the periodogram is plotted. Defaults to True. - - Returns - ------- - deviation_norm: float - The vector norm of the deviation between the actual power spectral - density and the power spectral densitry found in the periodogram. - - """ - - noise_samples = fast_colored_noise( - spectral_density=self.noise_spectral_density, - n_samples=self.n_samples_per_trace, - output_shape=(n_average,), - r_power_of_two=False, - dt=self.dt - ) - - sample_frequencies, spectral_density_or_spectrum = signal.periodogram( - x=noise_samples, - fs=1 / self.dt, - return_onesided=True, - scaling=scaling, - axis=-1 - ) - - if scaling == 'density': - y_label = 'Power Spectral Density (V**2/Hz)' - elif scaling == 'spectrum': - y_label = 'Power Spectrum (V**2)' - else: - raise ValueError('Unexpected scaling argument.') - - if draw_plot: - plt.figure() - - if log_plot is None: - plot_function = plt.plot - elif log_plot == 'semilogy': - plot_function = plt.semilogy - elif log_plot == 'semilogx': - plot_function = plt.semilogx - elif log_plot == 'loglog': - plot_function = plt.loglog - else: - raise ValueError('Unexpected plotting mode') - - plot_function(sample_frequencies, - np.mean(spectral_density_or_spectrum, axis=0), - label='Periodogram') - plot_function(sample_frequencies, - self.noise_spectral_density(sample_frequencies), - label='Spectral Noise Density') - - plt.ylabel(y_label) - plt.xlabel('Frequency (Hz)') - plt.legend(['Periodogram', 'Spectral Noise Density']) - plt.show() - - deviation_norm = np.linalg.norm( - np.mean(spectral_density_or_spectrum, axis=0)[1:-1] - - self.noise_spectral_density(sample_frequencies)[1:-1]) - return deviation_norm - diff --git a/qopt/optimize.py b/qopt/optimize.py index dbfe309..c191a1a 100644 --- a/qopt/optimize.py +++ b/qopt/optimize.py @@ -125,7 +125,7 @@ class Optimizer(ABC): use_jacobian_function: bool, optional If set to true, then the jacobians are calculated analytically. - Defaults to False. + Defaults to True. store_optimizer: bool, optional If True, then the optimizer stores itself in the result class. @@ -266,86 +266,6 @@ def cost_jacobian_wrapper(self, optimization_parameters): self._n_jac_fkt_eval += 1 return jacobian - def cost_func_wrapper_global(self, optimization_parameters): - """Wraps the cost function given by the simulator class. - - The relevant information for the analysis is saved. - - Parameters - ---------- - optimization_parameters: np.array - Raw optimization parameters in a linear array. - - Returns - ------- - costs: np.array, shape (n_fun) - Cost values. - - """ - if (time.time() - self._opt_start_time) \ - > self.termination_conditions['max_wall_time']: - raise WallTimeExceeded - - costs = self.system_simulator.wrapped_cost_functions_test( - optimization_parameters.reshape(self.pulse_shape[::-1]).T) - - if self.save_intermediary_steps: - self.optim_iter_summary.iter_num += 1 - self.optim_iter_summary.costs.append(costs) - self.optim_iter_summary.parameters.append( - optimization_parameters.reshape(self.pulse_shape[::-1]).T - ) - if np.linalg.norm(costs) < np.linalg.norm(self._min_costs): - self._min_costs = costs - self._min_costs_par = optimization_parameters.reshape( - self.pulse_shape[::-1]).T - - # apply the cost function weights after saving the values. - if self.cost_func_weights is not None: - costs *= self.cost_func_weights - - self._n_cost_fkt_eval += 1 - return costs - - def cost_jacobian_wrapper_global(self, optimization_parameters, scale_ind=[]): - """Wraps the cost Jacobian function given by the simulator class. - - The relevant information for the analysis is saved. - - Parameters - ---------- - optimization_parameters: np.array - Raw optimization parameters in a linear array. - - Returns - ------- - jacobian: np.array, shape (num_func, num_t * num_amp) - Jacobian of the cost functions. - - """ - jacobian = self.system_simulator.wrapped_jac_function_test( - optimization_parameters.reshape(self.pulse_shape[::-1]).T) - - if self.save_intermediary_steps: - self.optim_iter_summary.gradients.append(jacobian) - - jacobian[:,:,scale_ind] = jacobian[:,:,scale_ind]/(1+self._n_jac_fkt_eval) - - # jacobian shape (num_t, num_f, num_ctrl) -> (num_f, num_t * num_ctrl) - jacobian = jacobian.transpose([1, 2, 0]) - jacobian = jacobian.reshape( - (jacobian.shape[0], jacobian.shape[1] * jacobian.shape[2])) - - # apply the cost function weights after saving the values. - if self.cost_func_weights is not None: - jacobian = np.einsum('ab, a -> ab', jacobian, - self.cost_func_weights) - - self._n_jac_fkt_eval += 1 - return jacobian - - ########################### - @abstractmethod def run_optimization(self, initial_control_amplitudes: np.ndarray, verbose) \ @@ -382,11 +302,7 @@ def prepare_optimization(self, self._min_costs_par = None self._n_cost_fkt_eval = 0 self._n_jac_fkt_eval = 0 - try: - self.pulse_shape = initial_optimization_parameters.shape - except: - self.pulse_shape = len(initial_optimization_parameters) - + self.pulse_shape = initial_optimization_parameters.shape if self.save_intermediary_steps: self.optim_iter_summary = \ optimization_data.OptimizationSummary( @@ -543,111 +459,6 @@ def run_optimization(self, initial_control_amplitudes: np.array, return optim_result -class LeastSquaresOptimizerGlobal(Optimizer): - """ - Uses the scipy least squares method for optimization. - - Parameters - ---------- - system_simulator: `Simulator` - The systems simulator. - - termination_cond: dict - Termination conditions. - - save_intermediary_steps: bool, optional - If False, only the simulation result is stored. Defaults to False. - - method: str, optional - The optimization method used. Currently implemented are: - - 'trf': A trust region optimization algorithm. This is the default. - - bounds: array or list of boundaries, optional - The boundary conditions for the pulse optimizations. If none are given - then the pulse is assumed to take any real value. - - """ - - def __init__( - self, - n_time_steps_ctrl: int, - system_simulator: Optional[simulator.Simulator] = None, - termination_cond: Optional[Dict] = None, - save_intermediary_steps: bool = True, - method: str = 'trf', - bounds: Union[np.ndarray, List, None] = None, - use_jacobian_function=True, - cost_func_weights: Optional[Sequence[float]] = None, - store_optimizer: bool = False, - scale_down_grad_ind = []): - super().__init__(system_simulator=system_simulator, - termination_cond=termination_cond, - save_intermediary_steps=save_intermediary_steps, - cost_func_weights=cost_func_weights, - use_jacobian_function=use_jacobian_function, - store_optimizer=store_optimizer) - self.method = method - self.bounds = bounds - self.n_time_steps_ctrl = n_time_steps_ctrl - - self.scale_down_grad_ind = scale_down_grad_ind - - def cost_jacobian_wrapper_test(self,optimization_parameters): - - return super().cost_jacobian_wrapper_global(optimization_parameters,self.scale_down_grad_ind) - - def run_optimization(self, initial_control_amplitudes: np.array, - verbose: int = 0) -> optimization_data.OptimizationResult: - """See base class. - """ - super().prepare_optimization( - initial_optimization_parameters=initial_control_amplitudes) - - if self.use_jacobian_function: - jac = self.cost_jacobian_wrapper_test - else: - jac = '2-point' - - try: - result = scipy.optimize.least_squares( - fun=super().cost_func_wrapper_global, - x0=initial_control_amplitudes.T.flatten(), - jac=jac, - bounds=self.bounds, - method=self.method, - ftol=self.termination_conditions["min_cost_gain"], - xtol=self.termination_conditions["min_amplitude_change"], - gtol=self.termination_conditions["min_gradient_norm"], - max_nfev=self.termination_conditions["max_iterations"], - verbose=verbose, - x_scale="jac" - ) - - if self.system_simulator.stats is not None: - self.system_simulator.stats.end_t_opt = time.time() - - if self.store_optimizer: - storage_opt = self - else: - storage_opt = None - - optim_result = optimization_data.OptimizationResult( - final_cost=result.fun, - indices=self.system_simulator.cost_indices, - final_parameters=[result.x]*self.n_time_steps_ctrl, - final_grad_norm=np.linalg.norm(result.grad), - num_iter=result.nfev, - termination_reason=result.message, - status=result.status, - optimizer=storage_opt, - optim_summary=self.optim_iter_summary, - optimization_stats=self.system_simulator.stats - ) - except WallTimeExceeded: - optim_result = self.write_state_to_result() - - return optim_result - class ScalarMinimizingOptimizer(Optimizer): """ Interfaces to the minimize functions of the optimization package in @@ -1068,624 +879,3 @@ def prepare_optimization(self, super().prepare_optimization( initial_optimization_parameters=initial_optimization_parameters) self.annealer.state = initial_optimization_parameters - - -class SimulatedAnnealingScipy(Optimizer): - """ - This class uses simulated annealing for discrete optimization. - - Parameters - ---------- - temperature: float - Initial temperature for the annealing algorithm. - - step_size: int - Initial stepsize. - - interval: int - Number of optimization iterations before the step size is reduced. - - bounds: array of boundaries, shape: (2, num_t, num_ctrl) - The boundary conditions for the pulse optimizations. bounds[0] should - be the lower bounds, and bounds[1] the upper ones. - - """ - - def __init__( - self, - system_simulator: Optional[simulator.Simulator] = None, - termination_cond: Optional[Dict] = None, - save_intermediary_steps: bool = False, - store_optimizer: bool = False, - temperature: float = 1., - step_size: int = 1, - interval: int = 50, - bounds: Optional[np.ndarray] = None - ): - super().__init__( - system_simulator=system_simulator, - termination_cond=termination_cond, - save_intermediary_steps=save_intermediary_steps, - store_optimizer=store_optimizer - ) - self.temperature = temperature - self.step_size = step_size - self.interval = interval - self.bounds = bounds - - def run_optimization(self, initial_control_amplitudes: np.ndarray, - verbose: bool = False): - """See base class. """ - - super().prepare_optimization( - initial_optimization_parameters=initial_control_amplitudes) - - if self.store_optimizer: - storage_opt = self - else: - storage_opt = None - - try: - result = scipy.optimize.basinhopping( - func=self.cost_func_wrapper, - x0=initial_control_amplitudes.T.flatten(), - niter=self.termination_conditions["max_iterations"], - T=self.temperature, - stepsize=self.step_size, - take_step=self._take_step, - callback=None, - interval=self.interval, - disp=verbose - ) - - if self.system_simulator.stats is not None: - self.system_simulator.stats.end_t_opt = time.time() - - optim_result = optimization_data.OptimizationResult( - final_cost=result.fun, - indices=self.system_simulator.cost_indices, - final_parameters=result.x.reshape(self.pulse_shape[::-1]).T, - num_iter=result.nfev, - termination_reason=result.message, - status=result.status, - optimizer=storage_opt, - optim_summary=self.optim_iter_summary, - optimization_stats=self.system_simulator.stats - ) - - except WallTimeExceeded: - if self.system_simulator.stats is not None: - self.system_simulator.stats.end_t_opt = time.time() - - optim_result = optimization_data.OptimizationResult( - final_cost=self._min_costs, - indices=self.system_simulator.cost_indices, - final_parameters=self._min_costs_par, - num_iter=self._n_cost_fkt_eval, - termination_reason='Maximum Wall Time Exceeded', - status=5, - optimizer=storage_opt, - optim_summary=self.optim_iter_summary, - optimization_stats=self.system_simulator.stats - ) - - return optim_result - - def _take_step(self, current_pulse: np.ndarray) -> np.ndarray: - """ - This function applies a random discrete variation to the pulse. - - Parameters - ---------- - current_pulse: array of int - The pulse before the application of the take step function. - - Returns - ------- - new_pulse: array of int - The pulse initial pulse plus a random variation. - - """ - pulse = current_pulse.reshape(self.pulse_shape[::-1]).T - - if type(self.step_size) != int: - raise ValueError("The step size must be integer! But it is: " - + str(self.step_size)) - - if self.step_size == 0: - raise ValueError("The step size has been set to 0.") - - random_step = np.random.randint( - low=-1 * self.step_size, - high=self.step_size + 1, - size=pulse.shape - ) - - new_pulse = pulse + random_step - - # if a limit is exceeded, set the value to the limit - lower_limit_exceeded = new_pulse < self.bounds[0] - upper_limit_exceeded = new_pulse > self.bounds[1] - - new_pulse[lower_limit_exceeded] = self.bounds[0][lower_limit_exceeded] - new_pulse[upper_limit_exceeded] = self.bounds[1][upper_limit_exceeded] - - return new_pulse.T.flatten() - - -############################################################################### - -try: - import jax.numpy as jnp - from jax import jit, vmap - import jax - _HAS_JAX = True -except ImportError: - from unittest import mock - jit = mock.Mock() - jnp = mock.Mock() - vmap = mock.Mock() - jax = mock.Mock() - _HAS_JAX = False - - -class OptimizerJAX(ABC): - """See docstring of class w/o JAX. Requires simulator with JAX""" - - def __init__( - self, - system_simulator: Optional[simulator.SimulatorJAX] = None, - termination_cond: Optional[Dict] = None, - save_intermediary_steps: bool = True, - cost_func_weights: Optional[Sequence[float]] = None, - use_jacobian_function=True, - store_optimizer: bool = False - ): - if not _HAS_JAX: - raise ImportError("JAX not available") - self.system_simulator = system_simulator - self.use_jacobian_function = use_jacobian_function - self.termination_conditions = default_termination_conditions - if termination_cond is not None: - self.termination_conditions.update(**termination_cond) - - self.optim_iter_summary = None - self.pulse_shape = () - - self._opt_start_time = 0 - self._min_costs = jnp.inf - self._min_costs_par = None - self._n_cost_fkt_eval = 0 - self._n_jac_fkt_eval = 0 - - # flags: - self.save_intermediary_steps = save_intermediary_steps - self.store_optimizer = store_optimizer - - self.cost_func_weights = cost_func_weights - - if self.cost_func_weights is not None: - self.cost_func_weights = jnp.asarray( - self.cost_func_weights).flatten() - if len(self.cost_func_weights) == 0: - self.cost_func_weights = None - elif not len(self.system_simulator.cost_funcs) == len( - self.cost_func_weights): - raise ValueError('A cost function weight must be specified for' - 'each cost function or for none at all.') - - def cost_func_wrapper(self, optimization_parameters): - """Wraps the cost function given by the simulator class. - - The relevant information for the analysis is saved. - - Parameters - ---------- - optimization_parameters: Union[np.array, jnp.ndarray] - Raw optimization parameters in a linear array. - - Returns - ------- - costs: jnp.array, shape (n_fun) - Cost values. - - """ - if (time.time() - self._opt_start_time) \ - > self.termination_conditions['max_wall_time']: - raise WallTimeExceeded - - costs = self.system_simulator.wrapped_cost_functions( - optimization_parameters.reshape(self.pulse_shape[::-1]).T) - - if self.save_intermediary_steps: - self.optim_iter_summary.iter_num += 1 - self.optim_iter_summary.costs.append(costs) - self.optim_iter_summary.parameters.append( - optimization_parameters.reshape(self.pulse_shape[::-1]).T - ) - if jnp.linalg.norm(costs) < jnp.linalg.norm(self._min_costs): - self._min_costs = costs - self._min_costs_par = optimization_parameters.reshape( - self.pulse_shape[::-1]).T - - # apply the cost function weights after saving the values. - if self.cost_func_weights is not None: - costs *= self.cost_func_weights - - self._n_cost_fkt_eval += 1 - return costs - - def cost_jacobian_wrapper(self, optimization_parameters): - """Wraps the cost Jacobian function given by the simulator class. - - The relevant information for the analysis is saved. - - Parameters - ---------- - optimization_parameters: Union[np.array, jnp.ndarray] - Raw optimization parameters in a linear array. - - Returns - ------- - jacobian: jnp.array, shape (num_func, num_t * num_amp) - Jacobian of the cost functions. - - """ - jacobian = self.system_simulator.wrapped_jac_function( - optimization_parameters.reshape(self.pulse_shape[::-1]).T) - - if self.save_intermediary_steps: - self.optim_iter_summary.gradients.append(jacobian) - - # jacobian shape (num_t, num_f, num_ctrl) -> (num_f, num_t * num_ctrl) - jacobian = jacobian.transpose([1, 2, 0]) - jacobian = jacobian.reshape( - (jacobian.shape[0], jacobian.shape[1] * jacobian.shape[2])) - - # apply the cost function weights after saving the values. - if self.cost_func_weights is not None: - jacobian = jnp.einsum('ab, a -> ab', jacobian, - self.cost_func_weights) - - self._n_jac_fkt_eval += 1 - return jacobian - - @abstractmethod - def run_optimization( - self, - initial_control_amplitudes: Union[np.ndarray,jnp.ndarray], - verbose) \ - -> optimization_data.OptimizationResult: - """Runs the optimization of the control amplitudes. - - Parameters - ---------- - initial_control_amplitudes : array - shape (num_t, num_ctrl) - verbose - Verbosity of the run. Depends on which optimizer is used. - - Returns - ------- - optimization_result : `OptimizationResult` - The resulting data of the simulation. - - """ - pass - - def prepare_optimization( - self, - initial_optimization_parameters: Union[np.ndarray,jnp.ndarray]): - """Prepare for the next optimization. - - Parameters - ---------- - initial_optimization_parameters : array - shape (num_t, num_ctrl) - - Data stored in this class might be overwritten. - """ - self._min_costs = jnp.inf - self._min_costs_par = None - self._n_cost_fkt_eval = 0 - self._n_jac_fkt_eval = 0 - self.pulse_shape = initial_optimization_parameters.shape - if self.save_intermediary_steps: - self.optim_iter_summary = \ - optimization_data.OptimizationSummary( - indices=self.system_simulator.cost_indices - ) - self._opt_start_time = time.time() - if self.system_simulator.stats is not None: - # If the system simulator wants to write down statistics, then - # initialise a fresh instance - self.system_simulator.stats = \ - performance_statistics.PerformanceStatistics() - self.system_simulator.stats.start_t_opt = float( - self._opt_start_time) - self.system_simulator.stats.indices = \ - self.system_simulator.cost_indices - - def write_state_to_result(self): - """ Writes the current state into an instance of 'OptimizationResult'. - - Intended for saving progress when terminating the optimization in an - unexpected way. - - Returns - ------- - result: optimization_data.OptimizationResult - The current result of the optimization. - - """ - if self.system_simulator.stats is not None: - self.system_simulator.stats.end_t_opt = time.time() - - if self.use_jacobian_function: - jac_norm = jnp.linalg.norm( - self.cost_jacobian_wrapper(self._min_costs_par)) - else: - jac_norm = 0 - - if self.store_optimizer: - storage_opt = self - else: - storage_opt = None - - optim_result = optimization_data.OptimizationResult( - final_cost=self._min_costs, - indices=self.system_simulator.cost_indices, - final_parameters=self._min_costs_par, - final_grad_norm=jac_norm, - num_iter=self._n_cost_fkt_eval, - termination_reason='Maximum Wall Time Exceeded', - status=5, - optimizer=storage_opt, - optim_summary=self.optim_iter_summary, - optimization_stats=self.system_simulator.stats - ) - return optim_result - - -#only changes are np.array() on jax arrays in the end to be picklable, -#jax.scipy.optimize not usable in qopt workflow (?) -class LeastSquaresOptimizerJAX(OptimizerJAX): - """See docstring of class w/o JAX.""" - - def __init__( - self, - system_simulator: Optional[simulator.SimulatorJAX] = None, - termination_cond: Optional[Dict] = None, - save_intermediary_steps: bool = True, - method: str = 'trf', - bounds: Union[np.ndarray, jnp.array, List, None] = None, - use_jacobian_function=True, - cost_func_weights: Optional[Sequence[float]] = None, - store_optimizer: bool = False, - x_scale = 1.): - super().__init__(system_simulator=system_simulator, - termination_cond=termination_cond, - save_intermediary_steps=save_intermediary_steps, - cost_func_weights=cost_func_weights, - use_jacobian_function=use_jacobian_function, - store_optimizer=store_optimizer) - self.method = method - self.bounds = bounds - self.x_scale = x_scale - - def run_optimization(self, - initial_control_amplitudes: Union[np.array,jnp.array], - verbose: int = 0 - ) -> optimization_data.OptimizationResult: - """See base class. """ - super().prepare_optimization( - initial_optimization_parameters=initial_control_amplitudes) - - if self.use_jacobian_function: - jac = super().cost_jacobian_wrapper - else: - jac = '2-point' - - try: - result = scipy.optimize.least_squares( - fun=super().cost_func_wrapper, - x0=initial_control_amplitudes.T.flatten(), - jac=jac, - bounds=self.bounds, - method=self.method, - ftol=self.termination_conditions["min_cost_gain"], - xtol=self.termination_conditions["min_amplitude_change"], - gtol=self.termination_conditions["min_gradient_norm"], - max_nfev=self.termination_conditions["max_iterations"], - verbose=verbose, - x_scale=self.x_scale - ) - - if self.system_simulator.stats is not None: - self.system_simulator.stats.end_t_opt = time.time() - - if self.store_optimizer: - storage_opt = self - else: - storage_opt = None - - optim_result = optimization_data.OptimizationResult( - final_cost=np.array(result.fun), - indices=self.system_simulator.cost_indices, - final_parameters=np.array(result.x.reshape( - self.pulse_shape[::-1]).T), - final_grad_norm=np.linalg.norm(np.array(result.grad)), - num_iter=result.nfev, - termination_reason=result.message, - status=result.status, - optimizer=storage_opt, - optim_summary=self.optim_iter_summary, - optimization_stats=self.system_simulator.stats - ) - except WallTimeExceeded: - optim_result = self.write_state_to_result() - - return optim_result - - -class ScalarMinimizingOptimizerJAX(OptimizerJAX): - """See docstring of class w/o JAX.""" - - def __init__( - self, - system_simulator: Optional[simulator.SimulatorJAX] = None, - termination_cond: Optional[Dict] = None, - save_intermediary_steps: bool = True, - method: str = 'L-BFGS-B', - bounds: Union[np.ndarray, List, None] = None, - use_jacobian_function=True, - cost_func_weights: Optional[Sequence[float]] = None, - store_optimizer: bool = False, - ): - super().__init__(system_simulator=system_simulator, - termination_cond=termination_cond, - save_intermediary_steps=save_intermediary_steps, - cost_func_weights=cost_func_weights, - use_jacobian_function=use_jacobian_function, - store_optimizer=store_optimizer) - self.method = method - self.bounds = bounds - - - def cost_func_wrapper(self, optimization_parameters): - """ Evalutes the cost function. - - The total cost function is defined as the sum of cost functions. - - """ - costs = super().cost_func_wrapper(optimization_parameters) - scalar_costs = jnp.sum(costs) - #need to convert devicearray to float (?) - return float(scalar_costs) - - def cost_jacobian_wrapper(self, optimization_parameters): - """ The Jacobian reduced to the gradient. - - The gradient is calculated by summation over the Jacobian along the - function axis, because the total cost function is defined as the sum - of cost functions. - - Returns - ------- - gradient: numpy array, shape (num_t * num_amp) - The gradient of the costs in the 2 norm. - - """ - jac = super().cost_jacobian_wrapper(optimization_parameters) - grad = (jnp.sum(jac, axis=0)) - return np.array(grad,copy=True) - - def run_optimization(self, - initial_control_amplitudes: Union[np.array,jnp.array], - verbose: bool = False - ) -> optimization_data.OptimizationResult: - super().prepare_optimization( - initial_optimization_parameters=initial_control_amplitudes) - - if self.use_jacobian_function: - jac = self.cost_jacobian_wrapper - else: - jac = None - - if self.method == 'L-BFGS-B': - try: - result = scipy.optimize.minimize( - fun=self.cost_func_wrapper, - x0=initial_control_amplitudes.T.flatten(), - jac=jac, - bounds=self.bounds, - method=self.method, - options={ - 'ftol': self.termination_conditions["min_cost_gain"], - 'gtol': self.termination_conditions["min_gradient_norm"], - 'maxiter': self.termination_conditions["max_iterations"], - 'disp': verbose - } - ) - - if self.store_optimizer: - storage_opt = self - else: - storage_opt = None - - optim_result = optimization_data.OptimizationResult( - final_cost=np.array(result.fun), - indices=self.system_simulator.cost_indices, - final_parameters=np.array(result.x.reshape( - self.pulse_shape[::-1]).T), - final_grad_norm=np.linalg.norm(np.array(result.jac)), - num_iter=result.nfev, - termination_reason=result.status, - status=result.status, - optimizer=storage_opt, - optim_summary=self.optim_iter_summary, - optimization_stats=self.system_simulator.stats - ) - except WallTimeExceeded: - optim_result = self.write_state_to_result() - - elif self.method == 'Nelder-Mead': - try: - result = scipy.optimize.minimize( - fun=self.cost_func_wrapper, - x0=initial_control_amplitudes.T.flatten(), - bounds=self.bounds, - method=self.method, - options={ - 'maxiter': self.termination_conditions[ - "max_iterations"]}, - ) - - if self.store_optimizer: - storage_opt = self - else: - storage_opt = None - - optim_result = optimization_data.OptimizationResult( - final_cost=np.array(result.fun), - indices=self.system_simulator.cost_indices, - final_parameters=np.array(result.x.reshape( - self.pulse_shape[::-1]).T), - num_iter=result.nfev, - termination_reason=result.message, - status=result.status, - optimizer=storage_opt, - optim_summary=self.optim_iter_summary, - optimization_stats=self.system_simulator.stats - ) - except WallTimeExceeded: - optim_result = self.write_state_to_result() - - else: - try: - result = scipy.optimize.minimize( - fun=self.cost_func_wrapper, - x0=initial_control_amplitudes.T.flatten(), - bounds=self.bounds, - method=self.method - ) - - optim_result = optimization_data.OptimizationResult( - final_cost=np.array(result.fun), - indices=self.system_simulator.cost_indices, - final_parameters=np.array(result.x.reshape( - self.pulse_shape[::-1]).T), - num_iter=result.nfev, - termination_reason=result.message, - status=result.status, - optimizer=self, - optim_summary=self.optim_iter_summary, - optimization_stats=self.system_simulator.stats - ) - except WallTimeExceeded: - optim_result = self.write_state_to_result() - - if self.system_simulator.stats is not None: - self.system_simulator.stats.end_t_opt = time.time() - - return optim_result diff --git a/qopt/plotting.py b/qopt/plotting.py index 2435bdd..ff39a5e 100644 --- a/qopt/plotting.py +++ b/qopt/plotting.py @@ -112,7 +112,7 @@ def plot_bloch_vector_evolution( states = [ qt.Qobj((prop * initial_state).data) for prop in forward_propagators ] - a = np.empty((3, len(states)),dtype=complex) # for numerical integrity + a = np.empty((3, len(states))) x, y, z = qt.sigmax(), qt.sigmay(), qt.sigmaz() for i, state in enumerate(states): a[:, i] = [qt.expect(x, state), diff --git a/qopt/simulator.py b/qopt/simulator.py index 1331675..f248460 100644 --- a/qopt/simulator.py +++ b/qopt/simulator.py @@ -59,6 +59,7 @@ from qopt.util import needs_refactoring + class Simulator(object): """ The Dynamics class provides the interface for the Optimizer class. @@ -291,155 +292,6 @@ def wrapped_jac_function(self, pulse=None): return total_jac - def wrapped_cost_functions_test(self, pulse=None): - """ - Wraps the cost functions of the fidelity computer. - - This function coordinates the complete simulation including the - application of the transfer function, the execution of the time - slot computer and the evaluation of the actual cost functions. - - Parameters - ---------- - pulse: numpy array optional - If no pulse is specified the cost function is evaluated for the - attribute pulse. - - Returns - ------- - costs: numpy array, shape (n_fun) - Array of costs (i.e. infidelities). - - costs_indices: list of str - Names of the costs. - - """ - if pulse is None: - pulse = self.pulse - - for solver in self.solvers: - solver.set_optimization_parameters(pulse) - - costs = [] - - if self.stats: - self.stats.cost_func_eval_times.append([]) - for i, cost_func in enumerate(self.cost_funcs): - t_start = time.time() - #second argument is frequency [[amp,freq,phase],...,] - if type(cost_func).__name__ == "OperationInfidelity" or type(cost_func).__name__ == "OperationNoiseInfidelity" : - cost = cost_func.costs(pulse[0][1]) - elif type(cost_func).__name__ == "LeakageError" or type(cost_func).__name__ == "LeakageLiouville" or type(cost_func).__name__ == "StateInfidelity2": - cost = cost_func.costs() - else: - raise RuntimeError - t_end = time.time() - self.stats.cost_func_eval_times[-1].append(t_end - t_start) - - # reimplement the block below - costs.append(np.asarray(cost).flatten()) - - """ - I do not understand this block anymore. The cost can be an - array or a scalar, but the scalar can not be reshaped. - if hasattr(cost, "__len__"): - costs.append(cost) - else: - costs.append(cost.reshape(1)) - """ - costs = np.concatenate(costs, axis=0) - else: - for i, cost_func in enumerate(self.cost_funcs): - - if cost_func.__name__ == "OperationInfidelity" or cost_func.__name__ == "OperationNoiseInfidelity" : - cost = cost_func.costs(pulse[0][1]) - elif cost_func.__name__ == "LeakageError" or type(cost_func).__name__ == "LeakageLiouville" or type(cost_func).__name__ == "StateInfidelity2": - cost = cost_func.costs() - else: - raise RuntimeError - - costs.append(np.asarray(cost).flatten()) - """ - if hasattr(cost, "__len__"): - costs.append(cost) - else: - costs.append(cost.reshape(1)) - """ - costs = np.concatenate(costs, axis=0) - - return np.asarray(costs) - - def wrapped_jac_function_test(self, pulse=None): - """ - Wraps the gradient calculation functions of the fidelity computer. - - Parameters - ---------- - pulse: numpy array, optional - shape: (num_t, num_ctrl) If no pulse is specified the cost function - is evaluated for the attribute pulse. - - Returns - ------- - jac: numpy array - Array of gradients of shape (num_t, num_func, num_amp). - """ - - if self.numeric_jacobian: - return self.numeric_gradient(pulse=pulse) - - if pulse is None: - pulse = self.pulse - - for solver in self.solvers: - solver.set_optimization_parameters(pulse) - - jacobians = [] - - record_evaluation_times = bool(self.stats) - - if record_evaluation_times: - self.stats.grad_func_eval_times.append([]) - - for i, cost_func in enumerate(self.cost_funcs): - if record_evaluation_times: - t_start = time.time() - - if type(cost_func).__name__ == "OperationInfidelity" or type(cost_func).__name__ == "OperationNoiseInfidelity" : - jac_u = cost_func.grad(pulse[0][1]) - elif type(cost_func).__name__ == "LeakageError" or type(cost_func).__name__ == "LeakageLiouville" or type(cost_func).__name__ == "StateInfidelity2": - jac_u = cost_func.grad() - else: - raise RuntimeError - - # if the cost function is scalar, an extra dimension is inserted - if len(jac_u.shape) == 2: - jac_u = np.expand_dims(jac_u, axis=1) - - # apply the chain rule to the derivatives - jac_x = cost_func.solver.amplitude_function.derivative_by_chain_rule( - jac_u, cost_func.solver.transfer_function(pulse)) - jac_x_transferred = \ - cost_func.solver.transfer_function.gradient_chain_rule( - jac_x - ) - - if type(cost_func).__name__ == "OperationInfidelity" or type(cost_func).__name__ == "OperationNoiseInfidelity" : - jac_x_transferred[0,0,1] += cost_func.der_freq_test(pulse[0][1])[0,0] - elif type(cost_func).__name__ != "LeakageError" and type(cost_func).__name__ != "LeakageLiouville" and type(cost_func).__name__ != "StateInfidelity2": - raise RuntimeWarning - - jacobians.append(jac_x_transferred) - if record_evaluation_times: - t_end = time.time() - self.stats.grad_func_eval_times[-1].append(t_end - t_start) - - # two dimensional form as required by scipy solvers - total_jac = np.concatenate(jacobians, axis=1) - - return total_jac - - def compare_numeric_to_analytic_gradient( self, pulse: Optional[np.ndarray] = None, delta_eps: float = 1e-8, @@ -517,14 +369,14 @@ def numeric_gradient( central_costs = self.wrapped_cost_functions(pulse=test_pulse) - n_times, n_operators = np.asarray(test_pulse).shape + n_times, n_operators = test_pulse.shape n_cost_funcs = len(central_costs) gradients = np.zeros((n_times, n_cost_funcs, n_operators)) for n_time in range(n_times): for n_operator in range(n_operators): - delta = np.zeros_like(test_pulse) + delta = np.zeros_like(test_pulse, dtype=float) delta[n_time, n_operator] = delta_eps fwd_val = self.wrapped_cost_functions(test_pulse + delta) if symmetric: @@ -536,320 +388,3 @@ def numeric_gradient( (fwd_val - central_costs) / delta_eps return gradients - - -############################################################################### - -try: - import jax.numpy as jnp - _HAS_JAX = True -except ImportError: - from unittest import mock - jnp = mock.Mock() - _HAS_JAX = False - -class SimulatorJAX(Simulator): - """See docstring of class w/o JAX. Requires solver with JAX""" - - def __init__( - self, - solvers: Optional[Sequence[solver_algorithms.SolverJAX]], - cost_funcs: Optional[Sequence[cost_functions.CostFunction]], - optimization_parameters=None, - num_ctrl=None, - times=None, - num_times=None, - record_performance_statistics: bool = True, - numeric_jacobian: bool = False - ): - if not _HAS_JAX: - raise ImportError("JAX not available") - super().__init__(solvers,cost_funcs,optimization_parameters,num_ctrl, - times,num_times,record_performance_statistics, - numeric_jacobian) - - def wrapped_cost_functions(self, pulse=None): - """ - Wraps the cost functions of the fidelity computer. - - This function coordinates the complete simulation including the - application of the transfer function, the execution of the time - slot computer and the evaluation of the actual cost functions. - - Parameters - ---------- - pulse: (j)np array optional - If no pulse is specified the cost function is evaluated for the - attribute pulse. - - Returns - ------- - costs: jnp array, shape (n_fun) - Array of costs (i.e. infidelities). - - costs_indices: list of str - Names of the costs. - - """ - if pulse is None: - pulse = self.pulse - - for solver in self.solvers: - solver.set_optimization_parameters(pulse) - - costs = [] - - if self.stats: - self.stats.cost_func_eval_times.append([]) - for i, cost_func in enumerate(self.cost_funcs): - t_start = time.time() - cost = cost_func.costs() - t_end = time.time() - self.stats.cost_func_eval_times[-1].append(t_end - t_start) - - # reimplement the block below - costs.append(jnp.asarray(cost).flatten()) - - """ - I do not understand this block anymore. The cost can be an - array or a scalar, but the scalar can not be reshaped. - if hasattr(cost, "__len__"): - costs.append(cost) - else: - costs.append(cost.reshape(1)) - """ - costs = jnp.concatenate(costs, axis=0) - else: - for i, cost_func in enumerate(self.cost_funcs): - cost = cost_func.costs() - - costs.append(jnp.asarray(cost).flatten()) - """ - if hasattr(cost, "__len__"): - costs.append(cost) - else: - costs.append(cost.reshape(1)) - """ - costs = jnp.concatenate(costs, axis=0) - - return jnp.asarray(costs) - - def wrapped_jac_function(self, pulse=None): - """ - Wraps the gradient calculation functions of the fidelity computer. - - Parameters - ---------- - pulse: (j)np array, optional - shape: (num_t, num_ctrl) If no pulse is specified the cost function - is evaluated for the attribute pulse. - - Returns - ------- - jac: jnp array - Array of gradients of shape (num_t, num_func, num_amp). - """ - - if self.numeric_jacobian: - return self.numeric_gradient(pulse=pulse) - - if pulse is None: - pulse = self.pulse - - for solver in self.solvers: - solver.set_optimization_parameters(pulse) - - jacobians = [] - - record_evaluation_times = bool(self.stats) - - if record_evaluation_times: - self.stats.grad_func_eval_times.append([]) - - for i, cost_func in enumerate(self.cost_funcs): - if record_evaluation_times: - t_start = time.time() - jac_u = cost_func.grad() - - # if the cost function is scalar, an extra dimension is inserted - if len(jac_u.shape) == 2: - jac_u = jnp.expand_dims(jac_u, axis=1) - - # apply the chain rule to the derivatives - jac_x = cost_func.solver.amplitude_function.derivative_by_chain_rule( - jac_u, cost_func.solver.transfer_function(pulse)) - jac_x_transferred = \ - cost_func.solver.transfer_function.gradient_chain_rule( - jac_x - ) - jacobians.append(jac_x_transferred) - if record_evaluation_times: - t_end = time.time() - self.stats.grad_func_eval_times[-1].append(t_end - t_start) - - # two dimensional form as required by scipy solvers - total_jac = jnp.concatenate(jacobians, axis=1) - - return total_jac - -############################################################################### - -class SimulatorJAXSpecial(SimulatorJAX): - """ - - - """ - - def __init__( - self, - solvers: Optional[Sequence[solver_algorithms.Solver]], - cost_funcs: Optional[Sequence[cost_functions.CostFunction]], - optimization_parameters=None, - num_ctrl=None, - times=None, - num_times=None, - record_performance_statistics: bool = True, - numeric_jacobian: bool = False - ): - super().__init__(solvers,cost_funcs,optimization_parameters,num_ctrl,times,num_times,record_performance_statistics,numeric_jacobian) - - def wrapped_cost_functions(self, pulse=None): - """ - Wraps the cost functions of the fidelity computer. - - This function coordinates the complete simulation including the - application of the transfer function, the execution of the time - slot computer and the evaluation of the actual cost functions. - - Parameters - ---------- - pulse: numpy array optional - If no pulse is specified the cost function is evaluated for the - attribute pulse. - - Returns - ------- - costs: numpy array, shape (n_fun) - Array of costs (i.e. infidelities). - - costs_indices: list of str - Names of the costs. - - """ - if pulse is None: - pulse = self.pulse - - for solver in self.solvers: - solver.set_optimization_parameters(pulse) - - costs = [] - - if self.stats: - self.stats.cost_func_eval_times.append([]) - for i, cost_func in enumerate(self.cost_funcs): - t_start = time.time() - if type(cost_func).__name__ == "OperationInfidelityJAXSpecial" or type(cost_func).__name__ == "OperationNoiseInfidelityJAXSpecial" : - cost = cost_func.costs(pulse[0][-1]) - else: - raise RuntimeError - - t_end = time.time() - self.stats.cost_func_eval_times[-1].append(t_end - t_start) - - # reimplement the block below - costs.append(jnp.asarray(cost).flatten()) - - """ - I do not understand this block anymore. The cost can be an - array or a scalar, but the scalar can not be reshaped. - if hasattr(cost, "__len__"): - costs.append(cost) - else: - costs.append(cost.reshape(1)) - """ - costs = jnp.concatenate(costs, axis=0) - else: - for i, cost_func in enumerate(self.cost_funcs): - if type(cost_func).__name__ == "OperationInfidelityJAXSpecial" or type(cost_func).__name__ == "OperationNoiseInfidelityJAXSpecial" : - cost = cost_func.costs(pulse[0][-1]) - else: - raise RuntimeError - - costs.append(jnp.asarray(cost).flatten()) - """ - if hasattr(cost, "__len__"): - costs.append(cost) - else: - costs.append(cost.reshape(1)) - """ - costs = jnp.concatenate(costs, axis=0) - - return jnp.asarray(costs) - - def wrapped_jac_function(self, pulse=None): - """ - Wraps the gradient calculation functions of the fidelity computer. - - Parameters - ---------- - pulse: numpy array, optional - shape: (num_t, num_ctrl) If no pulse is specified the cost function - is evaluated for the attribute pulse. - - Returns - ------- - jac: numpy array - Array of gradients of shape (num_t, num_func, num_amp). - """ - - if self.numeric_jacobian: - return self.numeric_gradient(pulse=pulse) - - if pulse is None: - pulse = self.pulse - - for solver in self.solvers: - solver.set_optimization_parameters(pulse) - - jacobians = [] - - record_evaluation_times = bool(self.stats) - - if record_evaluation_times: - self.stats.grad_func_eval_times.append([]) - - for i, cost_func in enumerate(self.cost_funcs): - if record_evaluation_times: - t_start = time.time() - - if type(cost_func).__name__ == "OperationInfidelityJAXSpecial": - jac_u = cost_func.grad(pulse[0][-1]) - else: - raise RuntimeError - - - # if the cost function is scalar, an extra dimension is inserted - if len(jac_u.shape) == 2: - jac_u = jnp.expand_dims(jac_u, axis=1) - - # apply the chain rule to the derivatives - jac_x = cost_func.solver.amplitude_function.derivative_by_chain_rule( - jac_u, cost_func.solver.transfer_function(pulse)) - jac_x_transferred = \ - cost_func.solver.transfer_function.gradient_chain_rule( - jac_x - ) - - if type(cost_func).__name__ == "OperationInfidelityJAXSpecial": - ### jac_x_transferred=jac_x_transferred.at[0,0,-1].set(jac_x_transferred.at[0,0,-1] + cost_func.der_time_fact(pulse[0][-1])) - jac_x_transferred[0,0,-1] += cost_func.der_time_fact(pulse[0][-1]) - - jacobians.append(jac_x_transferred) - if record_evaluation_times: - t_end = time.time() - self.stats.grad_func_eval_times[-1].append(t_end - t_start) - - # two dimensional form as required by scipy solvers - total_jac = jnp.concatenate(jacobians, axis=1) - - return total_jac diff --git a/qopt/solver_algorithms.py b/qopt/solver_algorithms.py index 7f3eaf0..5ce5f68 100644 --- a/qopt/solver_algorithms.py +++ b/qopt/solver_algorithms.py @@ -78,7 +78,6 @@ from qopt.amplitude_functions import AmplitudeFunction, IdentityAmpFunc from qopt.util import needs_refactoring -from jax import grad, jit class Solver(ABC): r""" @@ -495,8 +494,7 @@ def forward_propagators(self) -> List[q_mat.OperatorMatrix]: """ if self._fwd_prop is None: - # self._compute_forward_propagation() - jit(self._compute_forward_propagation)() + self._compute_forward_propagation() return self._fwd_prop @property @@ -512,8 +510,7 @@ def frechet_deriv_propagators(self) -> List[List[q_mat.OperatorMatrix]]: """ if self._derivative_prop is None: - # self._compute_propagation_derivatives() - jit(self._compute_propagation_derivatives)() + self._compute_propagation_derivatives() return self._derivative_prop @property @@ -873,7 +870,7 @@ def __init__(self, self.frechet_deriv_approx_method = frechet_deriv_approx_method self._dyn_gen = None - + def set_optimization_parameters(self, y: np.array) -> None: """See base class. """ if not np.array_equal(self._opt_pars, y): @@ -1213,8 +1210,7 @@ def forward_propagators_noise(self) -> List[List[q_mat.OperatorMatrix]]: """ if self._fwd_prop_noise is None: - # self._compute_forward_propagation() - jit(self._compute_forward_propagation)() + self._compute_forward_propagation() return self._fwd_prop_noise @property @@ -1232,8 +1228,7 @@ def frechet_deriv_propagators_noise(self) \ """ if self._derivative_prop_noise is None: - # self._compute_propagation_derivatives() - jit(self._compute_propagation_derivatives)() + self._compute_propagation_derivatives() return self._derivative_prop_noise @property @@ -1530,7 +1525,7 @@ def noise_amplitude_function(noise_samples: np.array, Parameters ---------- - noise_samples: np.array + noise_samples: np.array, shape() Noise samples calculated by the noise trace generator. transferred_parameters: np.array @@ -1540,7 +1535,11 @@ def noise_amplitude_function(noise_samples: np.array, Control amplitudes. """ - noise_amplitudes = np.zeros((noise_samples.shape[0],noise_samples.shape[1],control_amplitudes.shape[1]), dtype=complex) + # noise_amplitudes = np.zeros_like(noise_samples, dtype=complex) + noise_amplitudes = np.zeros( + (noise_samples.shape[0], noise_samples.shape[1], + control_amplitudes.shape[1]), dtype=complex) + # complex values were requested. for trace_num in range(noise_samples.shape[1]): noise_amplitudes[:, trace_num, :] = self.amplitude_function( @@ -1873,7 +1872,6 @@ def reset_cached_propagators(self): self._diss_sup_op = None self._diss_sup_op_deriv = None - def _calc_diss_sup_op(self) -> List[q_mat.OperatorMatrix]: r""" Calculates the dissipative super operator as described in the class @@ -2197,1377 +2195,3 @@ def _compute_propagation(self): self._fwd.append(self._prop[t] * self._fwd[t]) self.prop_calculated = True - - -############################################################################### - -try: - import jax.numpy as jnp - from jax import jit, vmap - import jax - _HAS_JAX = True -except ImportError: - from unittest import mock - jit = mock.Mock() - jnp = mock.Mock() - vmap = mock.Mock() - jax = mock.Mock() - _HAS_JAX = False - -def _compute_propagation_expm_both_loop(transferred_time,dyn_gen, - derivative_directions): - """Internal loop of exponentiation of propagator and derivative""" - return jax.scipy.linalg.expm_frechet( - dyn_gen*transferred_time, - derivative_directions*transferred_time, - compute_expm=True) - -#from profiling with simple optimization example -#(could be different for complex problems): -#here all the runtime is (probably) in, but no faster way seems available -@jit -def _compute_propagation_expm_both(transferred_time,dyn_gen, - derivative_directions): - """Exponentiation of propagator and derivative, n_ctrl&n_timesteps on - first two axes - """ - return vmap(vmap(_compute_propagation_expm_both_loop,in_axes=(0,0,None)), - in_axes=(None,None,0))( - transferred_time,dyn_gen,derivative_directions) - -@jit -def _compute_propagation_expm_both_lind(transferred_time,dyn_gen, - derivative_directions): - """Exponentiation of propagator and derivative in super-operator formalism, - n_ctrl&n_timesteps on first two axes - """ - return vmap(vmap(_compute_propagation_expm_both_loop,in_axes=(0,0,0)), - in_axes=(None,None,1))( - transferred_time,dyn_gen,derivative_directions) - -@jit -def _compute_propagation_expm_both_noise(transferred_time,dyn_gen_noise, - derivative_directions): - """Exponentiation of propagator and derivative for Monte-Carlo, - n_traces on first axis - """ - return vmap(_compute_propagation_expm_both,in_axes=(None,0,None))( - transferred_time,dyn_gen_noise,derivative_directions) - -def _compute_propagation_expm_loop(transferred_time,dyn_gen): - """Internal loop of exponentiation of propagator""" - return jax.scipy.linalg.expm(dyn_gen*transferred_time) - -#if no derivatives runtime also here -@jit -def _compute_propagation_expm(transferred_time,dyn_gen): - """Exponentiation of propagator, n_ctrl&n_timesteps on first two axes""" - return vmap(_compute_propagation_expm_loop,in_axes=(0,0))( - transferred_time,dyn_gen) - -@jit -def _compute_propagation_expm_noise(transferred_time,dyn_gen_noise): - """Exponentiation of propagator for Monte-Carlo, n_traces on first axis""" - return vmap(_compute_propagation_expm,in_axes=(None,0))( - transferred_time,dyn_gen_noise) - -def _cumprod_loop(res,el): - """Internal loop of cumulative product of propagators""" - res = jnp.dot(el,res) - return res,res - -@jit -def _cumprod(init,prop): - """Cumulative product of propagators of single timesteps""" - _, cum_prod = jax.lax.scan(_cumprod_loop,init,prop) - return cum_prod - -@jit -def _cumprod_noise(init,prop_noise): - """Cumulative product of propagators of single timesteps for Monte-Carlo""" - return vmap(_cumprod,in_axes=(None,0))(init,prop_noise) - -def _cumprod_reversed_loop(res,el): - """Internal loop of reversed cumulative product of propagators""" - res = jnp.dot(res,el) - return res,res - -@jit -def _cumprod_reversed(init,prop): - """Reversed cumulative product of propagators of single timesteps""" - _, cum_prod = jax.lax.scan(_cumprod_reversed_loop,init,prop) - return cum_prod - -@jit -def _cumprod_reversed_noise(init,prop_noise): - """Reversed cumulative product of propagators of single timesteps for MC""" - return vmap(_cumprod_reversed,in_axes=(None,0))(init,prop_noise) - - -class SolverJAX(Solver): - """See docstring of class w/o JAX.""" - def __init__( - self, - h_ctrl: List[q_mat.OperatorMatrix], - h_drift: List[q_mat.OperatorMatrix], - tau: np.array, - initial_state: q_mat.OperatorMatrix = None, - opt_pars: Optional[Union[jnp.ndarray,np.ndarray]] = None, - ctrl_amps: Optional[Union[jnp.ndarray,np.ndarray]] = None, - filter_function_h_n: Union[ - Callable, List[List], None] = None, - filter_function_basis: Optional[basis.Basis] = None, - filter_function_n_coeffs_deriv: Optional[ - Callable[[np.ndarray], np.ndarray]] = None, - exponential_method: Optional[str] = None, - is_skew_hermitian: bool = True, - transfer_function: Optional[TransferFunction] = None, - amplitude_function: Optional[AmplitudeFunction] = None, - paranoia_level: int = 2 - ): - if not _HAS_JAX: - raise ImportError("JAX not available") - super().__init__( - h_ctrl, - h_drift, - tau, - initial_state, - opt_pars, - ctrl_amps, - filter_function_h_n, - filter_function_basis, - filter_function_n_coeffs_deriv, - exponential_method, - is_skew_hermitian, - transfer_function, - amplitude_function, - paranoia_level) - - if type(h_drift) in [matrix.DenseOperator, matrix.SparseOperator, - matrix.DenseOperatorJAX]: - self._h_drift_jnp = jnp.expand_dims(h_drift.data,0) - self.h_drift = [h_drift, ] * self.transfer_function.num_x - elif len(h_drift) == 1: - self._h_drift_jnp = jnp.expand_dims(h_drift[0].data,0) - self.h_drift = h_drift * self.transfer_function.num_x - else: - self._h_drift_jnp = jnp.array([h.data for h in h_drift]) - self.h_drift = h_drift - - self._h_ctrl_jnp = jnp.array([h.data for h in h_ctrl]) - self._transferred_time_jnp = jnp.array(self.transferred_time) - if initial_state is None: - dim = self.h_ctrl[0].shape[0] - self.initial_state = matrix.DenseOperatorJAX(jnp.eye(dim)) - else: - self.initial_state = matrix.DenseOperatorJAX(initial_state) - self._initial_state_jnp = self.initial_state.data - - self._prop_jnp = None - self._reversed_prop_jnp = None - self._fwd_prop_jnp = None - self._derivative_prop_jnp = None - - - def set_optimization_parameters(self, y: Union[jnp.ndarray,np.ndarray] - ) -> None: - """ - Set the control amplitudes. - - All computation flags are set to false. - - The new control amplitudes u are calculated: - u: np.array, shape (num_t, num_ctrl) - - Parameters - ---------- - y: Union[jnp.ndarray,np.ndarray], shape (num_x, num_ctrl) - Raw optimization parameters. - - """ - - if jnp.array_equal(self._opt_pars, y): - return - else: - #previously with copy (?) - self._opt_pars = y - - if self.transfer_function is not None: - self.transferred_parameters = self.transfer_function(y) - else: - #previously with copy (?) - self.transferred_parameters = y - - if self.amplitude_function is not None: - u = self.amplitude_function( - self.transferred_parameters) - else: - u = self.transferred_parameters - - if len(u.shape) != 2: - raise ValueError('The new control amplitudes must have two ' - 'dimensions! ' - '(time, control operator)') - - if u.shape[0] != len(self.transferred_time): - raise ValueError('The new control amplitudes do not have the ' - 'correct number of entries on the time axis!'+ - str(u.shape[0])+" "+str(len(self.transferred_time))) - - if u.shape[1] != len(self.h_ctrl): - raise ValueError('The new control amplitudes do not have the ' - 'correnct number of entries on the control axis!') - - self._ctrl_amps = u - self.reset_cached_propagators() - - def reset_cached_propagators(self): - """ Resets all cached propagators. """ - - self._prop = None #perhaps nonexistent? - self._fwd_prop = None - self._derivative_prop = None - self._reversed_prop = None - self.pulse_sequence = None - - self._prop_jnp = None - self._reversed_prop_jnp = None - self._fwd_prop_jnp = None - self._derivative_prop_jnp = None - - @property - def forward_propagators_jnp(self) -> jnp.ndarray: - - if self._fwd_prop_jnp is None: - self._compute_forward_propagation_jnp() - return self._fwd_prop_jnp - - @property - def reversed_propagators_jnp(self) -> jnp.ndarray: - - if self._reversed_prop_jnp is None: - self._compute_reversed_propagation_jnp() - return self._reversed_prop_jnp - - @property - def frechet_deriv_propagators_jnp(self) -> jnp.ndarray: - - if self._derivative_prop_jnp is None: - self._compute_propagation_derivatives_jnp() - return self._derivative_prop_jnp - - @abstractmethod - def _compute_propagation(self) -> None: - if self._prop_jnp is None: - self._compute_propagation_jnp() - - self._prop = [matrix.DenseOperatorJAX(p) for p in self._prop_jnp] - - def _compute_forward_propagation(self) -> None: - """Computes the forward propagators. """ - - - self._fwd_prop = [matrix.DenseOperatorJAX(p) - for p in self.forward_propagators_jnp] - - def _compute_reversed_propagation(self) -> None: - """Compute the reversed propagation. """ - - self._reversed_prop = [matrix.DenseOperatorJAX(p) - for p in self.reversed_propagators_jnp] - - @abstractmethod - def _compute_propagation_derivatives(self) -> None: - - if self._derivative_prop_jnp is None: - self._compute_propagation_derivatives_jnp() - - self._derivative_prop = [[matrix.DenseOperatorJAX(p) for p in der_t] - for der_t in self._derivative_prop_jnp] - - def _compute_forward_propagation_jnp(self) -> None: - - if self._prop_jnp is None: - self._compute_propagation_jnp() - - self._fwd_prop_jnp = jnp.append( - jnp.expand_dims(self._initial_state_jnp.copy(),0), - _cumprod(self._initial_state_jnp.copy(),self._prop_jnp),axis=0) - - def _compute_reversed_propagation_jnp(self) -> None: - - if self._prop_jnp is None: - self._compute_propagation_jnp() - - _initial_state_rev_jnp = jnp.eye(self._prop_jnp[0].shape[0]) * (1+0j) - - self._reversed_prop_jnp = jnp.append( - jnp.expand_dims(_initial_state_rev_jnp.copy(),0), - _cumprod_reversed(_initial_state_rev_jnp.copy(), - self._prop_jnp[::-1]),axis=0) - - @abstractmethod - def _compute_propagation_jnp(self) -> None: - """ - Computes the propagators. Must set self._prop! - - Raises - ------ - ValueError - If the control amplitudes are not set. - - """ - if self._ctrl_amps is None: - raise ValueError("The control amplitudes must be set to calculate " - "the propagation!") - - @abstractmethod - def _compute_propagation_derivatives_jnp(self) -> None: - """Compute the derivatives of the propagators by the control - amplitudes. - """ - pass - - def _calc_error(self): - - if self._dyn_gen is None: - self._dyn_gen = self._compute_dyn_gen() - - return (self._transferred_time_jnp[0])**2/2*jnp.linalg.norm( - self._dyn_gen[1:]@self._dyn_gen[:-1] - -self._dyn_gen[:-1]@self._dyn_gen[1:],axis=(1,2)) - - -class SchroedingerSolverJAX(SolverJAX): - """See docstring of class w/o JAX.""" - - def __init__(self, - h_drift: List[q_mat.OperatorMatrix], - h_ctrl: List[q_mat.OperatorMatrix], - tau: Union[jnp.array,np.array], - initial_state: q_mat.OperatorMatrix = None, - ctrl_amps: Optional[Union[jnp.array,np.array]] = None, - calculate_propagator_derivatives: bool = True, - filter_function_h_n: Union[ - Callable, List[List], None] = None, - filter_function_basis: Optional[basis.Basis] = None, - filter_function_n_coeffs_deriv: Optional[ - Callable[[np.ndarray], Union[jnp.array,np.array]]] = None, - exponential_method: Optional[str] = None, - frechet_deriv_approx_method: Optional[str] = None, - is_skew_hermitian: bool = True, - transfer_function: Optional[TransferFunction] = None, - amplitude_function: Optional[AmplitudeFunction] = None): - super().__init__( - h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, - tau=tau, ctrl_amps=ctrl_amps, - filter_function_h_n=filter_function_h_n, - filter_function_basis=filter_function_basis, - filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, - exponential_method=exponential_method, - is_skew_hermitian=is_skew_hermitian, - transfer_function=transfer_function, - amplitude_function=amplitude_function - ) - - - if self.exponential_method != "Frechet": - print("Other than Frechet ignored") - - self.id_text = 'ALL' - self.cache_text = 'Save' - self.calculate_propagator_derivatives = \ - calculate_propagator_derivatives - self.frechet_deriv_approx_method = frechet_deriv_approx_method - - self._dyn_gen = None - - - def set_optimization_parameters(self, y: Union[np.ndarray,jnp.ndarray]) -> None: - """See base class. """ - if not jnp.array_equal(self._opt_pars, y): - self.reset_cached_propagators() - super().set_optimization_parameters(y) - - def reset_cached_propagators(self): - """See base class. """ - self._dyn_gen = None - super().reset_cached_propagators() - - def _compute_dyn_gen(self) -> jnp.ndarray: - """ - Computes the dynamics generators. - - Returns - ------- - dyn_gen: List[ControlMatrix], len num_t - This is basically the total Hamiltonian. - - """ - - self._dyn_gen = -1j*(self._h_drift_jnp+jnp.einsum("ij,jkl->ikl", - self._ctrl_amps, - self._h_ctrl_jnp)) - #internally now only jax tensors? - return self._dyn_gen - - def _compute_derivative_directions( - self) -> jnp.ndarray: - """ - The directions of the frechet derivatives are the control operators. - - No deep copy is required because the result is not used for in-place - operations. - - """ - # The list is multiplied (copied by reference) because the elements - # will not be manipulated in place. (only as copy) - return -1j*jnp.expand_dims(self._h_ctrl_jnp,0) - - def _compute_propagation(self) -> None: - super()._compute_propagation() - - def _compute_propagation_derivatives(self) -> None: - super()._compute_propagation_derivatives() - - def _compute_propagation_jnp( - self, calculate_propagator_derivatives: Optional[bool] = None) \ - -> None: - """See base class. """ - super()._compute_propagation_jnp() - - if self._dyn_gen is None: - self._dyn_gen = self._compute_dyn_gen() - - if calculate_propagator_derivatives is None: - calculate_propagator_derivatives = \ - self.calculate_propagator_derivatives - - if calculate_propagator_derivatives: - derivative_directions = self._compute_derivative_directions() - - #TODO: behavior is not exactly reproduced as now - # derivative_directions[0] is taken; however only relevant for - #LindbladSolver (in special cases) (?) - self._prop_jnp,self._derivative_prop_jnp = \ - _compute_propagation_expm_both( - self._transferred_time_jnp, - self._dyn_gen,derivative_directions[0]) - self._prop_jnp = self._prop_jnp[0,:,:,:] - - else: - self._prop_jnp = _compute_propagation_expm( - self._transferred_time_jnp,self._dyn_gen) - - - def _compute_propagation_derivatives_jnp(self) -> None: - """ - Computes the frechet derivatives of the propagators. - - The derivatives are not returned but cached. Since the function is only - called when no derivatives are cached, the approximation is - prioritised. - """ - if not self.frechet_deriv_approx_method: - self._compute_propagation_jnp( - calculate_propagator_derivatives=True) - elif self.frechet_deriv_approx_method == 'grape': - if self._prop_jnp is None: - self._compute_propagation_jnp( - calculate_propagator_derivatives=False) - - self._derivative_prop_jnp = jnp.swapaxes( - jnp.expand_dims(self._transferred_time_jnp,(1,2,3))* - self._compute_derivative_directions()@ - jnp.expand_dims(self._prop_jnp,axis=1),0,1) - - else: - raise ValueError('Unknown gradient derivative approximation ' - 'method:' - + str(self.frechet_deriv_approx_method)) - - -class SchroedingerSMonteCarloJAX(SchroedingerSolverJAX): - """See docstring of class w/o JAX.""" - def __init__( - self, h_drift: List[q_mat.OperatorMatrix], - h_ctrl: List[q_mat.OperatorMatrix], - tau: Union[jnp.array,np.array], - h_noise: List[q_mat.OperatorMatrix], - noise_trace_generator: - Optional[noise.NoiseTraceGenerator], - initial_state: q_mat.OperatorMatrix = None, - ctrl_amps: Optional[Union[jnp.array,np.array]] = None, - calculate_propagator_derivatives: bool = False, - processes: Optional[int] = 1, - filter_function_h_n: Union[ - Callable, List[List], None] = None, - filter_function_basis: Optional[basis.Basis] = None, - filter_function_n_coeffs_deriv: Optional[ - Callable[[np.ndarray], np.ndarray]] = None, - exponential_method: Optional[str] = None, - frechet_deriv_approx_method: Optional[str] = None, - is_skew_hermitian: bool = True, - transfer_function: Optional[TransferFunction] = None, - amplitude_function: Optional[AmplitudeFunction] = None, - noise_amplitude_function: Optional[Callable[ - [np.array, np.array, np.array, - np.array], np.array]] = None - ): - - super().__init__( - h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, - tau=tau, ctrl_amps=ctrl_amps, - filter_function_h_n=filter_function_h_n, - filter_function_basis=filter_function_basis, - filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, - exponential_method=exponential_method, - calculate_propagator_derivatives=calculate_propagator_derivatives, - frechet_deriv_approx_method=frechet_deriv_approx_method, - is_skew_hermitian=is_skew_hermitian, - transfer_function=transfer_function, - amplitude_function=amplitude_function) - - self.h_noise = h_noise - self._h_noise_jnp = jnp.array([h.data for h in h_noise]) - self.noise_trace_generator = noise_trace_generator - self.noise_amplitude_function = noise_amplitude_function - self.processes = processes - - self._dyn_gen_noise = None - self._prop_noise = None - self._derivative_prop_noise = None - self._fwd_prop_noise = None - self._reversed_prop_noise = None - - self._prop_noise_jnp = None - self._derivative_prop_noise_jnp = None - self._fwd_prop_noise_jnp = None - self._reversed_prop_noise_jnp = None - - def set_optimization_parameters(self, - y: Union[np.ndarray,jnp.ndarray] - ) -> None: - """See base class. """ - if not jnp.array_equal(self._opt_pars, y): - self.reset_cached_propagators() - super().set_optimization_parameters(y) - - def reset_cached_propagators(self): - """See base class. """ - super().reset_cached_propagators() - self._dyn_gen_noise = None - self._prop_noise = None - self._prop_noise_jnp = None - self._derivative_prop_noise = None - self._derivative_prop_noise_jnp = None - self._fwd_prop_noise = None - self._reversed_prop_noise = None - self._fwd_prop_noise_jnp = None - self._reversed_prop_noise_jnp = None - - - @property - def propagators_noise(self) -> List[List[q_mat.OperatorMatrix]]: - """ - Returns the propagators of the system for each noise trace and - calculates them if necessary. - - Returns - ------- - propagators_noise: List[List[ControlMatrix]], - shape [[] * num_t] * num_noise_traces - Propagators of the system for each noise trace. - - """ - if self._prop_noise is None: - self._compute_propagation() - return self._prop_noise - - @property - def propagators_noise_jnp(self) -> jnp.ndarray: - """See docstring of function without _jnp. Now as jnp-array.""" - if self._prop_noise_jnp is None: - self._compute_propagation_jnp() - return self._prop_noise_jnp - - @property - def forward_propagators_noise(self) -> List[List[q_mat.OperatorMatrix]]: - """ - Returns the forward propagation of the initial state for every time - slice and every noise trace and calculate it if necessary. If the - initial state is the identity matrix, then the cumulative propagators - are given. The element forward_propagators[k][i] propagates a state by - the first i time steps under the kth noise trace, if the initial state - is the identity matrix. - - Returns - ------- - forward_propagation:List[List[ControlMatrix]], - shape [[] * (num_t + 1)] * num_noise_traces - Propagation of the initial state of the system. fwd[0] gives the - initial state itself. - - """ - if self._fwd_prop_noise is None: - self._compute_forward_propagation() - return self._fwd_prop_noise - - @property - def forward_propagators_noise_jnp(self) -> jnp.ndarray: - """See docstring of function without _jnp. Now as jnp-array.""" - if self._fwd_prop_noise_jnp is None: - self._compute_forward_propagation_jnp() - return self._fwd_prop_noise_jnp - - @property - def frechet_deriv_propagators_noise(self) \ - -> List[List[List[q_mat.OperatorMatrix]]]: - """ - Returns the frechet derivatives of the propagators with respect to the - control amplitudes for each noise trace. - - Returns - ------- - derivative_prop_noise: List[List[List[ControlMatrix]]], - shape [[[] * num_t] * num_ctrl] * num_noise_traces - Frechet derivatives of the propagators by the control amplitudes. - - """ - if self._derivative_prop_noise is None: - self._compute_propagation_derivatives() - return self._derivative_prop_noise - - @property - def frechet_deriv_propagators_noise_jnp(self) -> jnp.ndarray: - """See docstring of function without _jnp. Now as jnp-array.""" - if self._derivative_prop_noise_jnp is None: - self._compute_propagation_derivatives_jnp() - return self._derivative_prop_noise_jnp - - @property - def reversed_propagators_noise(self) -> List[List[q_mat.OperatorMatrix]]: - """ - Returns the reversed propagation of the initial state for every noise - trace and calculate it if necessary. If the initial state is the - identity matrix, then the reversed cumulative propagators are given. - The element forward_propagators[k][i] propagates a state by the first i - time steps under the kth noise trace, if the initial state is the - identity matrix. - - Returns - ------- - reversed_propagation_noise: List[List[ControlMatrix]], - shape [[] * (num_t + 1)] * num_noise_traces - Propagation of the initial state of the system. reversed[k][0] - gives the initial state itself. - - """ - if self._reversed_prop_noise is None: - self._compute_reversed_propagation() - return self._reversed_prop_noise - - @property - def reversed_propagators_noise_jnp(self) -> jnp.ndarray: - """See docstring of function without _jnp. Now as jnp-array.""" - if self._reversed_prop_noise_jnp is None: - self._compute_reversed_propagation_jnp() - return self._reversed_prop_noise_jnp - - def _compute_dyn_gen_noise(self) -> jnp.ndarray: - """ - Computes the dynamics generators for the perturbed and unperturbed - Schroedinger equation. - - Returns - ------- - dyn_gen_noise: List[List[q_mat.ControlMatrix]], - shape [[] * num_t] * num_noise_traces - Dynamics generators for each noise trace. - - """ - # compute the generators of the unperturbed dynamics - self._dyn_gen = super()._compute_dyn_gen() - - # compute the generators for the noise traces. - # n_noise_traces = self.noise_trace_generator.n_traces - - noise_samples = jnp.array(self.noise_trace_generator.noise_samples) - # we transpose, so we iterate over the time last - noise_samples = jnp.transpose(noise_samples, (2, 1, 0)) - - if self.noise_amplitude_function: - noise_samples = self.noise_amplitude_function( - noise_samples=noise_samples, - optimization_parameters=self._opt_pars, - transferred_parameters=self.transferred_parameters, - control_amplitudes=self._ctrl_amps - ) - - # i: n_samples_per_trace, j: n_traces, k: n_noise_ops, - # l: first dim of ham, m: second dim of ham - self._dyn_gen_noise = jnp.expand_dims(self._dyn_gen,axis=0) \ - - 1j*(jnp.einsum("ijk,klm->jilm",noise_samples,self._h_noise_jnp)) - - # -> (n_traces,n_samples_per_trace || t ?,d,d) - return self._dyn_gen_noise - - def _compute_propagation(self) -> None: - """ - Computes the propagators for the perturbed Schroedinger equation and - the derivatives on demand. - - Parameters - ---------- - calculate_propagator_derivatives: bool, optional - Calculate the derivatives of the propagators with respect to the - control amplitudes if true. - - """ - super()._compute_propagation() - self._prop_noise = [[matrix.DenseOperatorJAX(p) for p in trace] - for trace in self.propagators_noise_jnp] - - def _compute_propagation_derivatives(self) -> None: - """Computes propagator derivatives.""" - super()._compute_propagation_derivatives() - if self._derivative_prop_noise_jnp is None: - self._compute_propagation_derivatives_jnp() - - self._derivative_prop_noise_jnp = \ - [[[matrix.DenseOperatorJAX(p) for p in ctrl] for ctrl in der_t] - for der_t in self._derivative_prop_noise_jnp] - - - def _compute_propagation_jnp( - self, calculate_propagator_derivatives: Optional[bool] = None - ) -> None: - """See docstring of function without _jnp. Now as jnp-array.""" - - if self._dyn_gen_noise is None: - self._dyn_gen_noise = self._compute_dyn_gen_noise() - - if calculate_propagator_derivatives is None: - calculate_propagator_derivatives = \ - self.calculate_propagator_derivatives - - # parallelization of following code probably unnecessary - if calculate_propagator_derivatives: - - derivative_directions = self._compute_derivative_directions() - - # call the parent method for the noiseless propagators - super()._compute_propagation_jnp( - calculate_propagator_derivatives=calculate_propagator_derivatives) - - if self.processes == 1: - if calculate_propagator_derivatives: - - self._prop_noise_jnp, self._derivative_prop_noise_jnp = \ - _compute_propagation_expm_both_noise( - self._transferred_time_jnp, - self._dyn_gen_noise, - derivative_directions[0]) - self._prop_noise_jnp = self._prop_noise_jnp[:,0,:,:,:] - else: - - self._prop_noise_jnp = _compute_propagation_expm_noise( - self._transferred_time_jnp,self._dyn_gen_noise) - - elif (type(self.processes) == int and self.processes > 0) \ - or self.processes is None: - - raise NotImplementedError("No pool-multiprocess with jax calc, \ - (TODO) perhaps add with pmap (?)") - - - else: - raise ValueError('Invalid number of processes for parallel ' - 'computation!') - - def _compute_forward_propagation_jnp(self) -> None: - """Computes the forward propagators. """ - super()._compute_forward_propagation_jnp() - if self._prop_noise_jnp is None: - self._compute_propagation_jnp() - - cum_prop_noise = _cumprod_noise(self._initial_state_jnp.copy(), - self._prop_noise_jnp) - sh = cum_prop_noise.shape - - self._fwd_prop_noise_jnp = jnp.append(jnp.broadcast_to( - self._initial_state_jnp.copy(),(sh[0],1,*sh[2:])), - cum_prop_noise,axis=1) - - def _compute_forward_propagation(self) -> None: - """Computes the forward propagators. """ - super()._compute_forward_propagation() - - self._fwd_prop_noise = [[matrix.DenseOperatorJAX(p) for p in trace] - for trace in self.forward_propagators_noise_jnp] - - def _compute_reversed_propagation_jnp(self) -> None: - """Compute the reversed propagation. For the perturbed and unperturbed - Schroedinger equation. """ - super()._compute_reversed_propagation_jnp() - if self._prop_noise_jnp is None: - self._compute_propagation_jnp() - - _initial_state_rev_jnp = jnp.eye(self._prop_jnp[0].shape[0]) * (1+0j) - - cum_prop_reversed_noise = _cumprod_reversed_noise( - _initial_state_rev_jnp,self._prop_noise_jnp[::-1]) - - sh = cum_prop_reversed_noise.shape - - self._reversed_prop_noise_jnp = jnp.append( - jnp.broadcast_to(_initial_state_rev_jnp,(sh[0],1,*sh[2:])), - cum_prop_reversed_noise,axis=1) - - - def _compute_reversed_propagation(self) -> None: - """Compute the reversed propagation. For the perturbed and unperturbed - Schroedinger equation. """ - super()._compute_reversed_propagation() - - self._reversed_prop_noise = \ - [[matrix.DenseOperatorJAX(p) for p in trace] - for trace in self.reversed_propagators_noise_jnp] - - - def _compute_propagation_derivatives_jnp(self) -> None: - """ - Computes the frechet derivatives of the propagators. - - The derivatives are not returned but cached. Since the function is only - called when no derivatives are cached, the approximation is - prioritised. - """ - if not self.frechet_deriv_approx_method: - self._compute_propagation_jnp(calculate_propagator_derivatives=True) - - elif self.frechet_deriv_approx_method == 'grape': - super()._compute_propagation_derivatives_jnp() - - if self._prop_noise_jnp is None: - self._compute_propagation_jnp( - calculate_propagator_derivatives=False) - - derivative_directions = self._compute_derivative_directions() - - #broadcasting explicitly - self._derivative_prop_noise_jnp = \ - jnp.swapaxes( - jnp.expand_dims(self._transferred_time_jnp,(0,2,3,4))* - jnp.expand_dims(derivative_directions,0)@ - jnp.expand_dims(self._prop_noise_jnp,axis=2),1,2) - - else: - raise ValueError('Unknown gradient derivative approximation ' - 'method:' - + str(self.frechet_deriv_approx_method)) - - -class SchroedingerSMCControlNoiseJAX(SchroedingerSMonteCarloJAX): - """See docstring of class w/o JAX.""" - - def __init__( - self, - h_drift: List[q_mat.OperatorMatrix], - h_ctrl: List[q_mat.OperatorMatrix], - tau: Union[jnp.array,np.array], - noise_trace_generator: - Optional[noise.NoiseTraceGenerator], - initial_state: q_mat.OperatorMatrix = None, - ctrl_amps: Optional[np.array] = None, - calculate_propagator_derivatives: bool = False, - processes: Optional[int] = 1, - filter_function_h_n: Union[ - Callable, List[List], None] = None, - filter_function_basis: Optional[basis.Basis] = None, - filter_function_n_coeffs_deriv: Optional[ - Callable[[np.ndarray], np.ndarray]] = None, - exponential_method: Optional[str] = None, - frechet_deriv_approx_method: Optional[str] = None, - is_skew_hermitian: bool = True, - transfer_function: Optional[TransferFunction] = None, - amplitude_function: Optional[AmplitudeFunction] = None): - - def noise_amplitude_function( - noise_samples: Union[np.array,jnp.array], - transferred_parameters: Union[np.array,jnp.array], - control_amplitudes: Union[np.array,jnp.array], - **_): - """Calculates the noise amplitudes. - - Takes into account the actual optimization parameters and random - variations. - - Parameters - ---------- - noise_samples: np.array, shape() - Noise samples calculated by the noise trace generator. - - transferred_parameters: np.array - Transferred optimization parameters. - - control_amplitudes: np.array - Control amplitudes. - - """ - noise_amplitudes = jnp.zeros( - (noise_samples.shape[0], noise_samples.shape[1], - control_amplitudes.shape[1]), dtype=complex) - - - for trace_num in range(noise_samples.shape[1]): - #jnp cannot be updated in place - #->copy every time; inefficient in for loop? - noise_amplitudes = noise_amplitudes.at[:,trace_num,:].set(self.amplitude_function( - transferred_parameters + noise_samples[:, trace_num, :]) \ - - control_amplitudes) - return noise_amplitudes - - super().__init__( - h_drift=h_drift, - h_ctrl=h_ctrl, - initial_state=initial_state, - tau=tau, - h_noise=h_ctrl, - noise_trace_generator=noise_trace_generator, - ctrl_amps=ctrl_amps, - calculate_propagator_derivatives=calculate_propagator_derivatives, - processes=processes, - filter_function_h_n=filter_function_h_n, - filter_function_basis=filter_function_basis, - filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, - exponential_method=exponential_method, - frechet_deriv_approx_method=frechet_deriv_approx_method, - is_skew_hermitian=is_skew_hermitian, - transfer_function=transfer_function, - amplitude_function=amplitude_function,) - - -class LindbladSolverJAX(SchroedingerSolverJAX): - """See docstring of class w/o JAX.""" - - def __init__( - self, - h_drift: List[q_mat.OperatorMatrix], - h_ctrl: List[q_mat.OperatorMatrix], - tau: np.array, - initial_state: q_mat.OperatorMatrix = None, - ctrl_amps: Optional[np.array] = None, - calculate_unitary_derivatives: bool = False, - filter_function_h_n: Union[ - Callable, List[List], None] = None, - filter_function_basis: Optional[basis.Basis] = None, - filter_function_n_coeffs_deriv: Optional[ - Callable[[np.ndarray], np.ndarray]] = None, - exponential_method: Optional[str] = None, - frechet_deriv_approx_method: Optional[str] = None, - initial_diss_super_op: List[q_mat.OperatorMatrix] = None, - lindblad_operators: List[q_mat.OperatorMatrix] = None, - prefactor_function: Callable[[np.array,np.array],np.array] = None, - prefactor_derivative_function: - Callable[[np.array, np.array], np.array] = None, - super_operator_function: - Callable[[np.array, np.array], List[q_mat.OperatorMatrix]] = None, - super_operator_derivative_function: - Callable[[np.array, np.array], - List[List[q_mat.OperatorMatrix]]] = None, - is_skew_hermitian: bool = False, - transfer_function: Optional[TransferFunction] = None, - amplitude_function: Optional[AmplitudeFunction] = None) \ - -> None: - - if initial_state is None: - dim = h_ctrl[0].shape[0] - initial_state = type(h_ctrl[0])(np.eye(dim ** 2)) - - self._diss_sup_op_jnp = None - self._diss_sup_op_deriv_jnp = None - - # we do not throw away any operators or functions, just in case - self._initial_diss_super_op = initial_diss_super_op - self._lindblad_operators = lindblad_operators - - self._prefactor_function = prefactor_function - self._prefactor_deriv_function = prefactor_derivative_function - self._sup_op_func = super_operator_function - self._sup_op_deriv_func = super_operator_derivative_function - self._is_hermitian = is_skew_hermitian - - super().__init__( - h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, - tau=tau, ctrl_amps=ctrl_amps, - calculate_propagator_derivatives=calculate_unitary_derivatives, - filter_function_h_n=filter_function_h_n, - filter_function_basis=filter_function_basis, - filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, - exponential_method=exponential_method, - frechet_deriv_approx_method=frechet_deriv_approx_method, - is_skew_hermitian=is_skew_hermitian, - transfer_function=transfer_function, - amplitude_function=amplitude_function) - - def set_optimization_parameters(self, y: Union[jnp.array,np.array] - ) -> None: - """See base class. """ - if not np.array_equal(self._opt_pars, y): - super().set_optimization_parameters(y) - self.reset_cached_propagators() - - def reset_cached_propagators(self): - """ See base class. """ - super().reset_cached_propagators() - if self._prefactor_function is not None \ - or self._sup_op_func is not None: - self._diss_sup_op_jnp = None - self._diss_sup_op_deriv_jnp = None - - def _calc_diss_sup_op_jnp(self) -> jnp.ndarray: - r""" - Calculates the dissipative super operator as described in the class - doc string. - - Returns - ------- - diss_sup_op: jnp.ndarray, len num_t - Dissipation super operator; Where num_t is the number of timesteps - """ - if self._sup_op_func is None: - # use Lindblad operators - if self._lindblad_operators is None: - # use dissipation_sup_op - const_diss_sup_op = self._initial_diss_super_op - else: - # Calculate the time constant dissipation super operators - # without time dependence - const_diss_sup_op = [] - identity = self._lindblad_operators[0].identity_like() - - for lindblad in self._lindblad_operators: - const_diss_sup_op.append( - (lindblad.conj(do_copy=True)).kron(lindblad)) - const_diss_sup_op[-1] -= .5 * identity.kron( - lindblad.dag(do_copy=True) * lindblad) - const_diss_sup_op[-1] -= .5 * ( - lindblad.transpose(do_copy=True) - * lindblad.conj(do_copy=True)).kron(identity) - - # Add the time dependence - if self._prefactor_function is not None: - self._diss_sup_op = [] - prefactors = self._prefactor_function( - copy.deepcopy(self._ctrl_amps), - copy.deepcopy(self.transferred_parameters)) - for factor_at_time_t in prefactors: - self._diss_sup_op.append( - const_diss_sup_op[0] * factor_at_time_t[0]) - for sup_op, factor \ - in zip(const_diss_sup_op[1:], - factor_at_time_t[1:]): - self._diss_sup_op[-1] += sup_op * factor - else: - self._diss_sup_op = [const_diss_sup_op[0], ] - for sup_op in const_diss_sup_op[1:]: - self._diss_sup_op[0] += sup_op - self._diss_sup_op *= len(self.transferred_time) - else: - self._diss_sup_op = self._sup_op_func( - copy.deepcopy(self._ctrl_amps), - copy.deepcopy(self.transferred_parameters)) - - if isinstance(self._diss_sup_op,jnp.ndarray) or isinstance(self._diss_sup_op,np.ndarray): - self._diss_sup_op_jnp = self._diss_sup_op - else: - self._diss_sup_op_jnp = jnp.array([l.data for l in self._diss_sup_op]) - del self._diss_sup_op - #would be complicated to rewrite as jnp cause many in-place assignments? - #not the most efficient, but ok if not insane amounts of lindblad ops? - return self._diss_sup_op_jnp - - def _calc_diss_sup_op_deriv_jnp(self) \ - -> Optional[jnp.ndarray]: - r""" - Calculates the derivatives of the dissipation super operator with - respect to the control amplitudes. - - If the dissipation super operator is given as constant (1.) or as - lindblad operators (2.) they are assumed not to depend on the control - parameters and only the derivative of the prefactor is to be taken into - account. In order to do so, a function handle containing the - derivatives must be given. This function receives the control - amplitudes as num_t x num_ctrl numpy array and returns the derivatives - as num_t x num_l x num_ctrl array. - - If the dissipation super operator is given as function handle (3.), - then the derivatives must also be given as function handle receiving - the control amplitudes and returning a nested list of super operators - as control matrices. - - If the requested derivative functions are not provided (None), then - the dissipation super operator is considered constant in the control - amplitudes and the function returns None. - - Returns - ------- - diss_sup_op_deriv: jnp.array - The derivatives of the dissipation super operator with respect to - the control variables. - - """ - if self._sup_op_deriv_func is not None: - self._diss_sup_op_deriv = \ - self._sup_op_deriv_func( - copy.deepcopy(self._ctrl_amps), - copy.deepcopy(self.transferred_parameters)) - - if isinstance(self._diss_sup_op_deriv,jnp.ndarray) or isinstance(self._diss_sup_op_deriv,np.ndarray): - self._diss_sup_op_deriv_jnp = self._diss_sup_op_deriv - else: - self._diss_sup_op_deriv_jnp = \ - jnp.array([[l.data for l in lm] - for lm in self._diss_sup_op_deriv]) - del self._diss_sup_op_deriv - return self._diss_sup_op_deriv_jnp - - elif self._prefactor_deriv_function is not None: - if self._lindblad_operators is None: - # use dissipation_sup_op - const_diss_sup_op = self._initial_diss_super_op - else: - # Calculate the time constant dissipation super operators - # without time dependence - const_diss_sup_op = [] - identity = self._lindblad_operators[0].identity_like() - - for lindblad in self._lindblad_operators: - const_diss_sup_op.append( - (lindblad.conj(do_copy=True)).kron(lindblad)) - const_diss_sup_op[-1] -= .5 * identity.kron( - lindblad.dag(do_copy=True) * lindblad) - const_diss_sup_op[-1] -= .5 * ( - lindblad.transpose(do_copy=True) - * lindblad.conj(do_copy=True)).kron(identity) - - prefactor_derivatives = \ - self._prefactor_deriv_function( - copy.deepcopy(self._ctrl_amps), - copy.deepcopy(self.transferred_parameters)) - - # Todo: Assert that the prefactor returns the right dimension - - # prefactor_derivatives: shape (num_t, num_ctrl, num_l) - diss_sup_op_deriv = [] - for factor_per_ctrl_lind in prefactor_derivatives: - # create new sub list for eacht time step - diss_sup_op_deriv.append([]) - for factor_per_lind in factor_per_ctrl_lind: - # add the first term for each control direction - diss_sup_op_deriv[-1].append( - const_diss_sup_op[0] * factor_per_lind[0]) - for diss_sup_op, factor in zip( - const_diss_sup_op[1:], factor_per_lind[1:]): - # add the remaining terms - diss_sup_op_deriv[-1][-1] += diss_sup_op * factor - - if isinstance(diss_sup_op_deriv,jnp.ndarray) or isinstance(diss_sup_op_deriv,np.ndarray): - self._diss_sup_op_deriv_jnp = diss_sup_op_deriv - else: - self._diss_sup_op_deriv_jnp = \ - jnp.array([[l.data for l in lm] for lm in diss_sup_op_deriv]) - - return self._diss_sup_op_deriv_jnp - - else: - return None - - def _compute_derivative_directions( - self) -> jnp.ndarray: - r""" - Computes the derivative directions of the total dynamics generator. - - Returns - ------- - deriv_directions: jnp.array - """ - - identity_times_i = -1j*jnp.identity(self._h_ctrl_jnp[0].shape[0]) - h_ctrl_sup_op_jnp = jnp.kron(identity_times_i,self._h_ctrl_jnp) \ - -jnp.kron(jnp.transpose(self._h_ctrl_jnp,(0,2,1)),identity_times_i) - - # add derivative of the dissipation part - if self._diss_sup_op_deriv_jnp is None: - self._diss_sup_op_deriv_jnp = self._calc_diss_sup_op_deriv_jnp() - if self._diss_sup_op_deriv_jnp is not None: - dh_by_ctrl = self._diss_sup_op_deriv_jnp + h_ctrl_sup_op_jnp - else: - dh_by_ctrl = jnp.broadcast_to(h_ctrl_sup_op_jnp, - self._transferred_time_jnp.shape \ - +h_ctrl_sup_op_jnp.shape) - - return dh_by_ctrl - - def _parse_dissipative_super_operator(self) -> None: - r""" - check the dissipative super operator for dimensional consistency - (maybe even physical properties) - - not implemented yet - - """ - pass - - def _compute_dyn_gen(self) -> jnp.ndarray: - r""" - Computes the dynamics generator for the Lindblad master equation. - - The Hamiltonian is translated into the master equation formalism as - - .. math:: - - \mathcal{H} = I \otimes H - H^\ast \otimes I - - Then the dissipation super operator is added. - - Returns - ------- - dyn_gen: jnp.array, len num_t - Dynamics generators for the master equation. - - Raises - ------ - ValueError: - The computation is only defined for the use of dense control - matrices. - - """ - self._dyn_gen = super()._compute_dyn_gen() - - if self._diss_sup_op_jnp is None: - self._diss_sup_op_jnp = self._calc_diss_sup_op_jnp() - - identity_operator = jnp.identity(self._dyn_gen[0].shape[0]) - sup_op_dyn_gen = [] - - assert(len(self._dyn_gen) == len(self._diss_sup_op_jnp)) - - sup_op_dyn_gen = jnp.kron(identity_operator,self._dyn_gen) \ - + jnp.kron(jnp.conj(self._dyn_gen),identity_operator) \ - + self._diss_sup_op_jnp - - self._dyn_gen = sup_op_dyn_gen - return sup_op_dyn_gen - - def _compute_propagation_jnp( - self, calculate_propagator_derivatives: Optional[bool] = None) \ - -> None: - """See base class. """ - super(SchroedingerSolverJAX,self)._compute_propagation_jnp() - - if self._dyn_gen is None: - self._dyn_gen = self._compute_dyn_gen() - - if calculate_propagator_derivatives is None: - calculate_propagator_derivatives = \ - self.calculate_propagator_derivatives - - if calculate_propagator_derivatives: - derivative_directions = self._compute_derivative_directions() - - #previously with derivative_directions[0] due to being - #time-constant in normal SchroedingerSolver; however in - #LindbladSolver is maybe not(?) - self._prop_jnp, self._derivative_prop_jnp = \ - _compute_propagation_expm_both_lind(self._transferred_time_jnp, - self._dyn_gen, - derivative_directions) - self._prop_jnp = self._prop_jnp[0,:,:,:] - - else: - self._prop_jnp = _compute_propagation_expm( - self._transferred_time_jnp, - self._dyn_gen) - - -class LindbladSControlNoiseJAX(LindbladSolverJAX): - """See docstring of class w/o JAX.""" - - @needs_refactoring - def __init__(self, h_drift, h_ctrl, initial_state, tau, - ctrl_amps, transfer_function=None, - calculate_unitary_derivatives=True, filter_function_h_n=None, - exponential_method=None, lindblad_operators=None, - constant_lindblad_operators=False, noise_psd=1): - super().__init__( - h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, - tau=tau, ctrl_amps=ctrl_amps, - calculate_unitary_derivatives=calculate_unitary_derivatives, - filter_function_h_n=filter_function_h_n, - exponential_method=exponential_method) - - if lindblad_operators is None: - self.lindblad_super_operator = None - else: - d = lindblad_operators[0].shape[0] - self.lindblad_super_operator = np.zeros( - (len(lindblad_operators), d**2, d**2)) - for i, l in enumerate(lindblad_operators): - self.lindblad_super_operator[i, :, :] += np.kron(np.conj(l), l) - self.lindblad_super_operator[i, :, :] += -.5 * np.kron( - np.eye(d), l.T.conj() @ l) - self.lindblad_super_operator[i, :, :] += -.5 * np.kron( - l.T @ l.conj(), np.eye(d)) - - self.transfer_function = transfer_function - # if no transfer function is given it might be consider to be identity - # its not necessarily required - - self.constant_lindblad_operators = constant_lindblad_operators - self.noise_psd = noise_psd - self.incoherent_dyn_gen = None - - def _compute_propagation(self): - """Computes propagators.""" - # Compute and cache all dyn_gen (basically the total hamiltonian) - self._dyn_gen = self._h_drift_jnp - self._dyn_gen += jnp.sum(self._ctrl_amps * self._h_ctrl_jnp, axis=1) - - # super operator calculation - # this is the special case for charge noise on the control parameters - # the required filter function contains - if not self.constant_lindblad_operators or \ - self.incoherent_dyn_gen is None: - transfer_matrix = self.transfer_function.transfer_matrix - self.incoherent_dyn_gen = jnp.einsum('ijk,klm,k->ilm', - transfer_matrix, - self.lindblad_super_operator, - self.noise_psd) - dim = self._dyn_gen[0].shape[0] - identity_operator = jnp.identity(dim) - - self._dyn_gen = -1j*jnp.kron(identity_operator,self._dyn_gen) \ - -jnp.kron(self._dyn_gen,identity_operator) - self._dyn_gen += self.incoherent_dyn_gen - - # calculation of the propagators - # for t in range(len(self.num_t)): - if self.calculate_propagator_derivatives: - derivative_directions = jnp.kron( - identity_operator,self._h_ctrl_jnp) \ - -jnp.kron(self._h_ctrl_jnp,identity_operator) - self._prop_jnp, _derivative_prop_jnp = \ - _compute_propagation_expm_both_lind(self._transferred_time_jnp, - self._dyn_gen, - derivative_directions) - self._prop_jnp = self._prop_jnp[0,:,:,:] - #why this convention now? - self._dU = jnp.swapaxes(_derivative_prop_jnp,0,1) - - else: - self._prop_jnp = _compute_propagation_expm( - self._transferred_time_jnp,self._dyn_gen) - - self.prop_calculated = True - - \ No newline at end of file diff --git a/qopt/solver_algorithms_copy_original.py b/qopt/solver_algorithms_copy_original.py deleted file mode 100644 index f138c77..0000000 --- a/qopt/solver_algorithms_copy_original.py +++ /dev/null @@ -1,2105 +0,0 @@ -# -*- coding: utf-8 -*- -# ============================================================================= -# qopt -# Copyright (C) 2020 Julian Teske, Forschungszentrum Juelich -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -# -# Contact email: j.teske@fz-juelich.de -# ============================================================================= -""" Implements the algorithms to solve differential equations like -Schroedinger's equation or a master equation. - -The `Solver` class is the central piece of the actual simulation. It calculates -propagators from the differential equations describing the quantum dynamics. -The abstract base class inherits among other things an interface to the -`PulseSequence` class of the filter_functions package. - -The `Solver` classes can have an amplitude and a transfer function as attribute -and automate their use. The Monte Carlo solvers also hold an instance of a -noise trace generator. - -If requested, also derivatives of the propagators by the control amplitudes are -calculated or approximated. - -Classes -------- -:class:`Solver` - Abstract base class of the time slot computers. - -:class:`SchroedingerSolver` - Solver for the the unperturbed Schroedinger equation. - -:class:`SchroedingerSMonteCarlo` - Solver for the Schroedinger equation under the influence of noise. - -:class:`SchroedingerSMCControlNoise` - Solver for the Schroedinger equation under the influence of noise affecting - the control terms. - -:class:`LindbladSolver` - Solves the master equation in Lindblad form. - -Notes ------ -The implementation was inspired by the optimal control package of QuTiP [1]_ -(Quantum Toolbox in Python) - -References ----------- -.. [1] J. R. Johansson, P. D. Nation, and F. Nori: "QuTiP 2: A Python framework - for the dynamics of open quantum systems.", Comp. Phys. Comm. 184, 1234 - (2013) [DOI: 10.1016/j.cpc.2012.11.019]. - -""" - -import numpy as np -import copy -from typing import Optional, List, Callable, Union -from abc import ABC, abstractmethod -from multiprocessing import Pool - -from filter_functions import pulse_sequence, plotting, basis, numeric - -from qopt import noise, matrix, matrix as q_mat -from qopt.transfer_function import TransferFunction, IdentityTF -from qopt.amplitude_functions import AmplitudeFunction, IdentityAmpFunc -from qopt.util import needs_refactoring - - -class Solver(ABC): - r""" - Abstract base class for Solvers. - - Parameters - ---------- - h_ctrl: List[ControlMatrix], len num_ctrl - Control operators in the Hamiltonian as nested list of - shape n_t, num_ctrl. - - h_drift: List[ControlMatrix], len num_t or 1 - Drift operators in the Hamiltonian. You can either give a single element - or one for each transferred time step. - - initial_state : ControlMatrix - Initial state of the system as state vector. Can also be set to the - identity matrix. Then the forward propagation gives the total - propagator of the system. - - tau: array of float, shape (num_t, ) - Durations of the time slices. - - opt_pars: np.array, shape (num_y, num_par), optional - Raw optimization parameters. - - ctrl_amps: np.array, shape (num_t, num_ctrl), optional - The initial control amplitudes. - - filter_function_h_n: List[List[np.array]] or List[List[Qobj]] or callable - Nested list of noise Operators. Used in the filter function - formalism. _filter_function_h_n should look something like this: - - >>> H = [[n_oper1, n_coeff1, n_oper_identifier1], - >>> [n_oper2, n_coeff2, n_oper_identifier2], ...] - - The operators may be given either as NumPy arrays or QuTiP Qobjs - and each coefficient array should have the same number of elements - as *dt*, and should be given in units of :math:`\hbar`. If not every - sublist (read operator) was given a identifier, they are automatically - filled up with 'A_i' where i is the position of the operator. - Alternatively the create_ff_h_n may be a function handle creating - such an object when called with the optimization parameters. - - filter_function_basis: Basis, shape (d**2, d, d), optional - The operator basis in which to calculate. If a Generalized Gell-Mann - basis (see :meth:`~basis.Basis.ggm`) is chosen, some calculations will - be faster for large dimensions due to a simpler basis expansion. - However, when extending the pulse sequence to larger qubit registers, - cached filter functions cannot be retained since the GGM basis does not - factor into tensor products. In this case a Pauli basis is preferable. - - filter_function_n_coeffs_deriv: Callable numpy array to numpy array - This function calculates the derivatives of the noise susceptibility in - the filter function formalism. It receives the optimization parameters - as array of shape (num_opt, num_t) and returns the derivatives as array - of shape (num_noise_op, n_ctrl, num_t). - - exponential_method: string, optional - Method used by the ControlMatrix class for the calculation of the - matrix exponential. The default is 'Frechet'. See also the Docstring of - the file 'qopt.matrix'. - - is_skew_hermitian: bool - Only important for the exponential_method 'spectral'. If set to true, - the dynamical generator is assumed to be skew hermitian during the - spectral decomposition. - - transfer_function: TransferFunction - The transfer function for reshaping the optimization parameters. - - amplitude_function: AmplitudeFunction - The amplitude function connecting the transferred optimization - parameters to the control amplitudes. - - paranoia_level: int - The paranoia_level determines how many checks are conducted. - 0 No tests - 1 Some tests - 2 Exhaustive tests, dimension checks - - Attributes - ---------- - h_ctrl : List[ControlMatrix], len num_ctrl - Control operators in the Hamiltonian as list of length num_ctrl. - - h_drift : List[ControlMatrix], len num_t - Drift operators in the Hamiltonian. - - initial_state : ControlMatrix - Initial state of the system as state vector. Can also be set to the - identity matrix. Then the forward propagation gives the total - propagator of the system. - - transferred_time: List[float] - Durations of the time slices. - - filter_function_h_n: List[List[np.array]] or List[List[Qobj]] - Nested list of noise Operators. Used in the filter function - formalism. - - filter_function_basis: Basis - The filter function pulse sequence will be expressed in this basis. - See documentation of the filter function package. - - exponential_method: string, optional - Method used by the ControlMatrix class for the calculation of the - matrix exponential. The default is 'Frechet'. See also the Docstring of - the file 'qopt.matrix'. - - transfer_function: TransferFunction - The transfer function for reshaping the optimization parameters. - - amplitude_function: AmplitudeFunction - The amplitude function connecting the transferred optimization - parameters to the control amplitudes. - - _prop: List[ControlMatrix], len num_t - Propagators of the system. - - _fwd_prop: List[ControlMatrix], len num_t + 1 - Ordered product of the propagators. They describe the forward - propagation of the systems state. - - _reversed_prop: List[ControlMatrix], len num_t + 1 - Ordered product of propagators in reversed order. - - _derivative_prop: List[List[ControlMatrix]], shape [[] * num_t] * num_ctrl - Frechet derivatives of the propagators by the control amplitudes. - - Methods - ------- - propagators: List[ControlMatrix], len num_t - Returns the propagators of the system. - - forward_propagators: List[ControlMatrix], len num_t + 1 - Returns the forward propagation of the initial state. The element - forward_propagators[i] propagates a state by the first i time steps, if - the initial state is the identity matrix. - - frechet_deriv_propagators: List[List[ControlMatrix]], - shape [[] * num_t] * num_ctrl - Returns the frechet derivatives of the propagators by the control - amplitudes. - - reversed_propagators: List[ControlMatrix], len num_t + 1 - Returns the reversed propagation of the initial state. The element - reversed_propagators[i] propagates a state by the last i time steps, if - the initial state is the identity matrix. - - _compute_propagation: abstract method - Computes the propagators. - - _compute_forward_propagation - Compute the forward propagation of the initial state / system. - - _compute_reversed_propagation - Compute the reversed propagation of the initial state / system. - - _compute_propagation_derivatives: abstract method - Compute the derivatives of the propagators by the control amplitudes. - - create_pulse_sequence(new_amps): PulseSequence - Creates a pulse sequence instance corresponding to the current control - amplitudes. - - `Todo` - * Write parser - * setter for new hamiltonians - * make hamiltonians private - * also for the initial state - * extend constant drift hamiltonian - * Implement the drift operator with an amplitude. Right now, - * the operator is already multiplied with the amplitude, which is - * not coherent with the pulse sequence interface. Alternatively - * amplitude=1? - * transferred_time should be taken from the transfer function - * Use own plotting for the plotting - * Consequent try catches for the computation of the matrix exponential - - """ - - def __init__( - self, - h_ctrl: List[q_mat.OperatorMatrix], - h_drift: List[q_mat.OperatorMatrix], - tau: np.array, - initial_state: q_mat.OperatorMatrix = None, - opt_pars: Optional[np.array] = None, - ctrl_amps: Optional[np.array] = None, - filter_function_h_n: Union[ - Callable, List[List], None] = None, - filter_function_basis: Optional[basis.Basis] = None, - filter_function_n_coeffs_deriv: Optional[ - Callable[[np.ndarray], np.ndarray]] = None, - exponential_method: Optional[str] = None, - is_skew_hermitian: bool = True, - transfer_function: Optional[TransferFunction] = None, - amplitude_function: Optional[AmplitudeFunction] = None, - paranoia_level: int = 2 - ): - - self.h_ctrl = h_ctrl - self._ctrl_amps = ctrl_amps - self._opt_pars = opt_pars - - if initial_state is None: - dim = self.h_ctrl[0].shape[0] - self.initial_state = type(self.h_ctrl[0])(np.eye(dim)) - else: - self.initial_state = initial_state - - if exponential_method is None: - self.exponential_method = 'Frechet' - else: - self.exponential_method = exponential_method - - self._prop = None - self._fwd_prop = None - self._reversed_prop = None - self._derivative_prop = None - - self.pulse_sequence = None - - if filter_function_h_n is None: - self._filter_function_h_n = [] - else: - self._filter_function_h_n = filter_function_h_n - self.filter_function_basis = filter_function_basis - self.filter_function_n_coeffs_deriv = filter_function_n_coeffs_deriv - - self._is_skew_hermitian = is_skew_hermitian - - if transfer_function is None: - self.transfer_function = IdentityTF(num_ctrls=len(h_ctrl)) - else: - self.transfer_function = transfer_function - - self.transferred_time = None - self.set_times(tau=tau) - - if type(h_drift) in [matrix.DenseOperator, matrix.SparseOperator]: - self.h_drift = [h_drift, ] * self.transfer_function.num_x - elif len(h_drift) == 1: - self.h_drift = h_drift * self.transfer_function.num_x - else: - self.h_drift = h_drift - - if amplitude_function is None: - self.amplitude_function = IdentityAmpFunc() - else: - self.amplitude_function = amplitude_function - - self.transferred_parameters = None - - self.consistency_checks(paranoia_level=paranoia_level) - - def set_times(self, tau): - """ Set time values by passing them to the transfer function. - - Parameters - ---------- - tau: array of float, shape (num_t, ) - Durations of the time slices. - - """ - self.transfer_function.set_times(tau) - self.transferred_time = self.transfer_function.x_times - self.reset_cached_propagators() - - def set_optimization_parameters(self, y: np.array) -> None: - """ - Set the control amplitudes. - - All computation flags are set to false. - - The new control amplitudes u are calculated: - u: np.array, shape (num_t, num_ctrl) - - Parameters - ---------- - y: np.array, shape (num_x, num_ctrl) - Raw optimization parameters. - - """ - - if np.array_equal(self._opt_pars, y): - return - else: - self._opt_pars = np.copy(y) - - if self.transfer_function is not None: - self.transferred_parameters = self.transfer_function(y) - else: - self.transferred_parameters = np.copy(y) - - if self.amplitude_function is not None: - u = self.amplitude_function( - self.transferred_parameters) - else: - u = self.transferred_parameters - - if len(u.shape) != 2: - raise ValueError('The new control amplitudes must have two ' - 'dimensions! ' - '(time, control operator)') - - if u.shape[0] != len(self.transferred_time): - raise ValueError('The new control amplitudes do not have the ' - 'correct number of entries on the time axis!') - - if u.shape[1] != len(self.h_ctrl): - raise ValueError('The new control amplitudes do not have the ' - 'correnct number of entries on the control axis!') - - self._ctrl_amps = u - self.reset_cached_propagators() - - def reset_cached_propagators(self): - """ Resets all cached propagators. """ - self._prop = None - self._fwd_prop = None - self._derivative_prop = None - self._reversed_prop = None - self.pulse_sequence = None - - def consistency_checks(self, paranoia_level: int): - """Checks attributes for inner consistency. - - Parameters - ---------- - paranoia_level: int - The paranoia_level determines how many checks are conducted. - 0: No tests - 1: Some tests - 2: Exhaustive tests, dimension checks - - """ - if paranoia_level == 0: - return - - elif paranoia_level >= 1: - # check whether the hamiltonian is correct for the number of time - # steps - if isinstance(self.transferred_time, List): - self.transferred_time = np.asarray(self.transferred_time) - if len(self.transferred_time.shape) > 1: - raise ValueError("Tau must be a one dimensional numpy array or" - "a list.") - n_time_steps = self.transferred_time.shape[0] - - if len(self.h_drift) == 1: - self.h_drift = self.h_drift * n_time_steps - - if not (n_time_steps == len(self.h_drift) - or len(self.h_drift) == 0): - raise ValueError("The drift hamiltonian must have exactly one " - "entry for each transferred time step or no " - "entry at all or a single entry.") - if paranoia_level >= 2: - # check whether the Hamiltonian has the correct dimensions - dim = self.h_ctrl[0].shape[0] - - for ctrl_matrix in self.h_ctrl: - assert(dim == ctrl_matrix.shape[0]) - assert(dim == ctrl_matrix.shape[1]) - - for drift_matrx in self.h_drift: - assert(dim == drift_matrx.shape[0]) - assert(dim == drift_matrx.shape[1]) - - else: - raise ValueError("The paranoia level must be a positive integer.") - - @property - def propagators(self) -> List[q_mat.OperatorMatrix]: - """ - Returns the propagators of the system and calculates them if necessary. - - Returns - ------- - propagators: List[ControlMatrix], len num_t - Propagators of the system. - - """ - if self._prop is None: - self._compute_propagation() - return self._prop - - @property - def forward_propagators(self) -> List[q_mat.OperatorMatrix]: - """ - Returns the forward propagation of the initial state for every time - slice and calculate it if necessary. If the initial state is the - identity matrix, then the cumulative propagators are given. The element - forward_propagators[i] propagates a state by the first i time steps, if - the initial state is the identity matrix. - - Returns - ------- - forward_propagation: List[ControlMatrix], len num_t + 1 - Propagation of the initial state of the system. fwd[0] gives the - initial state itself. - - """ - if self._fwd_prop is None: - self._compute_forward_propagation() - return self._fwd_prop - - @property - def frechet_deriv_propagators(self) -> List[List[q_mat.OperatorMatrix]]: - """ - Returns the frechet derivatives of the propagators. - - Returns - ------- - derivative_prop: List[List[ControlMatrix]], - shape [[] * num_t] * num_ctrl - Frechet derivatives of the propagators by the control amplitudes - - """ - if self._derivative_prop is None: - self._compute_propagation_derivatives() - return self._derivative_prop - - @property - def reversed_propagators(self) -> List[q_mat.OperatorMatrix]: - """ - Returns the reversed propagation of the initial state for every time - slice and calculate it if necessary. If the initial state is the - identity matrix, then the reversed cumulative propagators are given. - The element forward_propagators[i] propagates a state by the first i - time steps, if the initial state is the identity matrix. - - Returns - ------- - reversed_propagation: List[ControlMatrix], len num_t + 1 - Propagation of the initial state of the system. reversed[0] gives - the initial state itself. - - """ - if self._reversed_prop is None: - self._compute_reversed_propagation() - return self._reversed_prop - - @property - def filter_function_n_coeffs_deriv_vals(self) -> Optional[np.ndarray]: - """ - Calculates the derivatives of the noise susceptibilities from the filter - function formalism. - - Returns - ------- - n_coeffs_deriv: numpy array of shape (num_noise_op, n_ctrl, num_t) - Derivatives of the noise susceptibilities by the control amplitudes. - - """ - if self.filter_function_n_coeffs_deriv is None: - return None - else: - return self.filter_function_n_coeffs_deriv(self._ctrl_amps) - - @property - def create_ff_h_n(self) -> list: - """Creates the noise hamiltonian of the filter function formalism. - - Returns - ------- - create_ff_h_n: nested list - Noise Hamiltonian of the filter function formalism. - - """ - if type(self._filter_function_h_n) == list: - h_n = self._filter_function_h_n - else: - h_n = self._filter_function_h_n(self._ctrl_amps) - - if not h_n: - h_n = [[np.zeros(self.h_ctrl[0].shape), - np.zeros((len(self.transferred_time),))]] - - return h_n - - @abstractmethod - def _compute_propagation(self) -> None: - """ - Computes the propagators. Must set self._prop! - - Raises - ------ - ValueError - If the control amplitudes are not set. - - """ - if self._ctrl_amps is None: - raise ValueError("The control amplitudes must be set to calculate " - "the propagation!") - - def _compute_forward_propagation(self) -> None: - """Computes the forward propagators. """ - if self._prop is None: - self._compute_propagation() - self._fwd_prop = [self.initial_state.copy(), ] - for prop in self._prop: - self._fwd_prop.append(prop * self._fwd_prop[-1]) - - def _compute_reversed_propagation(self) -> None: - """Compute the reversed propagation. """ - if self._prop is None: - self._compute_propagation() - - if type(self.initial_state) == matrix.DenseOperator: - self._reversed_prop = [matrix.DenseOperator( - np.eye(self._prop[0].shape[0])) * (1 + 0j), ] - elif type(self.initial_state) == matrix.SparseOperator: - raise NotImplementedError - # self._reversed_prop = [matrix.SparseOperator( - # np.eye(self._prop[0].shape[0])) * (1 + 0j), ] - else: - raise TypeError("The initial state should be either a dense or " - "sparse control matrix.") - - for prop in self._prop[::-1]: - self._reversed_prop.append(self._reversed_prop[-1] * prop) - - @abstractmethod - def _compute_propagation_derivatives(self) -> None: - """Compute the derivatives of the propagators by the control - amplitudes. - """ - pass - - def _diagonalize_and_propagate_pulse_sequence(self) -> None: - """Manually set eigendecomposition of the PulseSequence. - - Work around incompatibility of drift Hamiltonian - representations.""" - ps = self.pulse_sequence - drift_hamiltonian = np.array([h.data for h in self.h_drift]) - control_hamiltonian = np.einsum('ijk,il->ljk', ps.c_opers, ps.c_coeffs) - ps.eigvals, ps.eigvecs, ps.propagators = numeric.diagonalize( - drift_hamiltonian + control_hamiltonian, ps.dt - ) - ps.total_propagator = ps.propagators[-1] - - def create_pulse_sequence( - self, new_amps: Optional[np.array] = None, - ff_basis: Optional[basis.Basis] = None - ) -> pulse_sequence.PulseSequence: - """ - Create a pulse sequence of the filter function package written by - Tobias Hangleiter. - - See the documentation of the filter function package. - - Parameters - ---------- - new_amps: np.array, shape (num_t, num_ctrl), optional - New control amplitudes can be set before the pulse sequence is - initialized. - - ff_basis: Basis - The pulse sequence will be expanded in this basis. See - documentation of the filter function package. - - Returns - ------- - pulse_sequence: filter_functions.pulse_sequence.PulseSequence - The pulse sequence corresponding to the control model and the - control amplitudes set. - - """ - if new_amps is not None: - self.set_optimization_parameters(new_amps) - else: - if self._ctrl_amps is None: - raise ValueError('No optimization parameters set. ' - 'Please supply new_amps argument') - - if ff_basis is not None: - basis = ff_basis - elif self.filter_function_basis is not None: - basis = self.filter_function_basis - else: - basis = None - - # We have to work around different interfaces for the drift - # operators. Since in qopt the drift can be arbitrary (incl. - # nonlinear coupling), but in filter_functions the form H = - # a(t) A is imposed, we don't tell the PulseSequence object - # about H_drift and set the eigendecomposition after the fact. - if self.pulse_sequence is None: - h_c = list(zip( - self.h_ctrl, - self._ctrl_amps.T, - [f'Control{i}' for i in range(len(self.h_ctrl))] - )) - self.pulse_sequence = pulse_sequence.PulseSequence( - h_c, self.create_ff_h_n, self.transferred_time, basis - ) - else: - # Clean up the caches and update coefficients - self.pulse_sequence.cleanup('all') - self.pulse_sequence.c_coeffs = self._ctrl_amps.T - # Not the most elegant, but necessary for the current - # implementation. - self.pulse_sequence.n_coeffs = pulse_sequence._parse_Hamiltonian( - self.create_ff_h_n, - len(self.transferred_time), 'H_n')[2] - - if basis is not None: - self.pulse_sequence.basis = basis - - self._diagonalize_and_propagate_pulse_sequence() - return self.pulse_sequence - - def plot_bloch_sphere( - self, new_amps=None, return_Bloch: bool = False) -> None: - """ - Uses the pulse sequence to plot the systems evolution on the bloch - sphere. - - Only available for two dimensional systems. - - Parameters - ---------- - new_amps: np.array, shape (num_t, num_ctrl), optional - New control amplitudes can be set before the pulse sequence is - initialized. - - return_Bloch: bool - If True, then qutips Bloch object is returned. - - Returns - ------- - b: Bloch - Qutips Bloch object. Only returned if return_Bloch is set to True. - - """ - # Already takes care of updating and cleaning the PulseSequence object - pulse_sequence = self.create_pulse_sequence(new_amps=new_amps) - return plotting.plot_bloch_vector_evolution(pulse_sequence, - n_samples=500, - return_Bloch=return_Bloch) - - -class SchroedingerSolver(Solver): - """ - This time slot computer solves the unperturbed Schroedinger equation. - - All intermediary propagators are calculated and cached. Takes also input - parameters of the base class. - - Parameters - ---------- - calculate_propagator_derivatives: bool - If true, the derivatives of the propagators by the control amplitudes - are always calculated. Otherwise only on demand. - - frechet_deriv_approx_method: Optional[str] - Method for the approximation of the derivatives of the propagators, if - they are not calculated analytically. Note that this method is never - used if calculate_propagator_derivatives is set to True! - Methods: - None: The derivatives are not approximated by calculated by the control - matrix class. - 'grape': use the approximation given in the original grape paper. - - Attributes - ---------- - _dyn_gen: List[ControlMatrix], len num_t - The generators of the systems dynamics - - calculate_propagator_derivatives: bool - If true, the derivatives of the propagators by the control amplitudes - are always calculated. Otherwise only on demand. - - frechet_deriv_approx_method: Optional[str] - Method for the approximation of the derivatives of the propagators, if - they are not calculated analytically. Note that this method is never - used if calculate_propagator_derivatives is set to True! - Methods: - 'grape': use the approximation given in the original grape paper. - - Methods - ------- - _compute_derivative_directions: List[List[q_mat.ControlMatrix]], - shape [[] * num_ctrl] * num_t - Computes the directions of change with respect to the control - parameters. - - _compute_dyn_gen: List[ControlMatrix], len num_t - Computes the dynamics generators. - - `Todo` - * raise a warning if the approximation method although the gradient - is always calculated. - * raise a warning if the grape approximation is chosen but its - requirement of small time steps is not met. - - """ - - def __init__(self, - h_drift: List[q_mat.OperatorMatrix], - h_ctrl: List[q_mat.OperatorMatrix], - tau: np.array, - initial_state: q_mat.OperatorMatrix = None, - ctrl_amps: Optional[np.array] = None, - calculate_propagator_derivatives: bool = True, - filter_function_h_n: Union[ - Callable, List[List], None] = None, - filter_function_basis: Optional[basis.Basis] = None, - filter_function_n_coeffs_deriv: Optional[ - Callable[[np.ndarray], np.ndarray]] = None, - exponential_method: Optional[str] = None, - frechet_deriv_approx_method: Optional[str] = None, - is_skew_hermitian: bool = True, - transfer_function: Optional[TransferFunction] = None, - amplitude_function: Optional[AmplitudeFunction] = None): - super().__init__( - h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, - tau=tau, ctrl_amps=ctrl_amps, - filter_function_h_n=filter_function_h_n, - filter_function_basis=filter_function_basis, - filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, - exponential_method=exponential_method, - is_skew_hermitian=is_skew_hermitian, - transfer_function=transfer_function, - amplitude_function=amplitude_function - ) - self.id_text = 'ALL' - self.cache_text = 'Save' - self.calculate_propagator_derivatives = \ - calculate_propagator_derivatives - self.frechet_deriv_approx_method = frechet_deriv_approx_method - - self._dyn_gen = None - - def set_optimization_parameters(self, y: np.array) -> None: - """See base class. """ - if not np.array_equal(self._opt_pars, y): - self.reset_cached_propagators() - super().set_optimization_parameters(y) - - def reset_cached_propagators(self): - """See base class. """ - self._dyn_gen = None - super().reset_cached_propagators() - - def _compute_dyn_gen(self) -> List[q_mat.OperatorMatrix]: - """ - Computes the dynamics generators. - - Returns - ------- - dyn_gen: List[ControlMatrix], len num_t - This is basically the total Hamiltonian. - - """ - self._dyn_gen = [-1j * h for h in self.h_drift] - for ctrl, ctrl_op in enumerate(self.h_ctrl): - for dyn_gen, ctrl_amp in \ - zip(self._dyn_gen, self._ctrl_amps[:, ctrl]): - dyn_gen += -1j * ctrl_amp * ctrl_op - return self._dyn_gen - - def _compute_derivative_directions( - self) -> List[List[q_mat.OperatorMatrix]]: - """ - The directions of the frechet derivatives are the control operators. - - No deep copy is required because the result is not used for in-place - operations. - - """ - # The list is multiplied (copied by reference) because the elements - # will not be manipulated in place. (only as copy) - return [[operator * -1j for operator in self.h_ctrl], ] * len(self.transferred_time) - - def _compute_propagation( - self, calculate_propagator_derivatives: Optional[bool] = None) \ - -> None: - """See base class. """ - super()._compute_propagation() - - if self._dyn_gen is None: - self._dyn_gen = self._compute_dyn_gen() - - if calculate_propagator_derivatives is None: - calculate_propagator_derivatives = \ - self.calculate_propagator_derivatives - - # initialize the attributes - self._prop = [None for _ in range(len(self.transferred_time))] - - if calculate_propagator_derivatives: - derivative_directions = self._compute_derivative_directions() - self._derivative_prop = [ - [None for _ in range(len(self.transferred_time))] - for _2 in range(len(self.h_ctrl))] - for t in range(len(self.transferred_time)): - for ctrl in range(len(self.h_ctrl)): - try: - self._prop[t], self._derivative_prop[ctrl][t] \ - = self._dyn_gen[t].dexp( - derivative_directions[t][ctrl], - self.transferred_time[t], - compute_expm=True, method=self.exponential_method, - is_skew_hermitian=self._is_skew_hermitian) - except ValueError: - raise ValueError('The computation has failed with ' - 'a value error. Try another ' - 'exponentiation method.') - else: - for t in range(len(self.transferred_time)): - self._prop[t] = self._dyn_gen[t].exp( - tau=self.transferred_time[t], method=self.exponential_method, - is_skew_hermitian=self._is_skew_hermitian) - - def _compute_propagation_derivatives(self) -> None: - """ - Computes the frechet derivatives of the propagators. - - The derivatives are not returned but cached. Since the function is only - called when no derivatives are cached, the approximation is - prioritised. - """ - if not self.frechet_deriv_approx_method: - self._compute_propagation(calculate_propagator_derivatives=True) - elif self.frechet_deriv_approx_method == 'grape': - if self._prop is None: - self._compute_propagation( - calculate_propagator_derivatives=False) - self._derivative_prop = [[None for _ in range(len(self.h_ctrl))] - for _2 in range(len(self.transferred_time))] - derivative_directions = self._compute_derivative_directions() - for t in range(len(self.transferred_time)): - for ctrl in range(len(self.h_ctrl)): - self._derivative_prop[t][ctrl] = \ - self.transferred_time[t] * derivative_directions[t][ctrl] \ - * self._prop[t] - else: - raise ValueError('Unknown gradient derivative approximation ' - 'method:' - + str(self.frechet_deriv_approx_method)) - - -def _compute_matrix_exponentials(input_dict): - """Computes the propagator of the Schroedinger equation by evaluation of - a matrix exponential. - - Parameters - ---------- - input_dict: dict - Holds the parameters in a single dict, because the function - multiprocessing.Pool.map requires a single input argument. The dict - has the fields time, matrices, method and is_skew_hermitian. See also - _compute_propagator. - - Returns - ------- - exponentials: list of ControlMatrix - A list of the propagators. - - """ - time = input_dict['time'] - matrices = input_dict['matrices'] - method = input_dict['method'] - is_skew_hermitian = input_dict['is_skew_hermitian'] - - exponentials = [None, ] * len(time) - for i, m, t in zip(range(len(matrices)), matrices, time): - exponentials[i] = m.exp( - tau=t, - method=method, - is_skew_hermitian=is_skew_hermitian) - return exponentials - - -class SchroedingerSMonteCarlo(SchroedingerSolver): - r""" - Solves Schroedinger's equation for explicit noise realisations as Monte - Carlo experiment. - - This time slot computer solves the Schroedinger equation explicitly for - concrete noise realizations. The noise traces are generated by an instance - of the Noise Trace Generator Class. Then they can be processed by the - noise amplitude function, before they are multiplied by the noise - hamiltionians. - - Parameters - ---------- - h_noise: List[ControlMatrix], len num_noise_operators - List of noise operators occurring in the Hamiltonian. - - noise_trace_generator: noise.NoiseTraceGenerator - Noise trace generator object. - - processes: int, optional - If an integer is given, then the propagation is calculated in - this number of parallel processes. If 1 then no parallel - computing is applied. If None then cpu_count() is called to use - all cores available. Defaults to 1. - - noise_amplitude_function: Callable[[noise_samples: np.array, - optimization_parameters: np.array, - transferred_parameters: np.array, - control_amplitudes: np.array], np.array] - The noise amplitude function calculated the noisy control amplitudes - corresponding to the noise samples. They recieve 4 keyword arguments - being the noise samples, the optimization parameters, the transferred - optimization parameters and the control amplitudes in this order. - The noise samples are given with the shape (n_samples_per_trace, - n_traces, n_noise_operators), the optimization parameters - (num_x, num_ctrl), the transferred parameters (num_t, num_ctrl) and - the control amplitudes (num_t, num_ctrl). The returned noise amplitudes - should be of the shape (num_t, n_traces, n_noise_operators). - - Attributes - ---------- - h_noise: List[ControlMatrix], len num_noise_operators - List of noise operators occurring in the Hamiltonian. - - noise_trace_generator: noise.NoiseTraceGenerator - Noise trace generator object. - - _dyn_gen_noise: List[List[ControlMatrix]], - shape [[] * num_t] * num_noise_traces - Dynamics generators for the individual noise traces. - - _prop_noise: List[List[ControlMatrix]], - shape [[] * num_t] * num_noise_traces - Propagators for the individual noise traces. - - _fwd_prop_noise: List[List[ControlMatrix]], - shape [[] * (num_t + 1)] * num_noise_traces - Cumulation of the propagators for the individual noise traces. They - describe the forward propagation of the systems state. - - _reversed_prop_noise: List[List[ControlMatrix]], - shape [[] * (num_t + 1)] * num_noise_traces - Cumulation of propagators in reversed order for the individual noise - traces. - - _derivative_prop_noise: List[List[List[ControlMatrix]]], - shape [[[] * num_t] * num_ctrl] * num_noise_traces - Frechet derivatives of the propagators by the control amplitudes for - the individual noise traces. - - Methods - ------- - propagators_noise: List[List[ControlMatrix]], - shape [[] * num_t] * num_noise_traces - Propagators for the individual noise traces. - - forward_propagators_noise: List[List[ControlMatrix]], - shape [[] * (num_t + 1)] * num_noise_traces - Cumulation of the propagators for the individual noise traces. They - describe the forward propagation of the systems state. - - reversed_propagators_noise: List[List[ControlMatrix]], - shape [[] * (num_t + 1)] * num_noise_traces - Cumulation of propagators in reversed order for the individual noise - traces. - - frechet_deriv_propagators_noise: List[List[List[ControlMatrix]]], - shape [[[] * num_t] * num_ctrl] * num_noise_traces - Frechet derivatives of the propagators by the control amplitudes for - the individual noise traces. - - """ - def __init__( - self, h_drift: List[q_mat.OperatorMatrix], - h_ctrl: List[q_mat.OperatorMatrix], - tau: np.array, - h_noise: List[q_mat.OperatorMatrix], - noise_trace_generator: - Optional[noise.NoiseTraceGenerator], - initial_state: q_mat.OperatorMatrix = None, - ctrl_amps: Optional[np.array] = None, - calculate_propagator_derivatives: bool = False, - processes: Optional[int] = 1, - filter_function_h_n: Union[ - Callable, List[List], None] = None, - filter_function_basis: Optional[basis.Basis] = None, - filter_function_n_coeffs_deriv: Optional[ - Callable[[np.ndarray], np.ndarray]] = None, - exponential_method: Optional[str] = None, - frechet_deriv_approx_method: Optional[str] = None, - is_skew_hermitian: bool = True, - transfer_function: Optional[TransferFunction] = None, - amplitude_function: Optional[AmplitudeFunction] = None, - noise_amplitude_function: Optional[Callable[ - [np.array, np.array, np.array, - np.array], np.array]] = None - ): - - super().__init__( - h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, - tau=tau, ctrl_amps=ctrl_amps, - filter_function_h_n=filter_function_h_n, - filter_function_basis=filter_function_basis, - filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, - exponential_method=exponential_method, - calculate_propagator_derivatives=calculate_propagator_derivatives, - frechet_deriv_approx_method=frechet_deriv_approx_method, - is_skew_hermitian=is_skew_hermitian, - transfer_function=transfer_function, - amplitude_function=amplitude_function) - - self.h_noise = h_noise - self.noise_trace_generator = noise_trace_generator - self.noise_amplitude_function = noise_amplitude_function - self.processes = processes - - self._dyn_gen_noise = None - self._prop_noise = None - self._derivative_prop_noise = None - self._fwd_prop_noise = None - self._reversed_prop_noise = None - - def set_optimization_parameters(self, y: np.array) -> None: - """See base class. """ - if not np.array_equal(self._opt_pars, y): - self.reset_cached_propagators() - super().set_optimization_parameters(y) - - def reset_cached_propagators(self): - """See base class. """ - super().reset_cached_propagators() - self._dyn_gen_noise = None - self._prop_noise = None - self._derivative_prop_noise = None - self._fwd_prop_noise = None - self._reversed_prop_noise = None - - - @property - def propagators_noise(self) -> List[List[q_mat.OperatorMatrix]]: - """ - Returns the propagators of the system for each noise trace and - calculates them if necessary. - - Returns - ------- - propagators_noise: List[List[ControlMatrix]], - shape [[] * num_t] * num_noise_traces - Propagators of the system for each noise trace. - - """ - if self._prop_noise is None: - self._compute_propagation() - return self._prop_noise - - @property - def forward_propagators_noise(self) -> List[List[q_mat.OperatorMatrix]]: - """ - Returns the forward propagation of the initial state for every time - slice and every noise trace and calculate it if necessary. If the - initial state is the identity matrix, then the cumulative propagators - are given. The element forward_propagators[k][i] propagates a state by - the first i time steps under the kth noise trace, if the initial state - is the identity matrix. - - Returns - ------- - forward_propagation:List[List[ControlMatrix]], - shape [[] * (num_t + 1)] * num_noise_traces - Propagation of the initial state of the system. fwd[0] gives the - initial state itself. - - """ - if self._fwd_prop_noise is None: - self._compute_forward_propagation() - return self._fwd_prop_noise - - @property - def frechet_deriv_propagators_noise(self) \ - -> List[List[List[q_mat.OperatorMatrix]]]: - """ - Returns the frechet derivatives of the propagators with respect to the - control amplitudes for each noise trace. - - Returns - ------- - derivative_prop_noise: List[List[List[ControlMatrix]]], - shape [[[] * num_t] * num_ctrl] * num_noise_traces - Frechet derivatives of the propagators by the control amplitudes. - - """ - if self._derivative_prop_noise is None: - self._compute_propagation_derivatives() - return self._derivative_prop_noise - - @property - def reversed_propagators_noise(self) -> List[List[q_mat.OperatorMatrix]]: - """ - Returns the reversed propagation of the initial state for every noise - trace and calculate it if necessary. If the initial state is the - identity matrix, then the reversed cumulative propagators are given. - The element forward_propagators[k][i] propagates a state by the first i - time steps under the kth noise trace, if the initial state is the - identity matrix. - - Returns - ------- - reversed_propagation_noise: List[List[ControlMatrix]], - shape [[] * (num_t + 1)] * num_noise_traces - Propagation of the initial state of the system. reversed[k][0] - gives the initial state itself. - - """ - if self._reversed_prop_noise is None: - self._compute_reversed_propagation() - return self._reversed_prop_noise - - def _compute_dyn_gen_noise(self) -> List[List[q_mat.OperatorMatrix]]: - """ - Computes the dynamics generators for the perturbed and unperturbed - Schroedinger equation. - - Returns - ------- - dyn_gen_noise: List[List[q_mat.ControlMatrix]], - shape [[] * num_t] * num_noise_traces - Dynamics generators for each noise trace. - - """ - # compute the generators of the unperturbed dynamics - self._dyn_gen = super()._compute_dyn_gen() - - # compute the generators for the noise traces. - n_noise_traces = self.noise_trace_generator.n_traces - - noise_samples = self.noise_trace_generator.noise_samples - # we transpose, so we iterate over the time last - noise_samples = np.transpose(noise_samples, (2, 1, 0)) - - if self.noise_amplitude_function: - noise_samples = self.noise_amplitude_function( - noise_samples=noise_samples, - optimization_parameters=self._opt_pars, - transferred_parameters=self.transferred_parameters, - control_amplitudes=self._ctrl_amps - ) - - self._dyn_gen_noise = [[dyn_gen.copy() for dyn_gen in self._dyn_gen] - for _ in range(n_noise_traces)] - - for t, sample_stack in enumerate(noise_samples): - for n_trace, trace in enumerate(sample_stack): - for operator_sample, operator in zip(trace, self.h_noise): - self._dyn_gen_noise[n_trace][t] += \ - (-1j * operator_sample) * operator - return self._dyn_gen_noise - - def _compute_propagation( - self, calculate_propagator_derivatives: Optional[bool] = None - ) -> None: - """ - Computes the propagators for the perturbed Schroedinger equation and - the derivatives on demand. - - Parameters - ---------- - calculate_propagator_derivatives: bool, optional - Calculate the derivatives of the propagators with respect to the - control amplitudes if true. - - """ - - if self._dyn_gen_noise is None: - self._dyn_gen_noise = self._compute_dyn_gen_noise() - - n_noise_traces = self.noise_trace_generator.n_traces - num_t = len(self.transferred_time) - num_ctrl = len(self.h_ctrl) - - self._prop_noise = [[None for _ in range(num_t)] - for _2 in range(n_noise_traces)] - - if calculate_propagator_derivatives is None: - calculate_propagator_derivatives = \ - self.calculate_propagator_derivatives - - # parallelization of following code probably unnecessary - if calculate_propagator_derivatives: - self._derivative_prop_noise = \ - [[[None for _ in range(num_t)] - for _2 in range(num_ctrl)] - for _3 in range(n_noise_traces)] - derivative_directions = self._compute_derivative_directions() - - # call the parent method for the noiseless propagators - super()._compute_propagation( - calculate_propagator_derivatives=calculate_propagator_derivatives) - - if self.processes == 1: - if calculate_propagator_derivatives: - for k in range(n_noise_traces): - for t in range(num_t): - for ctrl in range(len(self.h_ctrl)): - self._prop_noise[k][t], \ - self._derivative_prop_noise[k][ctrl][t] \ - = self._dyn_gen_noise[k][t].dexp( - derivative_directions[t][ctrl], - self.transferred_time[t], - compute_expm=True, - method=self.exponential_method, - is_skew_hermitian=self._is_skew_hermitian) - else: - for k in range(n_noise_traces): - for t in range(num_t): - self._prop_noise[k][t] = self._dyn_gen_noise[k][t].exp( - tau=self.transferred_time[t], - method=self.exponential_method, - is_skew_hermitian=self._is_skew_hermitian) - - elif (type(self.processes) == int and self.processes > 0) \ - or self.processes is None: - - if calculate_propagator_derivatives: - raise NotImplementedError - else: - input_dicts = [] - for k in range(n_noise_traces): - input_dicts.append(dict()) - input_dicts[-1]['time'] = self.transferred_time - input_dicts[-1]['matrices'] = self._dyn_gen_noise[k] - input_dicts[-1]['method'] = self.exponential_method - input_dicts[-1][ - 'is_skew_hermitian'] = self._is_skew_hermitian - - with Pool(processes=self.processes) as pool: - self._prop_noise = pool.map( - _compute_matrix_exponentials, input_dicts) - - else: - raise ValueError('Invalid number of processes for parallel ' - 'computation!') - - def _compute_forward_propagation(self) -> None: - """Computes the forward propagators. """ - super()._compute_forward_propagation() - if self._prop_noise is None: - self._compute_propagation() - - self._fwd_prop_noise = [ - [self.initial_state.copy(), ] - for _ in range(self.noise_trace_generator.n_traces)] - - for fwd_per_trace, prop_per_trace in zip(self._fwd_prop_noise, - self._prop_noise): - for prop in prop_per_trace: - fwd_per_trace.append(prop * fwd_per_trace[-1]) - - def _compute_reversed_propagation(self) -> None: - """Compute the reversed propagation. For the perturbed and unperturbed - Schroedinger equation. """ - super()._compute_reversed_propagation() - if self._prop_noise is None: - self._compute_propagation() - - self._reversed_prop_noise = [ - [self._prop[0].identity_like(), ] - for _ in range(self.noise_trace_generator.n_traces)] - - for rev_per_trace, prop_per_trace in zip(self._reversed_prop_noise, - self._prop_noise): - for prop in prop_per_trace[::-1]: - rev_per_trace.append(rev_per_trace[-1] * prop) - - def _compute_propagation_derivatives(self) -> None: - """ - Computes the frechet derivatives of the propagators. - - The derivatives are not returned but cached. Since the function is only - called when no derivatives are cached, the approximation is - prioritised. - """ - if not self.frechet_deriv_approx_method: - self._compute_propagation(calculate_propagator_derivatives=True) - elif self.frechet_deriv_approx_method == 'grape': - super()._compute_propagation_derivatives() - - if self._prop_noise is None: - self._compute_propagation( - calculate_propagator_derivatives=False) - - n_noise_traces = self.noise_trace_generator.n_traces - num_t = len(self.transferred_time) - num_ctrl = len(self.h_ctrl) - - self._derivative_prop_noise = [ - [[None for _ in range(num_t)] - for _2 in range(num_ctrl)] - for _3 in range(n_noise_traces)] - - derivative_directions = self._compute_derivative_directions() - - for k in range(n_noise_traces): - for t in range(len(self.transferred_time)): - for ctrl in range(num_ctrl): - self._derivative_prop_noise[k][ctrl][t] = \ - self.transferred_time[t] * derivative_directions[t][ctrl] \ - * self._prop_noise[k][t] - else: - raise ValueError('Unknown gradient derivative approximation ' - 'method:' - + str(self.frechet_deriv_approx_method)) - - -class SchroedingerSMCControlNoise(SchroedingerSMonteCarlo): - """ - Convenience class like `SchroedingerSMonteCarlo` but with noise on the - optimization parameters. - - This time slot computer solves the Schroedinger equation explicitly for - concrete control noise realizations. This time slot computer assumes, - that the noise is sampled on the time scale of the already transferred - optimization parameters. The control Hamiltionians are also used as noise - Hamiltionians and the noise amplitude function adds the noise samples to - the unperturbed transferred optimization parameters and applies the - amplitude function of the control amplitudes. - - """ - def __init__( - self, - h_drift: List[q_mat.OperatorMatrix], - h_ctrl: List[q_mat.OperatorMatrix], - tau: np.array, - noise_trace_generator: - Optional[noise.NoiseTraceGenerator], - initial_state: q_mat.OperatorMatrix = None, - ctrl_amps: Optional[np.array] = None, - calculate_propagator_derivatives: bool = False, - processes: Optional[int] = 1, - filter_function_h_n: Union[ - Callable, List[List], None] = None, - filter_function_basis: Optional[basis.Basis] = None, - filter_function_n_coeffs_deriv: Optional[ - Callable[[np.ndarray], np.ndarray]] = None, - exponential_method: Optional[str] = None, - frechet_deriv_approx_method: Optional[str] = None, - is_skew_hermitian: bool = True, - transfer_function: Optional[TransferFunction] = None, - amplitude_function: Optional[AmplitudeFunction] = None): - - def noise_amplitude_function(noise_samples: np.array, - transferred_parameters: np.array, - control_amplitudes: np.array, - **_): - """Calculates the noise amplitudes. - - Takes into account the actual optimization parameters and random - variations. - - Parameters - ---------- - noise_samples: np.array - Noise samples calculated by the noise trace generator. - - transferred_parameters: np.array - Transferred optimization parameters. - - control_amplitudes: np.array - Control amplitudes. - - """ - noise_amplitudes = np.zeros((noise_samples.shape[0],noise_samples.shape[1],control_amplitudes.shape[1]), dtype=complex) - # complex values were requested. - for trace_num in range(noise_samples.shape[1]): - noise_amplitudes[:, trace_num, :] = self.amplitude_function( - transferred_parameters + noise_samples[:, trace_num, :]) \ - - control_amplitudes - return noise_amplitudes - - super().__init__( - h_drift=h_drift, - h_ctrl=h_ctrl, - initial_state=initial_state, - tau=tau, - h_noise=h_ctrl, - noise_trace_generator=noise_trace_generator, - ctrl_amps=ctrl_amps, - calculate_propagator_derivatives=calculate_propagator_derivatives, - processes=processes, - filter_function_h_n=filter_function_h_n, - filter_function_basis=filter_function_basis, - filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, - exponential_method=exponential_method, - frechet_deriv_approx_method=frechet_deriv_approx_method, - is_skew_hermitian=is_skew_hermitian, - transfer_function=transfer_function, - amplitude_function=amplitude_function, - noise_amplitude_function=noise_amplitude_function - ) - - -class LindbladSolver(SchroedingerSolver): - r""" - Solves a master equation for an open quantum system in the Markov - approximation using the Lindblad super operator formalism. - - The master equation to be solved is - - .. math:: - - d \rho / dt = i [\rho, H] + \sum_k (L_k \rho L_k^\dagger - - .5 L_k^\dagger L_k \rho - .5 \rho L_k^\dagger L_k) - - - with the Lindblad operators L_k. The solution is calculated as - - .. math:: - - \rho(t) = exp[(-i \mathcal{H} + \mathcal{G})t] \rho(0) - - with the dissipative super operator - - .. math:: - - \mathcal{G} = \sum_k D(L_k) - - .. math:: - - D(L) = L^\ast \otimes L - .5 I \otimes (L^\dagger L) - - .5 (L^T L^\ast) \otimes I - - The dissipation super operator can be given in three different ways. - - 1. A nested list of dissipation super operators D(L_k) as control - matrices. - 2. A nested list of Lindblad operators L as control matrices. - 3. A function handle receiving the control amplitudes as sole argument and - returning a dissipation super operator as list of control matrices. - - Optionally a prefactor function can be given for 1. and 2. This function - receives the control parameters and returns an array of the shape - num_t x num_l where num_t is the number of time steps in the control and - num_l is the number of Lindblad operators or dissipation super operators. - - If multiple construction arguments are given, the implementation - prioritises the function (3.) over the Lindblad operators (2.) over the - dissipation super operator (1.). - - Parameters - ---------- - initial_diss_super_op: List[ControlMatrix], len num_l - Initial dissipation super operator; num_l is the number of - Lindbladians. Set if you want to use (1.) (See documentation above!). - The control matrices are expected to be of shape (dim, dim) where dim - is the dimension of the system. - - lindblad_operators: List[ControlMatrix], len num_l - Lindblad operators; num_l is the number of Lindbladians. Set if you - want to use (2.) (See documentation above!). The Lindblad operators are - assumend to be of shape (dim, dim) where dim is the dimension of the - system. - - prefactor_function: Callable[[np.array, np.array], np.array] - Receives the control amplitudes u (as numpy array of shape - (num_t, num_ctrl)) and the transferred optimization parameters (as - numpy array of shape (num_t, num_opt)) and returns prefactors as numpy - array of shape (num_t, num_l). The prefactors a_k are used as weights in - the sum of the total dissipation operator. - - .. math:: - - \mathcal{G} = \sum_k a_k * D(L_k) - - If the Lindblad operator is for example given by a complex number b_k - times a constant (in time) matrix C_k. - - .. math:: - - L_k = b_k * C_k - - Then the prefactor is the squared absolute value of this number: - - .. math:: - - a_k = |b_k|^2 - - Set if you want to use method (1.) or (2.). (See class documentation.) - - prefactor_derivative_function: Callable[[np.array, np.array], np.array] - Receives the control amplitudes u (as numpy array of shape - (num_t, num_ctrl)) and the transferred optimization parameters (as - numpy array of shape (num_t, num_opt)) and returns the derivatives of - the prefactors as numpy array of shape (num_t, num_ctrl, num_l). The - derivatives d_k are used as weights in the sum of the derivative of the - total dissipation operator. - - .. math:: - - d \mathcal{G} / d u_k = \sum_k d_k * D(L_k) - - If the Lindblad operator is for example given by a complex number b_k - times a constant (in time) matrix C_k. And this number depends on the - control amplitudes u_k - - .. math:: - - L_k = b_k (u_k) * C_k - - Then the derivative of the prefactor is the derivative of the squared - absolute value of this number: - - .. math:: - - d_k = d |b_k|^2 / d u_k - - Set if you want to use method (1.) or (2.). (See class documentation.) - - super_operator_function: Callable[[np.array, np.array], List[ControlMatrix]] - Receives the control amlitudes u (as numpy array of shape - (num_t, num_ctrl)) and the transferred optimization parameters (as - numpy array of shape (num_t, num_opt)) and returns the total dissipation - operators as list of length num_t. Set if you want to use method (3.). - (See class documentation.) - - super_operator_derivative_function: Callable[[np.array, np.array], - List[List[ControlMatrix]]] - Receives the control amlitudes u (as numpy array of shape - (num_t, num_ctrl)) and the transferred optimization parameters (as - numpy array of shape (num_t, num_opt)) and returns the derivatives of - the total dissipation operators as nested list of - shape [[] * num_ctrl] * num_t. Set if you - want to use method (3.). (See class documentation.) - - is_skew_hermitian: bool - If True, then the total dynamics generator is assumed to be skew - hermitian. - - Attributes - ---------- - _diss_sup_op: List[ControlMatrix], len num_t - Total dissipaton super operator. - - _diss_sup_op_deriv: List[List[ControlMatrix]], - shape [[] * num_ctrl] * num_t - Derivative of the total dissipation operator with respect to the - control amplitudes. - - _initial_diss_super_op: List[ControlMatrix], len num_l - Initial dissipation super operator; num_l is the number of - Lindbladians. - - _lindblad_operatorsList[ControlMatrix], len num_l - Lindblad operators; num_l is the number of Lindbladians. - - _prefactor_function: Callable[[np.array], np.array] - Receives the control amplitudes u (as numpy array of shape - (num_t, num_ctrl)) and returns prefactors as numpy array - of shape (num_t, num_l). The prefactors a_k are used as weights in the - sum of the total dissipation operator. - - .. math:: - - \mathcal{G} = \sum_k a_k * D(L_k) - - If the Lindblad operator is for example given by a complex number b_k - times a constant (in time) matrix C_k. - - .. math:: - - L_k = b_k * C_k - - Then the prefactor is the squared absolute value of this number: - - .. math:: - - a_k = |b_k|^2 - - Set if you want to use method (1.) or (2.). (See class documentation.) - - _prefactor_deriv_function: Callable[[np.array], np.array] - Receives the control amplitudes u (as numpy array of shape - (num_t, num_ctrl)) and returns the derivatives of the - prefactors as numpy array of shape (num_t, num_ctrl, num_l). The - derivatives d_k are used as weights in the sum of the derivative of the - total dissipation operator. - - .. math:: - - d \mathcal{G} / d u_k = \sum_k d_k * D(L_k) - - If the Lindblad operator is for example given by a complex number b_k - times a constant (in time) matrix C_k. And this number depends on the - control amplitudes u_k - - .. math:: - - L_k = b_k (u_k) * C_k - - Then the derivative of the prefactor is the derivative of the squared - absolute value of this number: - - .. math:: - - d_k = d |b_k|^2 / d u_k - - _sup_op_func: Callable[[np.array], List[ControlMatrix]] - Receives the control amplitudes u (as numpy array of shape - (num_t, num_ctrl)) and returns the total dissipation - operators as list of length num_t. - - _sup_op_deriv_func: Callable[[np.array], List[List[ControlMatrix]]] - Receives the control amplitudes u (as numpy array of shape - (num_t, num_ctrl)) and returns the derivatives of the total dissipation - operators as nested list of shape [[] * num_ctrl] * num_t. - - Methods - ------- - _parse_dissipative_super_operator: None - - _calc_diss_sup_op: List[ControlMatrix] - Calculates the total dissipation super operator. - - _calc_diss_sup_op_deriv: Optional[List[List[ControlMatrix]]] - Calculates the derivatives of the total dissipation super operators - with respect to the control amplitudes. - - `Todo` - * Write parser - - """ - - def __init__( - self, - h_drift: List[q_mat.OperatorMatrix], - h_ctrl: List[q_mat.OperatorMatrix], - tau: np.array, - initial_state: q_mat.OperatorMatrix = None, - ctrl_amps: Optional[np.array] = None, - calculate_unitary_derivatives: bool = False, - filter_function_h_n: Union[ - Callable, List[List], None] = None, - filter_function_basis: Optional[basis.Basis] = None, - filter_function_n_coeffs_deriv: Optional[ - Callable[[np.ndarray], np.ndarray]] = None, - exponential_method: Optional[str] = None, - frechet_deriv_approx_method: Optional[str] = None, - initial_diss_super_op: List[q_mat.OperatorMatrix] = None, - lindblad_operators: List[q_mat.OperatorMatrix] = None, - prefactor_function: Callable[[np.array, np.array], np.array] = None, - prefactor_derivative_function: - Callable[[np.array, np.array], np.array] = None, - super_operator_function: - Callable[[np.array, np.array], List[q_mat.OperatorMatrix]] = None, - super_operator_derivative_function: - Callable[[np.array, np.array], - List[List[q_mat.OperatorMatrix]]] = None, - is_skew_hermitian: bool = False, - transfer_function: Optional[TransferFunction] = None, - amplitude_function: Optional[AmplitudeFunction] = None) \ - -> None: - - if initial_state is None: - dim = h_ctrl[0].shape[0] - initial_state = type(h_ctrl[0])(np.eye(dim ** 2)) - - self._diss_sup_op = None - self._diss_sup_op_deriv = None - - # we do not throw away any operators or functions, just in case - self._initial_diss_super_op = initial_diss_super_op - self._lindblad_operators = lindblad_operators - self._prefactor_function = prefactor_function - self._prefactor_deriv_function = prefactor_derivative_function - self._sup_op_func = super_operator_function - self._sup_op_deriv_func = super_operator_derivative_function - self._is_hermitian = is_skew_hermitian - - super().__init__( - h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, - tau=tau, ctrl_amps=ctrl_amps, - calculate_propagator_derivatives=calculate_unitary_derivatives, - filter_function_h_n=filter_function_h_n, - filter_function_basis=filter_function_basis, - filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, - exponential_method=exponential_method, - frechet_deriv_approx_method=frechet_deriv_approx_method, - is_skew_hermitian=is_skew_hermitian, - transfer_function=transfer_function, - amplitude_function=amplitude_function) - - def set_optimization_parameters(self, y: np.array) -> None: - """See base class. """ - if not np.array_equal(self._opt_pars, y): - super().set_optimization_parameters(y) - self.reset_cached_propagators() - - def reset_cached_propagators(self): - """ See base class. """ - super().reset_cached_propagators() - if self._prefactor_function is not None \ - or self._sup_op_func is not None: - self._diss_sup_op = None - self._diss_sup_op_deriv = None - - - def _calc_diss_sup_op(self) -> List[q_mat.OperatorMatrix]: - r""" - Calculates the dissipative super operator as described in the class - doc string. - - Returns - ------- - diss_sup_op: List[ControlMatrix], len num_l - Dissipation super operator; Where num_l is the number of Lindblad - terms. - - """ - if self._sup_op_func is None: - # use Lindblad operators - if self._lindblad_operators is None: - # use dissipation_sup_op - const_diss_sup_op = self._initial_diss_super_op - else: - # Calculate the time constant dissipation super operators - # without time dependence - const_diss_sup_op = [] - identity = self._lindblad_operators[0].identity_like() - - for lindblad in self._lindblad_operators: - const_diss_sup_op.append( - (lindblad.conj(do_copy=True)).kron(lindblad)) - const_diss_sup_op[-1] -= .5 * identity.kron( - lindblad.dag(do_copy=True) * lindblad) - const_diss_sup_op[-1] -= .5 * ( - lindblad.transpose(do_copy=True) - * lindblad.conj(do_copy=True)).kron(identity) - - # Add the time dependence - if self._prefactor_function is not None: - self._diss_sup_op = [] - prefactors = self._prefactor_function( - copy.deepcopy(self._ctrl_amps), - copy.deepcopy(self.transferred_parameters)) - for factor_at_time_t in prefactors: - self._diss_sup_op.append( - const_diss_sup_op[0] * factor_at_time_t[0]) - for sup_op, factor \ - in zip(const_diss_sup_op[1:], - factor_at_time_t[1:]): - self._diss_sup_op[-1] += sup_op * factor - else: - self._diss_sup_op = [const_diss_sup_op[0], ] - for sup_op in const_diss_sup_op[1:]: - self._diss_sup_op[0] += sup_op - self._diss_sup_op *= len(self.transferred_time) - else: - self._diss_sup_op = self._sup_op_func( - copy.deepcopy(self._ctrl_amps), - copy.deepcopy(self.transferred_parameters)) - return self._diss_sup_op - - def _calc_diss_sup_op_deriv(self) \ - -> Optional[List[List[q_mat.OperatorMatrix]]]: - r""" - Calculates the derivatives of the dissipation super operator with - respect to the control amplitudes. - - If the dissipation super operator is given as constant (1.) or as - lindblad operators (2.) they are assumed not to depend on the control - parameters and only the derivative of the prefactor is to be taken into - account. In order to do so, a function handle containing the - derivatives must be given. This function receives the control - amplitudes as num_t x num_ctrl numpy array and returns the derivatives - as num_t x num_l x num_ctrl array. - - If the dissipation super operator is given as function handle (3.), - then the derivatives must also be given as function handle receiving - the control amplitudes and returning a nested list of super operators - as control matrices. - - If the requested derivative functions are not provided (None), then - the dissipation super operator is considered constant in the control - amplitudes and the function returns None. - - Returns - ------- - diss_sup_op_deriv: Optional[List[List[q_mat.ControlMatrix]]], - shape [[] * num_ctrl] * num_t - The derivatives of the dissipation super operator with respect to - the control variables. - - """ - - if self._sup_op_deriv_func is not None: - self._diss_sup_op_deriv = \ - self._sup_op_deriv_func( - copy.deepcopy(self._ctrl_amps), - copy.deepcopy(self.transferred_parameters)) - return self._diss_sup_op_deriv - - elif self._prefactor_deriv_function is not None: - if self._lindblad_operators is None: - # use dissipation_sup_op - const_diss_sup_op = self._initial_diss_super_op - else: - # Calculate the time constant dissipation super operators - # without time dependence - const_diss_sup_op = [] - identity = self._lindblad_operators[0].identity_like() - - for lindblad in self._lindblad_operators: - const_diss_sup_op.append( - (lindblad.conj(do_copy=True)).kron(lindblad)) - const_diss_sup_op[-1] -= .5 * identity.kron( - lindblad.dag(do_copy=True) * lindblad) - const_diss_sup_op[-1] -= .5 * ( - lindblad.transpose(do_copy=True) - * lindblad.conj(do_copy=True)).kron(identity) - - prefactor_derivatives = \ - self._prefactor_deriv_function( - copy.deepcopy(self._ctrl_amps), - copy.deepcopy(self.transferred_parameters)) - - # Todo: Assert that the prefactor returns the right dimension - - # prefactor_derivatives: shape (num_t, num_ctrl, num_l) - diss_sup_op_deriv = [] - for factor_per_ctrl_lind in prefactor_derivatives: - # create new sub list for eacht time step - diss_sup_op_deriv.append([]) - for factor_per_lind in factor_per_ctrl_lind: - # add the first term for each control direction - diss_sup_op_deriv[-1].append( - const_diss_sup_op[0] * factor_per_lind[0]) - for diss_sup_op, factor in zip( - const_diss_sup_op[1:], factor_per_lind[1:]): - # add the remaining terms - diss_sup_op_deriv[-1][-1] += diss_sup_op * factor - self._diss_sup_op_deriv = diss_sup_op_deriv - return diss_sup_op_deriv - else: - return None - - def _compute_derivative_directions( - self) -> List[List[q_mat.OperatorMatrix]]: - r""" - Computes the derivative directions of the total dynamics generator. - - Returns - ------- - deriv_directions: List[List[q_mat.ControlMatrix]], - shape [[] * num_ctrl] * num_t - Derivative directions given by - - .. math:: - - -1j * (I \otimes H_k - H_k \otimes I) + d \mathcal{G} / d u_k - - """ - # derivative of the coherent part - identity_times_i = self.h_ctrl[0].identity_like() - identity_times_i *= -1j - h_ctrl_sup_op = [] - for ctrl_op in self.h_ctrl: - h_ctrl_sup_op.append(identity_times_i.kron(ctrl_op)) - h_ctrl_sup_op[-1] -= (ctrl_op.transpose(do_copy=True)).kron( - identity_times_i) - - # add derivative of the dissipation part - if self._diss_sup_op_deriv is None: - self._diss_sup_op_deriv = self._calc_diss_sup_op_deriv() - if self._diss_sup_op_deriv is not None: - dh_by_ctrl = [] - for diss_sup_op_deriv_at_t in self._diss_sup_op_deriv: - dh_by_ctrl.append([]) - for diss_sup_op_deriv, ctrl_sup_op \ - in zip(diss_sup_op_deriv_at_t, h_ctrl_sup_op): - dh_by_ctrl[-1].append(diss_sup_op_deriv + ctrl_sup_op) - else: - dh_by_ctrl = [h_ctrl_sup_op, ] * len(self.transferred_time) - - return dh_by_ctrl - - def _parse_dissipative_super_operator(self) -> None: - r""" - check the dissipative super operator for dimensional consistency - (maybe even physical properties) - - not implemented yet - - """ - pass - - def _compute_dyn_gen(self) -> List[q_mat.OperatorMatrix]: - r""" - Computes the dynamics generator for the Lindblad master equation. - - The Hamiltonian is translated into the master equation formalism as - - .. math:: - - \mathcal{H} = I \otimes H - H^\ast \otimes I - - Then the dissipation super operator is added. - - Returns - ------- - dyn_gen: List[ControlMatrix], len num_t - Dynamics generators for the master equation. - - Raises - ------ - ValueError: - The computation is only defined for the use of dense control - matrices. - - """ - self._dyn_gen = super()._compute_dyn_gen() - - if self._diss_sup_op is None: - self._diss_sup_op = self._calc_diss_sup_op() - - identiy_operator = self._dyn_gen[0].identity_like() - sup_op_dyn_gen = [] - - assert(len(self._dyn_gen) == len(self._diss_sup_op)) - - for dyn_gen, diss_sup_op in zip(self._dyn_gen, self._diss_sup_op): - sup_op_dyn_gen.append(identiy_operator.kron(dyn_gen)) - # the cancelling minus sign accounts for the -i factor, which is - # also conjugated (included in the dyn gen) - sup_op_dyn_gen[-1] += dyn_gen.conj(do_copy=True).kron( - identiy_operator) - sup_op_dyn_gen[-1] += diss_sup_op - - self._dyn_gen = sup_op_dyn_gen - return sup_op_dyn_gen - - -class LindbladSControlNoise(LindbladSolver): - """ - Special case of the Lindblad master equation. It considers white noise on - the control parameters. The same functionality should be implementable - with the parent class, but less convenient. - """ - - @needs_refactoring - def __init__(self, h_drift, h_ctrl, initial_state, tau, - ctrl_amps, transfer_function=None, - calculate_unitary_derivatives=True, filter_function_h_n=None, - exponential_method=None, lindblad_operators=None, - constant_lindblad_operators=False, noise_psd=1): - super().__init__( - h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, - tau=tau, ctrl_amps=ctrl_amps, - calculate_unitary_derivatives=calculate_unitary_derivatives, - filter_function_h_n=filter_function_h_n, - exponential_method=exponential_method) - - if lindblad_operators is None: - self.lindblad_super_operator = None - else: - d = lindblad_operators[0].shape[0] - self.lindblad_super_operator = np.zeros( - (len(lindblad_operators), d**2, d**2)) - for i, l in enumerate(lindblad_operators): - self.lindblad_super_operator[i, :, :] += np.kron(np.conj(l), l) - self.lindblad_super_operator[i, :, :] += -.5 * np.kron( - np.eye(d), l.T.conj() @ l) - self.lindblad_super_operator[i, :, :] += -.5 * np.kron( - l.T @ l.conj(), np.eye(d)) - - self.transfer_function = transfer_function - # if no transfer function is given it might be consider to be identity - # its not necessarily required - - self.constant_lindblad_operators = constant_lindblad_operators - self.noise_psd = noise_psd - self.incoherent_dyn_gen = None - - def _compute_propagation(self): - """ - - """ - # Compute and cache all dyn_gen (basically the total hamiltonian) - self._dyn_gen = copy.deepcopy(self.h_drift) - self._dyn_gen += np.sum(self._ctrl_amps * self.h_ctrl, axis=1) - - # initialize the attributes - self._prop = [None] * self.num_t - self._dU = np.array(shape=(self.num_t, self.num_ctrl), - dtype=matrix.DenseOperator) - self._fwd = [self.initial_state] - - # super operator calculation - # this is the special case for charge noise on the control parameters - # the required filter function contains - if not self.constant_lindblad_operators or \ - self.incoherent_dyn_gen is None: - transfer_matrix = self.transfer_function.transfer_matrix - self.incoherent_dyn_gen = np.einsum('ijk,klm,k->ilm', - transfer_matrix, - self.lindblad_super_operator, - self.noise_psd) - dim = self._dyn_gen[0].shape[0] - for i, gen in enumerate(self._dyn_gen): - gen = -1j * np.kron( - np.eye(dim), gen.data) - np.kron(gen.data, np.eye(dim)) - gen += self.incoherent_dyn_gen[i, :, :] - gen = matrix.DenseOperator(gen) - - # calculation of the propagators - for t in range(len(self.num_t)): - if self.calculate_propagator_derivatives: - for ctrl in range(self.num_ctrl): - direction = np.kron( - np.eye(dim), self.h_ctrl[t][ctrl]) - np.kron( - self.h_ctrl[t][ctrl], np.eye(dim)) - self._prop[t], self._dU[t, ctrl] = self._dyn_gen[t].dexp( - direction=direction, tau=self.transferred_time[t], - compute_expm=True, method=self.exponential_method) - - else: - self._prop[t] = self._dyn_gen[t].exp( - tau=self.transferred_time[t], method=self.exponential_method) - - self._fwd.append(self._prop[t] * self._fwd[t]) - - self.prop_calculated = True diff --git a/qopt/transfer_function.py b/qopt/transfer_function.py index ff8a21c..d1ccc38 100644 --- a/qopt/transfer_function.py +++ b/qopt/transfer_function.py @@ -116,6 +116,7 @@ from qopt.util import deprecated, needs_refactoring + class TransferFunction(ABC): """ A class for representing transfer functions, between optimization @@ -1713,467 +1714,3 @@ def set_times(self, times): super().set_times(times) # TODO: properly implement 'w' - - -############################################################################### - -try: - import jax.numpy as jnp - from jax import vmap - _HAS_JAX = True -except ImportError: - from unittest import mock - jnp = mock.Mock() - vmap = mock.Mock() - _HAS_JAX = False - -class TransferFunctionJAX(TransferFunction): - """See docstring of class w/o JAX.""" - - def __init__(self, - num_ctrls: int = 1, - bound_type: Optional[Tuple[str, int]] = None, - oversampling: int = 1, - offset: Optional[float] = None - ): - if not _HAS_JAX: - raise ImportError("JAX not available") - super().__init__(num_ctrls,bound_type,oversampling,offset) - - @abstractmethod - def __call__(self, y: Union[np.array,jnp.array]) -> jnp.array: - """Calculate the transferred optimization parameters (x). - - Evaluates the transfer function at the raw optimization parameters (y) - to calculate the transferred optimization parameters (x). - - Parameters - ---------- - y: Union[np.array,jnp.array], shape (num_y, num_par) - Raw optimization variables; num_y is the number of time slices of - the raw optimization parameters and num_par is the number of - distinct raw optimization parameters. - - Returns - ------- - u: jnp.array, shape (num_x, num_par) - Control parameters; num_u is the number of times slices for the - transferred optimization parameters. - - """ - pass - - @property - def num_padding_elements(self) -> (int, int): - """ - Convenience function. Returns the number of elements padded to the - beginning and the end of the control amplitude times. - - Returns - ------- - num_padding_elements: (int, int) - (elements padded to the beginning, elements padded to the end) - - """ - if self.bound_type is None: - return 0, 0 - elif self.bound_type[0] == 'n': - return self.bound_type[1], self.bound_type[1] - elif self.bound_type[0] == 'x': - return self.bound_type[1] * self.oversampling, \ - self.bound_type[1] * self.oversampling - elif self.bound_type[0] == 'right_n': - return 0, self.bound_type[1] - else: - raise ValueError('Unknown bound type ' + str(self.bound_type[0])) - - @abstractmethod - def gradient_chain_rule( - self, deriv_by_transferred_par: Union[np.array,jnp.array] - ) -> jnp.array: - """ - Obtain the derivatives of a quantity a i.e. da/dy by the optimization - variables from the derivatives by the amplitude of the control fields. - - The chain rule applies: df/dy = df/dx * dx/dy. - - Parameters - ---------- - deriv_by_transferred_par: Union[np.array,jnp.array], - shape (num_x, num_f, num_par) - The gradients of num_f functions by num_par optimization parameters - at num_x different time steps. - - Returns - ------- - deriv_by_opt_par: np.array, shape: (num_y, num_f, num_par) - The derivatives by the optimization parameters at num_y time steps. - - """ - pass - - def set_times(self, y_times: Union[np.array,jnp.array]) -> None: - """ - Generate the time_slot duration array 'transferred_time' - (here: x_times). - - The time slices depend on the oversampling of the control variables - and the boundary conditions. The times are for the intended use cases - only set once. - - Parameters - ---------- - y_times: Union[np.ndarray, jnp.ndarray, list], shape (num_y) - The time steps / durations of constant optimization variables. - num_y is the number of time steps for the raw optimization - variables. - - """ - if isinstance(y_times, list): - y_times = jnp.array(y_times) - if not isinstance(y_times, (np.ndarray,jnp.ndarray)): - raise Exception("times must be a list or (j)np.array") - - y_times = jnp.atleast_1d(jnp.squeeze(y_times)) - - if len(y_times.shape) > 1: - raise ValueError('The x_times should not have more than one ' - 'dimension!') - - self._num_y = y_times.size - self._y_times = y_times - - if self.bound_type is None: - self.num_x = self.oversampling * self._num_y - self.x_times = jnp.repeat( - self._y_times, self.oversampling) / self.oversampling - - elif self.bound_type[0] == 'n': - self.num_x = self.oversampling * self._num_y + 2 \ - * self.bound_type[1] - self.x_times = jnp.concatenate(( - self._y_times[0] / self.oversampling - * jnp.ones(self.bound_type[1]), - jnp.repeat( - self._y_times / self.oversampling, self.oversampling), - self._y_times[-1] / self.oversampling - * jnp.ones(self.bound_type[1]))) - - elif self.bound_type[0] == 'x': - self.num_x = self.oversampling * (self._num_y - + 2 * self.bound_type[1]) - self.x_times = jnp.concatenate(( - self._y_times[0] / self.oversampling - * jnp.ones(self.bound_type[1] * self.oversampling), - jnp.repeat(self._y_times / self.oversampling, - self.oversampling), - self._y_times[-1] / self.oversampling - * jnp.ones(self.bound_type[1] * self.oversampling))) - - elif self.bound_type[0] == 'right_n': - self.num_x = self.oversampling * self._num_y + self.bound_type[1] - self.x_times = np.concatenate(( - jnp.repeat(self._y_times / self.oversampling, - self.oversampling), - self._y_times[-1] / self.oversampling - * jnp.ones(self.bound_type[1]))) - - else: - raise ValueError('The boundary type ' + str(self.bound_type[0]) - + ' is not implemented!') - - def set_absolute_times(self, - absolute_y_times: Union[np.array,jnp.array,list] - ) -> None: - """ - Generate the time_slot duration array 'transferred_time' - (here: x_times) - - This time slices depend on the oversampling of the control variables - and the boundary conditions. The differences of the absolute times - give the time steps x_times. - - Parameters - ---------- - absolute_y_times: Union[np.array,jnp.array,list] - Absolute times of the start / end of each time segment for the raw - optimization parameters. - - """ - if isinstance(absolute_y_times, list): - absolute_y_times = jnp.array(absolute_y_times) - if not isinstance(absolute_y_times, Union[np.array,jnp.array]): - raise Exception("times must be a list or (j)np.array") - if not jnp.all(jnp.diff(absolute_y_times) >= 0): - raise Exception("times must be sorted") - - self._absolute_y_times = absolute_y_times - self.set_times(jnp.diff(absolute_y_times)) - - def plot_pulse(self, y: Union[np.array,jnp.array]) -> None: - """ - - Plot the control amplitudes corresponding to the given optimisation - variables. - - Parameters - ---------- - y: array, shape (num_y, num_par) - Raw optimization parameters. - - """ - - x = self(y) - #plotting not good with jnp(?) - x, y = np.array(x), np.array(y) - n_padding_start, n_padding_end = self.num_padding_elements - for y_per_control, x_per_control in zip(y.T, x.T): - plt.figure() - plt.bar(np.cumsum(self.x_times) - .5 * self.x_times[0], - x_per_control, self.x_times[0]) - plt.bar(np.cumsum(self._y_times) - .5 * self._y_times[0] - + np.cumsum(self._y_times)[n_padding_start] - - self._y_times[n_padding_start], - y_per_control, self._y_times[0], - fill=False) - plt.show() - - -class IdentityTFJAX(TransferFunctionJAX): - """See docstring of class w/o JAX.""" - - def __init__(self, num_ctrls=1): - super().__init__( - bound_type=None, - oversampling=1, - num_ctrls=num_ctrls, - offset=0. - ) - self.name = 'Identity' - - def __call__(self, y: Union[np.array,jnp.array]) -> jnp.array: - """See base class. """ - return jnp.asarray(y) - - def gradient_chain_rule( - self, deriv_by_transferred_par: Union[np.array,jnp.array] - ) -> jnp.array: - """See base class. """ - return jnp.asarray(deriv_by_transferred_par) - - -class OversamplingTFJAX(TransferFunctionJAX): - """See docstring of class w/o JAX.""" - - def __init__(self, - num_ctrls: int = 1, - bound_type: Optional[Tuple[str, int]] = None, - oversampling: int = 1 - ): - super().__init__( - num_ctrls=num_ctrls, - bound_type=bound_type, - oversampling=oversampling - ) - - def _calculate_transfer_matrix(self): - """Overrides the base class method. """ - raise NotImplementedError - - def __call__(self, y: Union[np.array,jnp.array]) -> jnp.array: - """Calculate the transferred optimization parameters (x). - - Only the oversampling and boundaries are taken into account. - - Parameters - ---------- - y: Union[np.array,jnp.array], shape (num_y, num_par) - Raw optimization variables; num_y is the number of time slices of - the raw optimization parameters and num_par is the number of - distinct raw optimization parameters. - - Returns - ------- - u: jnp.array, shape (num_x, num_par) - Control parameters; num_u is the number of times slices for the - transferred optimization parameters. - - """ - # oversample pulse by repetition - u = jnp.repeat(y, self.oversampling, axis=0) - - # add the padding elements - padding_start, padding_end = self.num_padding_elements - - u = jnp.concatenate( - (jnp.zeros((padding_start, self.num_ctrls)), - u, - jnp.zeros((padding_end, self.num_ctrls))), axis=0) - - return u - - def gradient_chain_rule( - self, deriv_by_transferred_par: Union[np.array,jnp.array] - ) -> jnp.array: - """ - See base class. - - Processing without transfer matrix. - - Parameters - ---------- - deriv_by_transferred_par: Union[np.array,jnp.array], - shape (num_x, num_f, num_par) - The gradients of num_f functions by num_par optimization parameters - at num_x different time steps. - - Returns - ------- - deriv_by_opt_par: jnp.array, shape: (num_y, num_f, num_par) - The derivatives by the optimization parameters at num_y time steps. - - """ - - shape = deriv_by_transferred_par.shape - assert len(shape) == 3 - assert shape[0] == self.num_x - assert shape[2] == self.num_ctrls - - # delete the padding elements - padding_start, padding_end = self.num_padding_elements - - # deriv_by_ctrl_amps: shape (num_x, num_f, num_par) - if padding_end > 0: - cropped_derivs = deriv_by_transferred_par[ - padding_start:-padding_end, :, :] - else: - cropped_derivs = deriv_by_transferred_par[ - padding_start:, :, :] - - cropped_derivs = jnp.expand_dims(cropped_derivs, axis=1) - cropped_derivs = jnp.reshape( - cropped_derivs, ( - self._num_y, - self.oversampling, - cropped_derivs.shape[2], - cropped_derivs.shape[3] - ) - ) - deriv_by_opt_par = jnp.sum(cropped_derivs, axis=1) - return deriv_by_opt_par - - -#### - -class LinearInterpTFJAX(TransferFunctionJAX): - """See docstring of class w/o JAX.""" - - def __init__(self, - num_ctrls: int = 1, - bound_type: Optional[Tuple[str, int]] = None, - oversampling: int = 1 - ): - super().__init__( - num_ctrls=num_ctrls, - bound_type=bound_type, - oversampling=oversampling - ) - - def _calculate_transfer_matrix(self): - """Overrides the base class method. """ - raise NotImplementedError - - def __call__(self, y: Union[np.array,jnp.array]) -> jnp.array: - """Calculate the transferred optimization parameters (x). - - Only the oversampling and boundaries are taken into account. - - Parameters - ---------- - y: Union[np.array,jnp.array], shape (num_y, num_par) - Raw optimization variables; num_y is the number of time slices of - the raw optimization parameters and num_par is the number of - distinct raw optimization parameters. - - Returns - ------- - u: jnp.array, shape (num_x, num_par) - Control parameters; num_u is the number of times slices for the - transferred optimization parameters. - - """ - # oversample pulse by repetition - # u = jnp.repeat(y, self.oversampling, axis=0) - - x_arr_old, x_arr_new = \ - jnp.linspace(0,y.shape[0],y.shape[0],endpoint=False), \ - jnp.linspace(0,y.shape[0],y.shape[0]*self.oversampling,endpoint=False) - #as coded now has base at beginning of time interval - u = jnp.moveaxis(vmap(jnp.interp,in_axes=(None,None,1))(x_arr_new,x_arr_old,y),0,1) - - # add the padding elements - #TODO: not implemented as not used so far - if self.num_padding_elements[0] != 0 or self.num_padding_elements[1] != 0: - raise NotImplementedError - # padding_start, padding_end = self.num_padding_elements - - # u = jnp.concatenate( - # (jnp.zeros((padding_start, self.num_ctrls)), - # u, - # jnp.zeros((padding_end, self.num_ctrls))), axis=0) - - return u - - def gradient_chain_rule( - self, deriv_by_transferred_par: Union[np.array,jnp.array] - ) -> jnp.array: - """ - See base class. - - Processing without transfer matrix. - - Parameters - ---------- - deriv_by_transferred_par: Union[np.array,jnp.array], - shape (num_x, num_f, num_par) - The gradients of num_f functions by num_par optimization parameters - at num_x different time steps. - - Returns - ------- - deriv_by_opt_par: jnp.array, shape: (num_y, num_f, num_par) - The derivatives by the optimization parameters at num_y time steps. - - """ - - shape = deriv_by_transferred_par.shape - - assert len(shape) == 3 - assert shape[0] == self.num_x - assert shape[2] == self.num_ctrls - # assert self.num_x//self.oversampling > 3 #to avoid complications - # assert self.x//self.oversampling == - - # delete the padding elements - if self.num_padding_elements[0] != 0 or self.num_padding_elements[1] != 0: - raise NotImplementedError - # padding_start, padding_end = self.num_padding_elements - m_arr = jnp.arange(0,self.oversampling)/self.oversampling - len_m = len(m_arr) - - deriv_by_opt_par = np.empty((self.num_x//self.oversampling,shape[1],shape[2])) - - - deriv_by_opt_par[0,:,:] = jnp.sum(deriv_by_transferred_par[0:self.oversampling]*(1-m_arr[:,np.newaxis,np.newaxis]),axis=0) - - deriv_by_opt_par[self.num_x//self.oversampling-1,:,:] = jnp.sum(deriv_by_transferred_par[self.oversampling*(self.num_x//self.oversampling-2):self.oversampling*(self.num_x//self.oversampling -1)]*m_arr[:,np.newaxis,np.newaxis],axis=0) - - - #slow but less memory consumption to avoid y*x shape - for i in range(1,self.num_x//self.oversampling -1): - deriv_by_opt_par[i,:,:] = jnp.sum(deriv_by_transferred_par[self.oversampling*(i-1):self.oversampling*i]*m_arr[:,np.newaxis,np.newaxis],axis=0) +\ - jnp.sum(deriv_by_transferred_par[self.oversampling*i:self.oversampling*(i+1)]*(1-m_arr[:,np.newaxis,np.newaxis]),axis=0) - - - # deriv_by_opt_par = jnp.sum(cropped_derivs, axis=1) - return jnp.asarray(deriv_by_opt_par) \ No newline at end of file