Skip to content

Commit

Permalink
Merge pull request #14 from Ericgig/JaxHigherClass
Browse files Browse the repository at this point in the history
Higher level classes with jax
  • Loading branch information
Ericgig committed Jun 6, 2023
2 parents 2a1fed5 + 1325111 commit 635af75
Show file tree
Hide file tree
Showing 10 changed files with 598 additions and 12 deletions.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ install_requires =
qutip>=5.0.0.dev0
jax
jax[cpu]
diffrax
equinox
setup_requires =
packaging

Expand Down
3 changes: 3 additions & 0 deletions src/qutip_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@
from .properties import *
from .linalg import *
from .create import *
from .qobjevo import *
from .ode import *
from .qutip_trees import *
8 changes: 5 additions & 3 deletions src/qutip_jax/binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def sub_jaxarray(left, right):

def mul_jaxarray(matrix, value):
"""Multiply a matrix element-wise by a scalar."""
return JaxArray._fast_constructor(matrix._jxa * value, matrix.shape)
# We don't want to check values type in case jax pass a tracer etc.
# But we want to ensure the output is a matrix, thus don't use the
# fast constructor.
return JaxArray(matrix._jxa * value)


def matmul_jaxarray(left, right, scale=1, out=None):
Expand All @@ -97,8 +100,7 @@ def matmul_jaxarray(left, right, scale=1, out=None):

result = left._jxa @ right._jxa

if scale != 1 or not isinstance(scale, int):
result *= scale
result *= scale

if out is None:
return JaxArray._fast_constructor(result, shape=shape)
Expand Down
38 changes: 30 additions & 8 deletions src/qutip_jax/jaxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@


class JaxArray(Data):
_jxa: jnp.ndarray
shape: tuple

def __init__(self, data, shape=None, copy=None):
jxa = jnp.array(data, dtype=jnp.complex128)

Expand All @@ -24,7 +27,6 @@ def __init__(self, data, shape=None, copy=None):
shape = (1, 1)
if len(shape) == 1:
shape = (shape[0], 1)

if not (
isinstance(shape, tuple)
and len(shape) == 2
Expand All @@ -37,15 +39,11 @@ def __init__(self, data, shape=None, copy=None):
"""Shape must be a 2-tuple of positive ints, but is """
+ repr(shape)
)

if np.prod(shape) != np.prod(data.shape):
raise ValueError("Shape of data does not match argument.")

# if copy:
# # Since jax's arrays are immutable, we could probably skip this.
# data = data.copy()
self._jxa = jxa.reshape(shape)
super().__init__(shape)
Data.__init__(self, shape)

def copy(self):
return self.__class__(self._jxa, copy=True)
Expand All @@ -65,6 +63,24 @@ def adjoint(self):
def trace(self):
return jnp.trace(self._jxa)

def __add__(self, other):
if isinstance(other, JaxArray):
out = self._jxa + other._jxa
return JaxArray._fast_constructor(out, out.shape)
return NotImplemented

def __sub__(self, other):
if isinstance(other, JaxArray):
out = self._jxa - other._jxa
return JaxArray._fast_constructor(out, out.shape)
return NotImplemented

def __matmul__(self, other):
if isinstance(other, JaxArray):
out = self._jxa @ other._jxa
return JaxArray._fast_constructor(out, out.shape)
return NotImplemented

@classmethod
def _fast_constructor(cls, array, shape):
out = cls.__new__(cls)
Expand All @@ -74,12 +90,18 @@ def _fast_constructor(cls, array, shape):

def _tree_flatten(self):
children = (self._jxa,) # arrays / dynamic values
aux_data = {"shape": self.shape} # static values
aux_data = {} # static values
return (children, aux_data)

@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
# unflatten should not check data validity
# jax can pass tracer, object, etc.
out = cls.__new__(cls)
out._jxa = children[0]
shape = getattr(out._jxa, "shape", (1,1))
Data.__init__(out, shape)
return out


