<a href="https://colab.research.google.com/github/sriharikrishna/EuroAD26/blob/main/diffrax_odetest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Install ```diffrax```.

In [None]:
!pip uninstall diffrax
!pip install git+https://github.com/patrick-kidger/diffrax.git

Imports

In [None]:
from diffrax import diffeqsolve, ODETerm, Dopri5, Tsit5
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

Define a `solver` wrapper around `diffeqsolve`. You can choose between `Dopri5()` and `Tsit5()`.

In [None]:
def solver(y0, t_i, t_f, A):
  def ode_fn(t, y, B):
    return jnp.matmul(B[0],y)
  term = ODETerm(ode_fn)
  #ODEsolver = Dopri5()
  ODEsolver = Tsit5()
  solution = diffeqsolve(term, ODEsolver, t0=t_i, t1=t_f, dt0=0.2, y0=y0, args=(A,))
  return solution.ys


Call the solver with `float64` datatype

In [None]:
t_i = 0.
t_f = jnp.array(2., dtype=jnp.float64)

y0 = jnp.array([[1.0, 9.0],
               [1.0, 9.0]], dtype=jnp.float64)
A = jnp.array([[0, 1.0],
               [- 100.0, 0]], dtype=jnp.float64)

solution = solver(y0, t_i, t_f, A)

print(solution)

Let us now try the code with complex values.

In [None]:
t_i = 0.
t_f = jnp.array(2., dtype=jnp.float64)
y0 = jnp.array([[1.0-1j, 9.0-2j],
                [1.0-1j, 9.0-2j]], dtype=jnp.complex128)
A = jnp.array([[0-1j, 1.0+2j],
               [- 100.0+3j, 0+4j]], dtype=jnp.complex128)

solution = solver(y0, t_i, t_f, A)

print(solution)

When we compute the Jacobian for function with complex outputs, we must set `holomorphic=true`. A [holomorphic function](https://en.wikipedia.org/wiki/Holomorphic_function) is a complex-valued function of one or more complex variables that is complex differentiable in a neighbourhood of each point in a domain in complex coordinate space C^n.

We must also set `t_i` and `t_f` to be complex valued.

In [None]:
t_i = jnp.array(0.+0j, dtype=jnp.complex128)
t_f = jnp.array(2.+0j, dtype=jnp.complex128)
y0 = jnp.array([[1.0-1j, 9.0-2j],
                [1.0-1j, 9.0-2j]], dtype=jnp.complex128)

A = jnp.array([[0-1j, 1.0+2j],
               [- 100.0+3j, 0+4j]], dtype=jnp.complex128)
jacrev_fun = jax.jacrev(solver, argnums=(2),holomorphic=True)
jac = jacrev_fun(y0, t_i, t_f, A)
print(jac)