<a href="https://colab.research.google.com/github/sriharikrishna/EuroAD26/blob/main/EuroAD_pt_odetest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

`odeint` in PyTorch is provided by [`torchdiffeq`](https://github.com/rtqichen/torchdiffeq). We will initially install the default release, and use a different version later.

In [None]:
!pip uninstall torchdiffeq
#Comment this line to use the developer's version
!pip install torchdiffeq
#Uncomment this line to use the developer's version
#!pip install git+https://github.com/rtqichen/torchdiffeq

Import relevant packages

In [None]:
import torch
from torchdiffeq import odeint as odeint

In [None]:
COMPLEX_DTYPE = torch.complex128

This defines a solver wrapper

In [None]:
def solver(y0, t_f, A):
  t = torch.linspace(0., t_f.item(), 10)
  def ode_fn(_: float, y: torch.Tensor):
    return torch.matmul(A,y)
  return odeint(ode_fn, y0, t, method='dopri5')[-1]

Let us first the code with `torch.complex128` datatype. For the default package, this will fail because of a check in the package.

In [None]:
t_i = 0.
t_f = torch.tensor(2.0, dtype=torch.float64)

y0 = torch.tensor([[1.0-1j, 9.0-2j],
                [1.0-1j, 9.0-2j]], dtype=COMPLEX_DTYPE)

A = torch.tensor([[0-1j, 1.0+2j],
               [- 100.0+3j, 0+4j]], dtype=COMPLEX_DTYPE)

print(solver(y0, t_f, A))


We asked the developer to remove the check, which has not been released yet.
So let us change the version of `torchdiffeq` that is used, by editing the top box. Then we can rerun the preceding box.

Now that the solver call works with complex data types, we can differentiate the code. It looks a little clunky, but it works.

In [None]:
t_i = 0.
t_f = torch.tensor(2.0, dtype=torch.float64, requires_grad=True)

y0 = torch.tensor([[1.0-1j, 9.0-2j],
                [1.0-1j, 9.0-2j]], dtype=COMPLEX_DTYPE, requires_grad=True)

A = torch.tensor([[0-1j, 1.0+2j],
               [- 100.0+3j, 0+4j]], dtype=COMPLEX_DTYPE, requires_grad=True)

final_state = solver(y0, t_f, A)

final_state.grad= None
seed = torch.zeros((2,2), dtype=torch.complex64)
for i in range(2):
  for j in range(2):
    seed[i,j] = 1.0+0.j
    final_state.backward(seed,retain_graph=True)
    print("d(finalstate[",i,",",j,"])/d(y0)= ", y0.grad)
    print("d(finalstate[",i,",",j,"])/d(A)= ", A.grad)
    print("d(finalstate[",i,",",j,"])/d(t_f)= ", t_f.grad)

Note that the derivatives w.r.t. `time` are all `None`.  So we will have to use something else.