<a href="https://colab.research.google.com/github/sriharikrishna/EuroAD26/blob/main/EuroAD_jax_odetest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Import ```odeint``` from JAX.

In [None]:
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
jax.config.update("jax_enable_x64", True)

Define a `solver` wrapper around `odeint`

In [None]:
def solver(y0, t_f, A):
  t = jnp.linspace(0., jnp.real(t_f), 10)
  def ode_fn(y, _, B):
    return jnp.matmul(B[0],y)

  return odeint(ode_fn, y0, t, (A,))[-1]

Call the solver with `float64` datatype

In [None]:
t_i = 0.
t_f = jnp.array(2., dtype=jnp.float64)

y0 = jnp.array([[1.0, 9.0],
               [1.0, 9.0]], dtype=jnp.float64)
A = jnp.array([[0, 1.0],
               [- 100.0, 0]], dtype=jnp.float64)

print(solver(y0, t_f, A))

Call `jax.jacrev` to compute the Jacobian. There are other routines `jax.jacrev`, `jax.jvp`, `jax.vjp` that we can invoke as well.

In [None]:
t_f = jnp.array(2., dtype=jnp.float64)
y0 = jnp.array([[1.0, 9.0],
               [1.0, 9.0]], dtype=jnp.float64)
A = jnp.array([[0, 1.0],
               [- 100.0, 0]], dtype=jnp.float64)
jacrev_fun = jax.jacrev(solver, argnums=(1),holomorphic=False)
jac = jacrev_fun(y0, t_f, A)
print(jac)

Let us now try the code with complex values.

When we compute the Jacobian for function with complex outputs, we must set `holomorphic=true`. A [holomorphic function](https://en.wikipedia.org/wiki/Holomorphic_function) is a complex-valued function of one or more complex variables that is complex differentiable in a neighbourhood of each point in a domain in complex coordinate space C^n.

The following code will fail because `t_f` is not complex valued.

In [None]:
t_f = jnp.array(2., dtype=jnp.float64)
y0 = jnp.array([[1.0-1j, 9.0-2j],
                [1.0-1j, 9.0-2j]], dtype=jnp.complex128)

A = jnp.array([[0-1j, 1.0+2j],
               [- 100.0+3j, 0+4j]], dtype=jnp.complex128)
jacrev_fun = jax.jacrev(solver, argnums=(1),holomorphic=True)
jac = jacrev_fun(y0, t_f, A)
print(jac)

When we set the time to be complex valued, the complex part of the derivatives [are set to zero](https://github.com/google/jax/blob/main/jax/experimental/ode.py#L245) inside JAX.

In [None]:
t_f = jnp.array(2.+2j, dtype=jnp.complex128)
y0 = jnp.array([[1.0-1j, 9.0-2j],
                [1.0-1j, 9.0-2j]], dtype=jnp.complex128)

A = jnp.array([[0-1j, 1.0+2j],
               [- 100.0+3j, 0+4j]], dtype=jnp.complex128)
jacrev_fun = jax.jacrev(solver, argnums=(1),holomorphic=True)
jac = jacrev_fun(y0, t_f, A)
print(jac)

Uncommenting that line, gives us correct values. This is the small fix that we suggest!