In [5]:
from jax import vmap, jit
import jax.numpy as jnp
from jax.scipy.linalg import expm as jexpm

from functools import partial

from qopt.solver_algorithms import SchroedingerSolver

class JaxSchroedingerSolver(SchroedingerSolver):

    @staticmethod
    def _expm(matrix):
        return jexpm(matrix)

    def ham(self, control_values):

        ctrl_val_times_ctrl_ham = vmap(lambda x, y: x * y)(
            control_values, self.h_ctrl)

        tot_control_ham = jnp.sum(ctrl_val_times_ctrl_ham, axis=0)

        return self.h_drift + tot_control_ham

    def _ham_seq(self, control_seq):
        return vmap(self.ham)(control_seq)

    def _time_step_unitary(self, ham_matrix, dt):
        unitary = self._expm(-1j * ham_matrix * dt)
        return unitary

    def _unitary_seq_from_ham_seq(self, hamiltonians, dt_s):
        return vmap(self._time_step_unitary)(hamiltonians, dt_s)

    @partial(jit, static_argnums=0)
    def _unitary_seq(self, control_seq, dt_s):
        hamiltonians = self._ham_seq(control_seq)
        return self._unitary_seq_from_ham_seq(hamiltonians, dt_s)

    def _compute_propagation(self):
        self._prop = self._unitary_seq(self._ctrl_amps, self.transferred_time)

ModuleNotFoundError: No module named 'jax'

In [None]:
import jax