<a href="https://colab.research.google.com/github/sriharikrishna/EuroAD26/blob/main/EuroAD_pt_odetest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

`odeint` in PyTorch is provided by [`torchdiffeq`](https://github.com/rtqichen/torchdiffeq). We will initially install the default release, and use a different version later.

In [1]:
!pip uninstall torchdiffeq
#!pip install torchdiffeq
!pip install git+https://github.com/rtqichen/torchdiffeq

Found existing installation: torchdiffeq 0.2.4
Uninstalling torchdiffeq-0.2.4:
  Would remove:
    /usr/local/lib/python3.10/dist-packages/torchdiffeq-0.2.4.dist-info/*
    /usr/local/lib/python3.10/dist-packages/torchdiffeq/*
Proceed (Y/n)? y
  Successfully uninstalled torchdiffeq-0.2.4
Collecting git+https://github.com/rtqichen/torchdiffeq
  Cloning https://github.com/rtqichen/torchdiffeq to /tmp/pip-req-build-bmyrenmi
  Running command git clone --filter=blob:none --quiet https://github.com/rtqichen/torchdiffeq /tmp/pip-req-build-bmyrenmi
  Resolved https://github.com/rtqichen/torchdiffeq to commit cae73789b929d4dbe8ce955dace0089cf981c1a0
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torchdiffeq
  Building wheel for torchdiffeq (setup.py) ... [?25l[?25hdone
  Created wheel for torchdiffeq: filename=torchdiffeq-0.2.4-py3-none-any.whl size=32834 sha256=152acc4982ba56bf364509f4cc517b3e4fc1a2069ae766c8282e7e406c8cd939
  Stored in director

Import relevant packages

In [1]:
import torch
from torchdiffeq import odeint as odeint

In [2]:
COMPLEX_DTYPE = torch.complex128

This defines a solver wrapper

In [3]:
def solver(y0, t_f, A):
  t = torch.linspace(0., t_f.item(), 10)
  def ode_fn(_: float, y: torch.Tensor):
    return torch.matmul(A,y)
  return odeint(ode_fn, y0, t, method='dopri5')[-1]

Let us first the code with `torch.complex128` datatype. For the default package, this will fail because of a check in the package.

In [4]:
t_i = 0.
t_f = torch.tensor(2.0, dtype=torch.float64)

y0 = torch.tensor([[1.0-1j, 9.0-2j],
                [1.0-1j, 9.0-2j]], dtype=COMPLEX_DTYPE)

A = torch.tensor([[0-1j, 1.0+2j],
               [- 100.0+3j, 0+4j]], dtype=COMPLEX_DTYPE)

print(solver(y0, t_f, A))


tensor([[ 6.1990e+05+2733558.9655j, -6.1580e+06+17204207.0352j],
        [ 1.2353e+07-10545445.5707j,  1.0485e+08-14763334.4577j]],
       dtype=torch.complex128)


We asked the developer to remove the check, which has not been released yet.
So let us change the version of `torchdiffeq` that is used, by editing the top box. Then we can rerun the preceding box.

Now that the solver call works with complex data types, we can differentiate the code. It looks a little clunky, but it works.

In [5]:
t_i = 0.
t_f = torch.tensor(2.0, dtype=torch.float64, requires_grad=True)

y0 = torch.tensor([[1.0-1j, 9.0-2j],
                [1.0-1j, 9.0-2j]], dtype=COMPLEX_DTYPE, requires_grad=True)

A = torch.tensor([[0-1j, 1.0+2j],
               [- 100.0+3j, 0+4j]], dtype=COMPLEX_DTYPE, requires_grad=True)

final_state = solver(y0, t_f, A)

final_state.grad= None
seed = torch.zeros((2,2), dtype=torch.complex64)
for i in range(2):
  for j in range(2):
    seed[i,j] = 1.0+0.j
    final_state.backward(seed,retain_graph=True)
    print("d(finalstate[",i,",",j,"])/d(y0)= ", y0.grad)
    print("d(finalstate[",i,",",j,"])/d(A)= ", A.grad)
    print("d(finalstate[",i,",",j,"])/d(t_f)= ", t_f.grad)

d(finalstate[ 0 , 0 ])/d(y0)=  tensor([[-9.3240e+05-1.9241e+06j,  1.6603e-04-8.1419e-05j],
        [-1.2443e+05+2.4739e+05j, -1.6563e-04+8.1330e-05j]],
       dtype=torch.complex128)
d(finalstate[ 0 , 0 ])/d(A)=  tensor([[  875946.4151-3135587.1573j, 13093469.8333+12996827.7377j],
        [ -390244.2989+144838.0280j,   363842.7668-2331530.9311j]],
       dtype=torch.complex128)
d(finalstate[ 0 , 0 ])/d(t_f)=  None
d(finalstate[ 0 , 1 ])/d(y0)=  tensor([[-1864803.2054-3848230.6711j,  -932401.6035-1924115.3333j],
        [ -248860.6958+494776.6413j,  -124430.3472+247388.3184j]],
       dtype=torch.complex128)
d(finalstate[ 0 , 1 ])/d(A)=  tensor([[-4.4050e+06-26582716.1339j,  1.4369e+08+51649063.6144j],
        [-2.4199e+06+2452140.2558j, -5.4315e+06-18759931.6687j]],
       dtype=torch.complex128)
d(finalstate[ 0 , 1 ])/d(t_f)=  None
d(finalstate[ 1 , 0 ])/d(y0)=  tensor([[ 9586051.4121-5371374.3273j, -1864803.2070-3848230.6665j],
        [-1307165.0240-562743.3866j,  -248860.6942+49477

Note that the derivatives w.r.t. `time` are all `None`.  So we will have to use something else.