In [None]:
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
jax.config.update("jax_enable_x64", True)

In [None]:
def solver(y0, t_f, A):
  t = jnp.linspace(0., jnp.real(t_f), 10)
  def ode_fn(y, _, B):
    print("In ode_fn ", B[0], y)
    return jnp.matmul(B[0],y)

  return odeint(ode_fn, y0, t, (A,))[-1]

In [None]:
t_i = 0.
t_f = jnp.array(2., dtype=jnp.float64)

y0 = jnp.array([[1.0, 9.0],
               [1.0, 9.0]], dtype=jnp.float64)
A = jnp.array([[0, 1.0],
               [- 100.0, 0]], dtype=jnp.float64)

print(solver(y0, t_f, A))

In [None]:
t_f = jnp.array(2., dtype=jnp.float64)
y0 = jnp.array([[1.0, 9.0],
               [1.0, 9.0]], dtype=jnp.float64)
A = jnp.array([[0, 1.0],
               [- 100.0, 0]], dtype=jnp.float64)
jacrev_fun = jax.jacrev(solver, argnums=(1),holomorphic=False)
jac = jacrev_fun(y0, t_f, A)
print(jac)

In [None]:
t_f = jnp.array(2.+0j, dtype=jnp.float64)
y0 = jnp.array([[1.0-1j, 9.0-2j],
                [1.0-1j, 9.0-2j]], dtype=jnp.complex128)

A = jnp.array([[0-1j, 1.0+2j],
               [- 100.0+3j, 0+4j]], dtype=jnp.complex128)
jacrev_fun = jax.jacrev(solver, argnums=(1),holomorphic=True)
jac = jacrev_fun(y0, t_f, A)
print(jac)

In [None]:
t_f = jnp.array(2.+0j, dtype=jnp.complex128)
y0 = jnp.array([[1.0-1j, 9.0-2j],
                [1.0-1j, 9.0-2j]], dtype=jnp.complex128)

A = jnp.array([[0-1j, 1.0+2j],
               [- 100.0+3j, 0+4j]], dtype=jnp.complex128)
jacrev_fun = jax.jacrev(solver, argnums=(1),holomorphic=True)
jac = jacrev_fun(y0, t_f, A)
print(jac)

https://github.com/google/jax/blob/main/jax/experimental/ode.py