tree_util.register_pytree_node(
Expand Down
2 changes: 1 addition & 1 deletion src/qutip_jax/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def expect_jaxarray(op, state):
if state._jxa.shape[0] == state._jxa.shape[1]:
out = jnp.sum(op._jxa * state._jxa.T)
else:
out = state._jxa.T.conj() @ op._jxa @ state._jxa
out = (state._jxa.T.conj() @ op._jxa @ state._jxa)[0, 0]
return out


Expand Down
140 changes: 140 additions & 0 deletions src/qutip_jax/ode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import diffrax
from qutip.solver.integrator import Integrator
import jax
import jax.numpy as jnp
from qutip.solver.mesolve import MESolver
from qutip.solver.sesolve import SESolver
from qutip.core import data as _data
import numpy as np
from qutip_jax import JaxArray
from qutip_jax.qobjevo import JaxQobjEvo

__all__ = []


@jax.jit
def _cplx2float(arr):
return jnp.stack([arr.real, arr.imag])


@jax.jit
def _float2cplx(arr):
return arr[0] + 1j * arr[1]


class DiffraxIntegrator(Integrator):
method: str = "diffrax"
supports_blackbox: bool = False # No feedback support
support_time_dependant: bool = True
integrator_options: dict = {
"dt0": 0.0001,
"solver": diffrax.Tsit5(),
"stepsize_controller": diffrax.ConstantStepSize(),
"max_steps": 100000,
}

def __init__(self, system, options):
self.system = JaxQobjEvo(system)
self._is_set = False # get_state can be used and return a valid state.
self._options = self.integrator_options.copy()
self.options = options
self.ODEsystem = diffrax.ODETerm(self.dstate)
self.solver_state = None
self.name = f"{self.method}: {self.options['solver']}"

def _prepare(self):
pass

@staticmethod
def dstate(t, y, args):
state = _float2cplx(y)
H, kwargs = args
d_state = H.matmul_data(t, JaxArray(state), **kwargs)
return _cplx2float(d_state._jxa)

def set_state(self, t, state0):
self.solver_state = None
self.t = t
if not isinstance(state0, JaxArray):
state0 = _data.to(JaxArray, state0)
self.state = _cplx2float(state0._jxa)
self._is_set = True

def get_state(self, copy=False):
return self.t, JaxArray(_float2cplx(self.state))

def integrate(self, t, copy=False, **kwargs):
sol = diffrax.diffeqsolve(
self.ODEsystem,
t0=self.t,
t1=t,
y0=self.state,
saveat=diffrax.SaveAt(t1=True, solver_state=True),
solver_state=self.solver_state,
args=(self.system, kwargs),
**self._options,
)
self.t = t
self.state = sol.ys[0, :]
self.solver_state = sol.solver_state
return self.get_state()

def arguments(self, args):
self.system = self.system.arguments(args)
self.solver_state = None

def _flatten(self):
children = (
self.system,
self._options,
self.solver_state,
)
if self._is_set:
children += (self.t, self.state,)
aux_data = {
"_is_set": self._is_set,
}
return (children, aux_data)

@classmethod
def _unflatten(cls, aux_data, children):
out = cls.__new__(cls)
out.system = children[0]
out._options = children[1]
out.solver_state = children[2]
out._is_set = aux_data["_is_set"]
if out._is_set:
out.t = children[3]
out.state = children[4]
out.ODEsystem = diffrax.ODETerm(out.dstate)
return out

@property
def options(self):
"""
Supported options by diffrax method:
dt0 : float, default=0.0001
Initial step size.
solver: AbstractSolver, default=Tsit5(),
ODE solver instance from diffrax.
stepsize_controller: AbstractStepSizeController, default=ConstantStepSize()
Step size controller from diffrax.
max_steps: int, default=100000
Maximum number of steps for the integration.
"""
return self._options

@options.setter
def options(self, new_options):
Integrator.options.fset(self, new_options)


MESolver.add_integrator(DiffraxIntegrator, 'diffrax')
SESolver.add_integrator(DiffraxIntegrator, 'diffrax')
jax.tree_util.register_pytree_node(
DiffraxIntegrator, DiffraxIntegrator._flatten, DiffraxIntegrator._unflatten
)
Loading

0 comments on commit 635af75

Please sign in to comment.