# qutip-jax JAX backend for qutip

JAX is a numpy like library that can run on a CPU, GPU and TPU and supports automatic differentiation.
qutip-jax allows JAX array to be used to store `Qobj`'s data allowing qutip to run on GPU.

This backend will work with all qutip functions, but some may convert data to other format without warning. For example using scipy ODE will convert the state to a numpy array.

Support for `jit` and `grad` with qutip's functions is experimental. When using the right options, it is possible to run `mesolve` and `sesolve` on GPU with both compilation and auto-differentiation working. Many `Qobj` operations are also supported.

In [1]:
import jax
import qutip
import qutip_jax  # noqa: F401

The JAX backend is activated by importing the `qutip_jax` module. 
Then the formats `jax` and `jaxdia` are added to know qutip data types.
- `"jax"` stores the data as a dense Jax Array.
- `"jaxdia"` represents sparse arrays in DIAgonal format.

In [2]:
# Creating jax Qobj using the dtype argument
id_jax = qutip.qeye(3, dtype="jax")
id_jax.data_as("JaxArray")

Array([[1.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 1.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 1.+0.j]], dtype=complex128)

In [3]:
# Creating jax Qobj using a context manager
with qutip.CoreOptions(default_dtype="jaxdia"):
    id = qutip.qeye(3)
    a = qutip.destroy(3)

# Creating jax Qobj using manual conversion
sz = qutip.sigmaz().to("jaxdia")
sx = qutip.sigmax().to("jaxdia")

# Once created, most operations will conserve the data format
op = (sz & a) + (sx & id)
op

Quantum object: dims=[[2, 3], [2, 3]], shape=(6, 6), type='oper', dtype=JaxDia, isherm=False
Qobj data =
[[ 0.          1.          0.          1.          0.          0.        ]
 [ 0.          0.          1.41421356  0.          1.          0.        ]
 [ 0.          0.          0.         -0.          0.          1.        ]
 [ 1.          0.          0.          0.         -1.          0.        ]
 [ 0.          1.          0.          0.          0.         -1.41421356]
 [ 0.          0.          1.          0.          0.          0.        ]]

In [4]:
# Many functions will do operations without converting its output to numpy
qutip.expect(op, qutip.rand_dm([2, 3], dtype="jax"))

Array(-0.21467461-0.14428622j, dtype=complex128)

`jit` can be used with most linear algebra functions:

In [5]:
op = qutip.num(3, dtype="jaxdia")
state = qutip.rand_dm(3, dtype="jax")


@jax.jit
def f(op, state):
    return op @ state @ op.dag()


print(f(op, state))
%timeit op @ state @ op.dag()
%timeit f(op, state)

Quantum object: dims=[[3], [3]], shape=(3, 3), type='oper', dtype=JaxArray, isherm=True
Qobj data =
[[ 0.        +0.j          0.        +0.j          0.        +0.j        ]
 [ 0.        +0.j          0.59646831+0.j         -0.02138352-0.31528847j]
 [ 0.        +0.j         -0.02138352+0.31528847j  1.06888957+0.j        ]]


74.7 μs ± 266 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


21.5 μs ± 56.1 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


JAX can be used with `mesolve` and `sesolve` in a way that supports `jax.jit` and `jax.grad`, but specific options must be used:
- The ODE solver from diffrax must be used instead of those provided by scipy.
- `normalize_output` must be false
- Coefficient for QobjEvo must be `jitted` function.
- The isherm flag of e_ops must be pre-set.
- The class interface must be used for `jit`
- `e_data` must be used instead of expect for auto-differentiation.
- All operators and states must use `jax` or `jaxdia` format.

In [6]:
@jax.jit
def fp(t, w):
    return jax.numpy.exp(1j * t * w)


@jax.jit
def fm(t, w):
    return jax.numpy.exp(-1j * t * w)


@jax.jit
def cte(t, A):
    return A


with qutip.CoreOptions(default_dtype="jax"):
    H = qutip.num(10)
    c_ops = [qutip.QobjEvo([qutip.destroy(10), fm], args={"w": 1.0})]

H.isherm  # Precomputing the `isherm` flag

solver = qutip.MESolver(
    H, c_ops, options={"method": "diffrax", "normalize_output": False}
)


def final_expect(solver, rho0, t, w):
    result = solver.run(rho0, [0, t], args={"w": w}, e_ops=H)
    return result.e_data[0][-1].real


dfinal_expect_dt = jax.jit(
    jax.grad(final_expect, argnums=[2]), static_argnames=["solver"]
)

# TODO: use dfinal_expect_dt instead of final_expect when qutip-jax bug-fix
# dfinal_expect_dt(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)
jax.grad(final_expect, argnums=[2])(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)

  out = fun(*args, **kwargs)


(Array(-9.6348726, dtype=float64, weak_type=True),)

In [7]:
qutip.about()


QuTiP: Quantum Toolbox in Python
Copyright (c) QuTiP team 2011 and later.
Current admin team: Alexander Pitchford, Nathan Shammah, Shahnawaz Ahmed, Neill Lambert, Eric Giguère, Boxi Li, Simon Cross, Asier Galicia, Paul Menczel, and Patrick Hopf.
Board members: Daniel Burgarth, Robert Johansson, Anton F. Kockum, Franco Nori and Will Zeng.
Original developers: R. J. Johansson & P. D. Nation.
Previous lead developers: Chris Granade & A. Grimsmo.
Currently developed through wide collaboration. See https://github.com/qutip for details.

QuTiP Version:      5.3.0.dev0+d849c94
Numpy Version:      2.3.2
Scipy Version:      1.16.1
Cython Version:     3.1.3
Matplotlib Version: 3.10.5
Python Version:     3.12.0
Number of CPUs:     4
BLAS Info:          Generic
INTEL MKL Ext:      None
Platform Info:      Linux (x86_64)
Installation path:  /home/runner/miniconda3/envs/test-environment-v5/lib/python3.12/site-packages/qutip

Installed QuTiP family packages
-------------------------------

qutip-qtr