-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #14 from Ericgig/JaxHigherClass
Higher level classes with jax
- Loading branch information
Showing
10 changed files
with
598 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,8 @@ install_requires = | |
qutip>=5.0.0.dev0 | ||
jax | ||
jax[cpu] | ||
diffrax | ||
equinox | ||
setup_requires = | ||
packaging | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import diffrax | ||
from qutip.solver.integrator import Integrator | ||
import jax | ||
import jax.numpy as jnp | ||
from qutip.solver.mesolve import MESolver | ||
from qutip.solver.sesolve import SESolver | ||
from qutip.core import data as _data | ||
import numpy as np | ||
from qutip_jax import JaxArray | ||
from qutip_jax.qobjevo import JaxQobjEvo | ||
|
||
__all__ = [] | ||
|
||
|
||
@jax.jit | ||
def _cplx2float(arr): | ||
return jnp.stack([arr.real, arr.imag]) | ||
|
||
|
||
@jax.jit | ||
def _float2cplx(arr): | ||
return arr[0] + 1j * arr[1] | ||
|
||
|
||
class DiffraxIntegrator(Integrator): | ||
method: str = "diffrax" | ||
supports_blackbox: bool = False # No feedback support | ||
support_time_dependant: bool = True | ||
integrator_options: dict = { | ||
"dt0": 0.0001, | ||
"solver": diffrax.Tsit5(), | ||
"stepsize_controller": diffrax.ConstantStepSize(), | ||
"max_steps": 100000, | ||
} | ||
|
||
def __init__(self, system, options): | ||
self.system = JaxQobjEvo(system) | ||
self._is_set = False # get_state can be used and return a valid state. | ||
self._options = self.integrator_options.copy() | ||
self.options = options | ||
self.ODEsystem = diffrax.ODETerm(self.dstate) | ||
self.solver_state = None | ||
self.name = f"{self.method}: {self.options['solver']}" | ||
|
||
def _prepare(self): | ||
pass | ||
|
||
@staticmethod | ||
def dstate(t, y, args): | ||
state = _float2cplx(y) | ||
H, kwargs = args | ||
d_state = H.matmul_data(t, JaxArray(state), **kwargs) | ||
return _cplx2float(d_state._jxa) | ||
|
||
def set_state(self, t, state0): | ||
self.solver_state = None | ||
self.t = t | ||
if not isinstance(state0, JaxArray): | ||
state0 = _data.to(JaxArray, state0) | ||
self.state = _cplx2float(state0._jxa) | ||
self._is_set = True | ||
|
||
def get_state(self, copy=False): | ||
return self.t, JaxArray(_float2cplx(self.state)) | ||
|
||
def integrate(self, t, copy=False, **kwargs): | ||
sol = diffrax.diffeqsolve( | ||
self.ODEsystem, | ||
t0=self.t, | ||
t1=t, | ||
y0=self.state, | ||
saveat=diffrax.SaveAt(t1=True, solver_state=True), | ||
solver_state=self.solver_state, | ||
args=(self.system, kwargs), | ||
**self._options, | ||
) | ||
self.t = t | ||
self.state = sol.ys[0, :] | ||
self.solver_state = sol.solver_state | ||
return self.get_state() | ||
|
||
def arguments(self, args): | ||
self.system = self.system.arguments(args) | ||
self.solver_state = None | ||
|
||
def _flatten(self): | ||
children = ( | ||
self.system, | ||
self._options, | ||
self.solver_state, | ||
) | ||
if self._is_set: | ||
children += (self.t, self.state,) | ||
aux_data = { | ||
"_is_set": self._is_set, | ||
} | ||
return (children, aux_data) | ||
|
||
@classmethod | ||
def _unflatten(cls, aux_data, children): | ||
out = cls.__new__(cls) | ||
out.system = children[0] | ||
out._options = children[1] | ||
out.solver_state = children[2] | ||
out._is_set = aux_data["_is_set"] | ||
if out._is_set: | ||
out.t = children[3] | ||
out.state = children[4] | ||
out.ODEsystem = diffrax.ODETerm(out.dstate) | ||
return out | ||
|
||
@property | ||
def options(self): | ||
""" | ||
Supported options by diffrax method: | ||
dt0 : float, default=0.0001 | ||
Initial step size. | ||
solver: AbstractSolver, default=Tsit5(), | ||
ODE solver instance from diffrax. | ||
stepsize_controller: AbstractStepSizeController, default=ConstantStepSize() | ||
Step size controller from diffrax. | ||
max_steps: int, default=100000 | ||
Maximum number of steps for the integration. | ||
""" | ||
return self._options | ||
|
||
@options.setter | ||
def options(self, new_options): | ||
Integrator.options.fset(self, new_options) | ||
|
||
|
||
MESolver.add_integrator(DiffraxIntegrator, 'diffrax') | ||
SESolver.add_integrator(DiffraxIntegrator, 'diffrax') | ||
jax.tree_util.register_pytree_node( | ||
DiffraxIntegrator, DiffraxIntegrator._flatten, DiffraxIntegrator._unflatten | ||
) |
Oops, something went wrong.