Skip to content

Commit

Permalink
Merge pull request #25
Browse files Browse the repository at this point in the history
merge (undocumented) conveyor simulation changes
  • Loading branch information
Nomos11 committed Apr 30, 2024
2 parents 7902a45 + 92cbcd5 commit b9c4046
Show file tree
Hide file tree
Showing 11 changed files with 8,361 additions and 99 deletions.
9 changes: 9 additions & 0 deletions qopt/__init__.py
Expand Up @@ -82,3 +82,12 @@
__version__ = '1.3'
__license__ = 'GNU GPLv3+'
__author__ = 'Julian Teske, Forschungszentrum Juelich'


try:
from jax.config import config
config.update("jax_enable_x64", True)
#TODO: add new objects here/ import other stuff?
# __all__ += []
except ImportError:
pass
125 changes: 124 additions & 1 deletion qopt/amplitude_functions.py
Expand Up @@ -64,10 +64,11 @@
"""

from abc import ABC, abstractmethod
from typing import Callable
from typing import Callable, Optional

import numpy as np

from typing import Union

class AmplitudeFunction(ABC):
"""Abstract Base class of the amplitude function. """
Expand Down Expand Up @@ -218,3 +219,125 @@ def derivative_by_chain_rule(self, deriv_by_ctrl_amps: np.ndarray,
# return: shape (time, func, par)

return np.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps)


###############################################################################

try:
import jax.numpy as jnp
from jax import jit,vmap,jacfwd
_HAS_JAX = True
except ImportError:
from unittest import mock
jit, vmap, jacfwd = mock.Mock(), mock.Mock(), mock.Mock()
jnp = mock.Mock()
_HAS_JAX = False


class IdentityAmpFuncJAX(AmplitudeFunction):
"""See docstring of class without JAX.
Designed to return jax-numpy-arrays.
"""

def __init__(self):
if not _HAS_JAX:
raise ImportError("JAX not available")

def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
"""See base class. """
return jnp.asarray(x)

def derivative_by_chain_rule(
self,
deriv_by_ctrl_amps: Union[np.ndarray,jnp.ndarray],
x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
"""See base class. """
return jnp.asarray(deriv_by_ctrl_amps)


class UnaryAnalyticAmpFuncJAX(AmplitudeFunction):
"""See docstring of class without JAX.
Designed to return jax-numpy-arrays.
Functions need to be compatible with jit.
(Includes that functions need to be pure
(i.e. output solely depends on input)).
"""

def __init__(self,
value_function: Callable[[float, ], float],
derivative_function: [Callable[[float, ], float]]):
if not _HAS_JAX:
raise ImportError("JAX not available")
self.value_function = jit(jnp.vectorize(value_function))
self.derivative_function = jit(jnp.vectorize(derivative_function))

def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
"""See base class. """
return jnp.asarray(self.value_function(x))

def derivative_by_chain_rule(
self,
deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray], x):
"""See base class. """
du_by_dx = self.derivative_function(x)
# du_by_dx shape: (n_time, n_ctrl)
# deriv_by_ctrl_amps shape: (n_time, n_func, n_ctrl)
# deriv_by_opt_par shape: (n_time, n_func, n_ctrl
# since the function is unary we have n_ctrl = n_amps
return jnp.einsum('ij,ikj->ikj', du_by_dx, deriv_by_ctrl_amps)


class CustomAmpFuncJAX(AmplitudeFunction):
"""See docstring of class without JAX.
Designed to return jax-numpy-arrays.
Functions need to be compatible with jit.
(Includes that functions need to be pure
(i.e. output solely depends on input)).
If derivative_function=None, autodiff is used.
t_to_vectorize: if value_function/derivative_function not yet
vectorized for num_t
"""

def __init__(
self,
value_function: Callable[[Union[np.ndarray, jnp.ndarray],],
Union[np.ndarray, jnp.ndarray]],
derivative_function: Callable[[Union[np.ndarray, jnp.ndarray],],
Union[np.ndarray, jnp.ndarray]],
t_to_vectorize: bool = False
):
if not _HAS_JAX:
raise ImportError("JAX not available")
if t_to_vectorize == True:
self.value_function = jit(vmap(value_function),in_axes=(0,))
else:
self.value_function = jit(value_function)
if derivative_function is not None:
if t_to_vectorize == True:
self.derivative_function = jit(vmap(derivative_function),in_axes=(0,))
else:
self.derivative_function = jit(derivative_function)
else:
if t_to_vectorize == True:
def der_wrapper(x):
return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(x)),in_axes=(0,))(x),1,2)
else:
def der_wrapper(x):
return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(jnp.expand_dims(x,axis=0))[0,:]),in_axes=(0,))(x),1,2)
self.derivative_function = jit(der_wrapper)

def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
"""See base class. """
return jnp.asarray(self.value_function(x))

def derivative_by_chain_rule(
self,
deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray],
x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
"""See base class. """
du_by_dx = self.derivative_function(x)
# du_by_dx: shape (time, par, ctrl)
# deriv_by_ctrl_amps: shape (time, func, ctrl)
# return: shape (time, func, par)

return jnp.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps)

0 comments on commit b9c4046

Please sign in to comment.