# qutip-jax JAX backend for qutip

JAX is a numpy like library that can run on a CPU, GPU and TPU and supports automatic differentiation.
qutip-jax allows JAX array to be used to store `Qobj`'s data allowing qutip to run on GPU.

This backend will work with all qutip functions, but some may convert data to other format without warning. For example using scipy ODE will convert the state to a numpy array.

Support for `jit` and `grad` with qutip's functions is experimental. When using the right options, it is possible to run `mesolve` and `sesolve` on GPU with both compilation and auto-differentiation working. Many `Qobj` operations are also supported.

In [None]:
import jax
import qutip
import qutip_jax  # noqa: F401

The JAX backend is activated by importing the `qutip_jax` module. 
Then the formats `jax` and `jaxdia` are added to know qutip data types.
- `"jax"` stores the data as a dense Jax Array.
- `"jaxdia"` represents sparse arrays in DIAgonal format.

In [None]:
# Creating jax Qobj using the dtype argument
id_jax = qutip.qeye(3, dtype="jax")
id_jax.data_as("JaxArray")

In [None]:
# Creating jax Qobj using a context manager
with qutip.CoreOptions(default_dtype="jaxdia"):
    id = qutip.qeye(3)
    a = qutip.destroy(3)

# Creating jax Qobj using manual conversion
sz = qutip.sigmaz().to("jaxdia")
sx = qutip.sigmax().to("jaxdia")

# Once created, most operations will conserve the data format
op = (sz & a) + (sx & id)
op

In [None]:
# Many functions will do operations without converting its output to numpy
qutip.expect(op, qutip.rand_dm([2, 3], dtype="jax"))

`jit` can be used with most linear algebra functions:

In [None]:
op = qutip.num(3, dtype="jaxdia")
state = qutip.rand_dm(3, dtype="jax")


@jax.jit
def f(op, state):
    return op @ state @ op.dag()


print(f(op, state))
%timeit op @ state @ op.dag()
%timeit f(op, state)

JAX can be used with `mesolve` and `sesolve` in a way that supports `jax.jit` and `jax.grad`, but specific options must be used:
- The ODE solver from diffrax must be used instead of those provided by scipy.
- `normalize_output` must be false
- Coefficient for QobjEvo must be `jitted` function.
- The isherm flag of e_ops must be pre-set.
- The class interface must be used for `jit`
- `e_data` must be used instead of expect for auto-differentiation.
- All operators and states must use `jax` or `jaxdia` format.

In [None]:
@jax.jit
def fp(t, w):
    return jax.numpy.exp(1j * t * w)


@jax.jit
def fm(t, w):
    return jax.numpy.exp(-1j * t * w)


@jax.jit
def cte(t, A):
    return A


with qutip.CoreOptions(default_dtype="jax"):
    H = qutip.num(10)
    c_ops = [qutip.QobjEvo([qutip.destroy(10), fm], args={"w": 1.0})]

H.isherm  # Precomputing the `isherm` flag

solver = qutip.MESolver(
    H, c_ops, options={"method": "diffrax", "normalize_output": False}
)


def final_expect(solver, rho0, t, w):
    result = solver.run(rho0, [0, t], args={"w": w}, e_ops=H)
    return result.e_data[0][-1].real


dfinal_expect_dt = jax.jit(
    jax.grad(final_expect, argnums=[2]), static_argnames=["solver"]
)

# TODO: use dfinal_expect_dt instead of final_expect when qutip-jax bug-fix
# dfinal_expect_dt(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)
jax.grad(final_expect, argnums=[2])(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)

In [None]:
qutip.about()