
# Testing the use of adjoint_ode

In [6]:


# Standard libraries

import seaborn as sns
from matplotlib.colors import to_rgba, to_rgb
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
import imageio
import math
import torch
import torch.nn as nn


# Predefined ResNet structures


from models.neural_odes import NeuralODE, grad_loss_inputs
from models.training import doublebackTrainer, epsTrainer
from plots.plots import histories_plt

from torch.utils import data as data
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_moons, make_circles, make_blobs
from sklearn.preprocessing import StandardScaler



# Progress bar
from tqdm.notebook import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Imports for Jupyter display

# For export. Makes the plots size right for the screen
%config InlineBackend.figure_formats = ['svg'] 
%matplotlib inline


#random seed
torch.backends.cudnn.deterministic = True
torch.manual_seed(2)
torch.cuda.manual_seed(2)


#quite fixed variables

#as we look at ResNet in dim 2 this is fixed
hidden_dim, data_dim = 2, 2

# 20 before  #T is the end time, num_steps are the amount of discretization steps for the ODE solver
T, num_steps = 5.0, 5
dt = T/num_steps
turnpike = False

bound = 0.
fp = False
cross_entropy = True
adjoint = True

shuffle = False

if turnpike:
    weight_decay = 0 if bound > 0. else dt*0.01
else:
    weight_decay = dt*0.01  # 0.01 for fp, 0.1 else

In [10]:
from torchdiffeq import odeint_adjoint, odeint

torch.set_default_tensor_type(torch.DoubleTensor)

method = 'dopri5'

class f(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, t, y):
        return y**2

x0 = torch.tensor([1., 1.], requires_grad=True)
t = torch.linspace(0., 1., 10000, requires_grad=True)

y = odeint_adjoint(f(), x0, t, method=method).view(-1)

dydt_adj = torch.autograd.grad(torch.sum(y), t)[0]
# print(dydt[0][::100])

y = odeint(f(), x0, t, method=method).view(-1)

dydt = torch.autograd.grad(torch.sum(y), t)[0]
error = dydt - dydt_adj
error = error.mean()
print(error)

tensor(-0.0004)


In [14]:
from torchdiffeq import odeint_adjoint, odeint


class Dynamics(nn.Module):
    """
    The nonlinear, right hand side $f(u(t), x(t)) of the neural ODE.
    We distinguish the different structures defined in the dictionary "architectures" just above.
    """
    def __init__(self, T=10, time_steps=10):
        super(Dynamics, self).__init__()
        
        self.non_linearity = nn.Tanh()
        self.T = T
        self.time_steps = time_steps
        
        
        ##-- R^{d_hid} -> R^{d_hid} layer --
        blocks = [nn.Linear(2, 2) for _ in range(self.time_steps)]
        self.fc2_time = nn.Sequential(*blocks)
        
    def forward(self, t, x):
        """
        The output of the class -> f(x(t), u(t)).
        """
        
        # print(f'{t=}')
        dt = self.T/(self.time_steps - 1)
        k = int(t // dt)
        
        # print(f'{k=}')
        
        
        w_t = self.fc2_time[k].weight
        b_t = self.fc2_time[k].bias
                                # w(t)\sigma(x(t))+b(t)  inner
        out = self.non_linearity(x).matmul(w_t.t()) + b_t        
        
        return out

T, num_steps = 10., 11
# method = 'euler'
method = 'dopri5'

x0 = torch.tensor([0.5, 0.5], requires_grad=True)
t = torch.linspace(0., T, num_steps, requires_grad=True)
print(t)

torch.manual_seed(4)
y = odeint(Dynamics(T=T, time_steps = num_steps), x0, t, method=method).view(-1)
dydt = torch.autograd.grad(torch.sum(y), x0)[0]
print(dydt)

torch.manual_seed(4)
y = odeint_adjoint(Dynamics(T=T, time_steps = num_steps), x0, t, method=method).view(-1)
dydt_adj = torch.autograd.grad(torch.sum(y), x0)[0]
print(dydt_adj)

print((dydt - dydt_adj).mean())
# y = odeint_adjoint(f(), x0, t, method='dopri5').view(-1)

# dydt = torch.autograd.grad(torch.sum(y), t)
# print(dydt)


tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.],
       requires_grad=True)
tensor([ 3.8617, 10.5510])
tensor([ 3.8286, 10.5576])
tensor(0.0133)
