In [2]:
## Import the package and set up the device
import torch
import torch.nn as nn
import numpy as np 
import time

from __future__ import print_function
import argparse
import numpy as np
from inspect import getfullargspec

from typing import List, Tuple, Union, Callable, Dict, Iterable, Generator
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR, MultiStepLR
from torch import Tensor
from collections import namedtuple
from warnings import warn
import torch.utils.data as data


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Using device:", device)

Using device: cuda


In [3]:
## Some useful functions: 

def hairer_norm(tensor):
    return tensor.abs().pow(2).mean().sqrt()

def standardize_vf_call_signature(vector_field, order=1, defunc_wrap=False):
    "Ensures Callables or nn.Modules passed to `ODEProblems` and `NeuralODE` have consistent `__call__` signature (t, x)"
    
    if issubclass(type(vector_field), nn.Module):
        if 't' not in getfullargspec(vector_field.forward).args:
            print("Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, "
                "we've wrapped it for you.")
            vector_field = DEFuncBase(vector_field, has_time_arg=False)
    else: 
        # argspec for lambda functions needs to be done on the function itself
        if 't' not in getfullargspec(vector_field).args:
            print("Your vector field callable (lambda) should have both time `t` and state `x` as arguments, "
                "we've wrapped it for you.")
            vector_field = DEFuncBase(vector_field, has_time_arg=False)   
        else: vector_field = DEFuncBase(vector_field, has_time_arg=True) 
    if defunc_wrap: return DEFunc(vector_field, order)
    else: return vector_field
    


def init_step(f, f0, x0, t0, order, atol, rtol):
    scale = atol + torch.abs(x0) * rtol
    d0, d1 = hairer_norm(x0 / scale), hairer_norm(f0 / scale)

    if d0 < 1e-5 or d1 < 1e-5:
        h0 = torch.tensor(1e-6, dtype=t0.dtype, device=t0.device)
    else:
        h0 = 0.01 * d0 / d1

    x_new = x0 + h0 * f0
    f_new = f(t0 + h0, x_new)
    d2 = hairer_norm((f_new - f0) / scale) / h0
    if d1 <= 1e-15 and d2 <= 1e-15:
        h1 = torch.max(torch.tensor(1e-6, dtype=t0.dtype, device=t0.device), h0 * 1e-3)
    else:
        h1 = (0.01 / max(d1, d2)) ** (1. / float(order + 1))
    dt = torch.min(100 * h0, h1).to(t0)
    return dt

@torch.no_grad()
def adapt_step(dt, error_ratio, safety, min_factor, max_factor, order):
    if error_ratio == 0: return dt * max_factor
    if error_ratio < 1: min_factor = torch.ones_like(dt)
    exponent = torch.tensor(order, dtype=dt.dtype, device=dt.device).reciprocal()
    factor = torch.min(max_factor, torch.max(safety / error_ratio ** exponent, min_factor))
    return dt * factor




In [4]:
## Some useful class:

class DEFuncBase(nn.Module):
    def __init__(self, vector_field: Callable, has_time_arg: bool = True):
        """Basic wrapper to ensure call signature compatibility between generic torch Modules and vector fields.
        Args:
            vector_field (Callable): callable defining the dynamics / vector field / `dxdt` / forcing function
            has_time_arg (bool, optional): Internal arg. to indicate whether the callable has `t` in its `__call__'
                or `forward` method. Defaults to True.
        """
        super().__init__()
        self.nfe, self.vf, self.has_time_arg = 0.0, vector_field, has_time_arg

    def forward(self, t: Tensor, x: Tensor, args: Dict = {}) -> Tensor:
        self.nfe += 1
        if self.has_time_arg:
            return self.vf(t, x, args=args)
        else:
            return self.vf(x)
        

class DEFunc(nn.Module):
    def __init__(self, vector_field: Callable, order: int = 1):
        """Special vector field wrapper for Neural ODEs.

        Handles auxiliary tasks: time ("depth") concatenation, higher-order dynamics and forward propagated integral losses.

        Args:
            vector_field (Callable): callable defining the dynamics / vector field / `dxdt` / forcing function
            order (int, optional): order of the differential equation. Defaults to 1.

        Notes:
            Currently handles the following:
            (1) assigns time tensor to each submodule requiring it (e.g. `GalLinear`).
            (2) in case of integral losses + reverse-mode differentiation, propagates the loss in the first dimension of `x`
                and automatically splits the Tensor into `x[:, 0]` and `x[:, 1:]` for vector field computation
            (3) in case of higher-order dynamics, adjusts the vector field forward to recursively compute various orders.
        """
        super().__init__()
        self.vf, self.nfe, = vector_field, 0.0
        self.order, self.integral_loss, self.sensitivity = order, None, None
        # identify whether vector field already has time arg

    def forward(self, t: Tensor, x: Tensor, args: Dict = {}) -> Tensor:
        self.nfe += 1
        # set `t` depth-variable to DepthCat modules
        for _, module in self.vf.named_modules():
            if hasattr(module, "t"):
                module.t = t

        # if-else to handle autograd training with integral loss propagated in x[:, 0]
        if (self.integral_loss is not None) and self.sensitivity == "autograd":
            x_dyn = x[:, 1:]
            dlds = self.integral_loss(t, x_dyn)
            if len(dlds.shape) == 1:
                dlds = dlds[:, None]
            if self.order > 1:
                x_dyn = self.horder_forward(t, x_dyn, args)
            else:
                x_dyn = self.vf(t, x_dyn)
            return cat([dlds, x_dyn], 1).to(x_dyn)

        # regular forward
        else:
            if self.order > 1:
                x = self.higher_order_forward(t, x)
            else:
                x = self.vf(t, x)
            return x

In [5]:
## The table for DP45 and tsit5
ExplicitRKTableau = namedtuple('ExplicitRKTableau', 'c, A, b_sol, b_err')


def construct_dopri5(dtype):
    c = torch.tensor([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], dtype=dtype)
    a = [
        torch.tensor([1 / 5], dtype=dtype),
        torch.tensor([3 / 40, 9 / 40], dtype=dtype),
        torch.tensor([44 / 45, -56 / 15, 32 / 9], dtype=dtype),
        torch.tensor([19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], dtype=dtype),
        torch.tensor([9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], dtype=dtype),
        torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], dtype=dtype),
    ]
    bsol = torch.tensor([35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], dtype=dtype)
    berr = torch.tensor([1951 / 21600, 0, 22642 / 50085, 451 / 720, -12231 / 42400, 649 / 6300, 1 / 60.], dtype=dtype)

    dmid = torch.tensor([-1.1270175653862835, 0., 2.675424484351598, -5.685526961588504, 3.5219323679207912,
                         -1.7672812570757455, 2.382468931778144])
    return (c, a, bsol, bsol - berr)

def construct_tsit5(dtype):

    c = torch.tensor([
        161 / 1000,
        327 / 1000,
        9 / 10,
        .9800255409045096857298102862870245954942137979563024768854764293221195950761080302604,
        1.,
        1.
    ], dtype=dtype)
    a = [
        torch.tensor([
            161 / 1000
        ], dtype=dtype),
        torch.tensor([
            -.8480655492356988544426874250230774675121177393430391537369234245294192976164141156943e-2,
            .3354806554923569885444268742502307746751211773934303915373692342452941929761641411569
        ], dtype=dtype),
        torch.tensor([
            2.897153057105493432130432594192938764924887287701866490314866693455023795137503079289,
            -6.359448489975074843148159912383825625952700647415626703305928850207288721235210244366,
            4.362295432869581411017727318190886861027813359713760212991062156752264926097707165077
        ], dtype=dtype),
        torch.tensor([
            5.325864828439256604428877920840511317836476253097040101202360397727981648835607691791,
            -11.74888356406282787774717033978577296188744178259862899288666928009020615663593781589,
            7.495539342889836208304604784564358155658679161518186721010132816213648793440552049753,
            -.9249506636175524925650207933207191611349983406029535244034750452930469056411389539635e-1
        ], dtype=dtype),
        torch.tensor([
            5.861455442946420028659251486982647890394337666164814434818157239052507339770711679748,
            -12.92096931784710929170611868178335939541780751955743459166312250439928519268343184452,
            8.159367898576158643180400794539253485181918321135053305748355423955009222648673734986,
            -.7158497328140099722453054252582973869127213147363544882721139659546372402303777878835e-1,
            -.2826905039406838290900305721271224146717633626879770007617876201276764571291579142206e-1
        ], dtype=dtype),
        torch.tensor([
            .9646076681806522951816731316512876333711995238157997181903319145764851595234062815396e-1,
            1 / 100,
            .4798896504144995747752495322905965199130404621990332488332634944254542060153074523509,
            1.379008574103741893192274821856872770756462643091360525934940067397245698027561293331,
            -3.290069515436080679901047585711363850115683290894936158531296799594813811049925401677,
            2.324710524099773982415355918398765796109060233222962411944060046314465391054716027841
        ], dtype=dtype),
    ]
    bsol = torch.tensor([
        .9646076681806522951816731316512876333711995238157997181903319145764851595234062815396e-1,
        1 / 100,
        .4798896504144995747752495322905965199130404621990332488332634944254542060153074523509,
        1.379008574103741893192274821856872770756462643091360525934940067397245698027561293331,
        -3.290069515436080679901047585711363850115683290894936158531296799594813811049925401677,
        2.324710524099773982415355918398765796109060233222962411944060046314465391054716027841,
        0.
    ], dtype=dtype)
    berr = torch.tensor([
        .9468075576583945807478876255758922856117527357724631226139574065785592789071067303271e-1,
        .9183565540343253096776363936645313759813746240984095238905939532922955247253608687270e-2,
        .4877705284247615707855642599631228241516691959761363774365216240304071651579571959813,
        1.234297566930478985655109673884237654035539930748192848315425833500484878378061439761,
        -2.707712349983525454881109975059321670689605166938197378763992255714444407154902012702,
        1.866628418170587035753719399566211498666255505244122593996591602841258328965767580089,
        1 / 66.,
    ], dtype=dtype)
    return (c, a, bsol, bsol - berr)

In [6]:
## Differential solver class
class DiffEqSolver(nn.Module):
    def __init__(
            self, 
            order, 
            stepping_class:str="fixed", 
            min_factor:float=0.2, 
            max_factor:float=10, 
            safety:float=0.9
        ):

        super(DiffEqSolver, self).__init__()
        self.order = order
        self.min_factor = torch.tensor([min_factor])
        self.max_factor = torch.tensor([max_factor])
        self.safety = torch.tensor([safety])
        self.tableau = None
        self.stepping_class = stepping_class

    def sync_device_dtype(self, x, t_span):
        "Ensures `x`, `t_span`, `tableau` and other solver tensors are on the same device with compatible dtypes"
        device = x.device
        if self.tableau is not None:
            c, a, bsol, berr = self.tableau
            self.tableau = c.to(x), [a.to(x) for a in a], bsol.to(x), berr.to(x)
        t_span = t_span.to(device)
        self.safety = self.safety.to(device)
        self.min_factor = self.min_factor.to(device)
        self.max_factor = self.max_factor.to(device)
        return x, t_span

    def step(self, f, x, t, dt, k1=None, args=None):
        raise NotImplementedError("Stepping rule not implemented for the solver")

In [7]:
## Copy the hardcoded DP45 method
class DormandPrince45(DiffEqSolver):
    def __init__(self, dtype=torch.float32):
        super().__init__(order=5)
        self.dtype = dtype
        self.stepping_class = 'fixed'
        self.tableau = construct_dopri5(self.dtype)

    def step(self, f, x, t, dt, k1=None, args=None):
        c, a, bsol, berr = self.tableau
        if k1 == None: k1 = f(t, x)
        k2 = f(t + c[0] * dt, x + dt * a[0] * k1)
        k3 = f(t + c[1] * dt, x + dt * (a[1][0] * k1 + a[1][1] * k2))
        k4 = f(t + c[2] * dt, x + dt * a[2][0] * k1 + dt * a[2][1] * k2 + dt * a[2][2] * k3)
        k5 = f(t + c[3] * dt, x + dt * a[3][0] * k1 + dt * a[3][1] * k2 + dt * a[3][2] * k3 + dt * a[3][3] * k4)
        k6 = f(t + c[4] * dt, x + dt * a[4][0] * k1 + dt * a[4][1] * k2 + dt * a[4][2] * k3 + dt * a[4][3] * k4 + dt * a[4][4] * k5)
        k7 = f(t + c[5] * dt, x + dt * a[5][0] * k1 + dt * a[5][1] * k2 + dt * a[5][2] * k3 + dt * a[5][3] * k4 + dt * a[5][4] * k5 + dt * a[5][5] * k6)
        x_sol = x + dt * (bsol[0] * k1 + bsol[1] * k2 + bsol[2] * k3 + bsol[3] * k4 + bsol[4] * k5 + bsol[5] * k6)
        err = dt * (berr[0] * k1 + berr[1] * k2 + berr[2] * k3 + berr[3] * k4 + berr[4] * k5 + berr[5] * k6 + berr[6] * k7)
        return k7, x_sol, err, (k1, k2, k3, k4, k5, k6, k7)

In [8]:
## Copy the hardcoded Tsitouras45 method
class Tsitouras45(DiffEqSolver):
    def __init__(self, dtype=torch.float32):
        super().__init__(order=5)
        self.dtype = dtype
        self.stepping_class = 'adaptive'
        self.tableau = construct_tsit5(self.dtype)

    def step(self, f, x, t, dt, k1=None, args=None) -> Tuple:
        c, a, bsol, berr = self.tableau
        if k1 == None: k1 = f(t, x)
        k2 = f(t + c[0] * dt, x + dt * a[0][0] * k1)
        k3 = f(t + c[1] * dt, x + dt * (a[1][0] * k1 + a[1][1] * k2))
        k4 = f(t + c[2] * dt, x + dt * a[2][0] * k1 + dt * a[2][1] * k2 + dt * a[2][2] * k3)
        k5 = f(t + c[3] * dt, x + dt * a[3][0] * k1 + dt * a[3][1] * k2 + dt * a[3][2] * k3 + dt * a[3][3] * k4)
        k6 = f(t + c[4] * dt, x + dt * a[4][0] * k1 + dt * a[4][1] * k2 + dt * a[4][2] * k3 + dt * a[4][3] * k4 + dt * a[4][4] * k5)
        k7 = f(t + c[5] * dt, x + dt * a[5][0] * k1 + dt * a[5][1] * k2 + dt * a[5][2] * k3 + dt * a[5][3] * k4 + dt * a[5][4] * k5 + dt * a[5][5] * k6)
        x_sol = x + dt * (bsol[0] * k1 + bsol[1] * k2 + bsol[2] * k3 + bsol[3] * k4 + bsol[4] * k5 + bsol[5] * k6)
        err = dt * (berr[0] * k1 + berr[1] * k2 + berr[2] * k3 + berr[3] * k4 + berr[4] * k5 + berr[5] * k6 + berr[6] * k7)
        return k7, x_sol, err, (k1, k2, k3, k4, k5, k6, k7)


In [9]:
## Define the fixed time solver:
def _fixed_odeint(f, x, t_span, solver, save_at=(), args={}):
    """Solves IVPs with same `t_span`, using fixed-step methods"""
    if len(save_at) == 0: save_at = t_span
    if not isinstance(save_at, torch.Tensor):
        save_at = torch.tensor(save_at)

    assert all(torch.isclose(t, save_at).sum() == 1 for t in save_at),\
        "each element of save_at [torch.Tensor] must be contained in t_span [torch.Tensor] once and only once"

    t, T, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]

    sol = []
    if torch.isclose(t, save_at).sum():
        sol = [x]

    steps = 1
    while steps <= len(t_span) - 1:
        _, x, _,_ = solver.step(f, x.squeeze(0), t.squeeze(0), dt, k1=None, args=args)
        t = t + dt

        if torch.isclose(t, save_at).sum():
            sol.append(x)
        if steps < len(t_span) - 1: dt = t_span[steps+1] - t
        steps += 1

    if isinstance(sol[0], dict):
        final_out = {k: [v] for k, v in sol[0].items()}
        _ = [final_out[k].append(x[k]) for k in x.keys() for x in sol[1:]]
        final_out = {k: torch.stack(v) for k, v in final_out.items()}
    elif isinstance(sol[0], torch.Tensor):
        final_out = torch.stack(sol)
    else:
        raise NotImplementedError(f"{type(x)} is not supported as the state variable")

    return save_at, final_out

In [21]:
## Introduce the adaptive method
def _adaptive_odeint(f, k1, x, dt, t_span, solver, atol=1e-4, rtol=1e-4, args=None, interpolator=None, return_all_eval=False, seminorm=(False, None)):
    """Adaptive ODE solve routine, called by `odeint`.

    Args:
        f ([type]):
        k1 ([type]):
        x ([type]):
        dt ([type]):
        t_span ([type]):
        solver ([type]):
        atol ([type], optional): Defaults to 1e-4.
        rtol ([type], optional): Defaults to 1e-4.
        args (Dict):
        use_interp (bool, optional):
        return_all_eval (bool, optional): Defaults to False.


    Notes:
        (1) We check if the user wants all evaluated solution points, not only those
        corresponding to times in `t_span`. This is automatically set to `True` when `odeint`
        is called for interpolated adjoints
    """
    x, t_span = solver.sync_device_dtype(x, t_span)
    t_eval, t, T = t_span[1:], t_span[:1], t_span[-1]
    ckpt_counter, ckpt_flag = 0, False
    eval_times, sol = [t], [x]
    while t < T:
        if t + dt > T:
            dt = T - t
        ############### checkpointing ###############################
        if t_eval is not None:
            # satisfy checkpointing by using interpolation scheme or resetting `dt`
            if (ckpt_counter < len(t_eval)) and (t + dt > t_eval[ckpt_counter]):
                if interpolator == None:
                    # save old dt, raise "checkpoint" flag and repeat step
                    dt_old, ckpt_flag = dt, True
                    dt = t_eval[ckpt_counter] - t

        f_new, x_new, x_err, stages = solver.step(f, x, t, dt, k1=k1, args=args)
        ################# compute error #############################
        if seminorm[0] == True:
            state_dim = seminorm[1]
            error = x_err[:state_dim]
            error_scaled = error / (atol + rtol * torch.max(x[:state_dim].abs(), x_new[:state_dim].abs()))
        else:
            error = x_err
            error_scaled = error / (atol + rtol * torch.max(x.abs(), x_new.abs()))
        error_ratio = hairer_norm(error_scaled)
        accept_step = error_ratio <= 1

        if accept_step:
            ############### checkpointing via interpolation ###############################
            if t_eval is not None and interpolator is not None:
                coefs = None
                while (ckpt_counter < len(t_eval)) and (t + dt > t_eval[ckpt_counter]):
                    t0, t1 = t, t + dt
                    x_mid = x + dt * sum([interpolator.bmid[i] * stages[i] for i in range(len(stages))])
                    f0, f1, x0, x1 = k1, f_new, x, x_new
                    if coefs == None: coefs = interpolator.fit(dt, f0, f1, x0, x1, x_mid)
                    x_in = interpolator.evaluate(coefs, t0, t1, t_eval[ckpt_counter])
                    sol.append(x_in)
                    eval_times.append(t_eval[ckpt_counter][None])
                    ckpt_counter += 1

            if t + dt == t_eval[ckpt_counter] or return_all_eval: # note (1)
                sol.append(x_new)
                eval_times.append(t + dt)
                # we only increment the ckpt counter if the solution points corresponds to a time point in `t_span`
                if t + dt == t_eval[ckpt_counter]: ckpt_counter += 1
            t, x = t + dt, x_new
            k1 = f_new

        ################ stepsize control ###########################
        # reset "dt" in case of checkpoint without interp
        if ckpt_flag:
            dt = dt_old - dt
            ckpt_flag = False
        dt = adapt_step(dt, error_ratio,
                        solver.safety,
                        solver.min_factor,
                        solver.max_factor,
                        solver.order)
    return torch.cat(eval_times), torch.stack(sol)

## Test fixed step EPI2 with one step DP45

In [11]:
## Next, Implement the EPI2 method by only one step forward using DP45 as a sub-class of DP45:
class EPI2_one_step(DormandPrince45):
    def __init__(self, dtype=torch.float32):
        super().__init__()
        self.dtype = dtype
        self.stepping_class = 'fixed'

    def step(self, f, x_n, t, dt, k1=None, args=None):
        
        init_cond_epi=torch.zeros_like(x_n)
        constants=f(t,x_n).detach()
        
        ## Recall that this step of EPI2 is going to update x to x+dt using EPI2 step under dynamics f
        def f_combined(t, x):
            x.requires_grad_(True)
            # Compute the original dynamics f(x, t)
            f_x_t = f(t, x).detach()

            # Compute the Jacobian-vector product (A_n * x) using torch.autograd.functional.jacobian
            _, A_x_product = torch.autograd.functional.jvp(lambda x: f(t, x), x, x) 
            return A_x_product + constants
        
        _, increment, _, _=super().step(f_combined,init_cond_epi,t,dt)
        x_sol=x_n+increment
        return None, x_sol, None, None

## Test fixed step EPI2 with adaptive step DP45

In [12]:
## Next, Implement the EPI2 method by applying a adaptive method using DP45 as a sub-class of DP45:
class EPI2_adaptive(DormandPrince45):
    def __init__(self, dtype=torch.float32):
        super().__init__(dtype=dtype)
        self.dtype = dtype
        self.stepping_class = 'fixed'


    def step(self, f, x_n, t, dt, k1=None, args=None):
        
        ## Specify what is solver we will be using
        solver=DormandPrince45(dtype=torch.float32)
        constants=f(t,x_n).detach()
        
        
        
        
        ## Recall that this step of EPI2 is going to update x to x+dt using EPI2 step under dynamics f
        def f_combined(t, x):
            # Compute the original dynamics f(x, t)
            x.requires_grad_(True)
            f_x_t = f(t, x)

            # Compute the Jacobian-vector product (A_n * x) using torch.autograd.functional.jacobian
            _, A_x_product = torch.autograd.functional.jvp(lambda x: f(t, x), x, x)
            return A_x_product + constants
        
        x0=torch.zeros_like(x_n).to(device)
        t0=torch.zeros_like(t).to(device)
        k0=f_combined(t0,x0).to(device)
        dt_substep=init_step(f_combined, k0, x0, t0, solver.order, atol=1e-4, rtol=1e-4).to(device)
        
        t_span=torch.linspace(t0.item(), dt.item(), steps=5, dtype=torch.float32).to(device)

        record_time, increment=_adaptive_odeint(f_combined, k0, x0, dt_substep, t_span, solver , atol=1e-4, rtol=1e-4)
        
        x_sol=x_n+increment[-1]
        return None, x_sol, None, None

In [13]:
## introduce the odeint function:
def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, nn.Module], atol:float=1e-3, rtol:float=1e-3,
           t_stops:Union[List, Tensor, None]=None, verbose:bool=False, interpolator:Union[str, Callable, None]=None, return_all_eval:bool=False,
           save_at:Union[Iterable, Tensor]=(), args:Dict={}, seminorm:Tuple[bool, Union[int, None]]=(False, None)) -> Tuple[Tensor, Tensor]:
    """Solve an initial value problem (IVP) determined by function `f` and initial condition `x`.

       Functional `odeint` API of the `torchdyn` package.

    Args:
        f (Callable):
        x (Tensor):
        t_span (Union[List, Tensor]):
        solver (Union[str, nn.Module]):
        atol (float, optional): Defaults to 1e-3.
        rtol (float, optional): Defaults to 1e-3.
        t_stops (Union[List, Tensor, None], optional): Defaults to None.
        verbose (bool, optional): Defaults to False.
        interpolator (bool, optional): Defaults to False.
        return_all_eval (bool, optional): Defaults to False.
        save_at (Union[List, Tensor], optional): Defaults to t_span
        args (Dict): Arbitrary parameters used in step
        seminorm (Tuple[bool, Union[int, None]], optional): Whether to use seminorms in local error computation.

    Returns:
        Tuple[Tensor, Tensor]: returns a Tuple (t_eval, solution).
    """
    if t_span[1] < t_span[0]: # time is reversed
        if verbose: warn("You are integrating on a reversed time domain, adjusting the vector field automatically")
        f_ = lambda t, x: -f(-t, x)
        t_span = -t_span
    else: f_ = f

    if type(t_span) == list: t_span = torch.cat(t_span)
    # instantiate the solver in case the user has specified preference via a `str` and ensure compatibility of device ~ dtype
    if type(solver) == str:
        solver = str_to_solver(solver, x.dtype)
    x, t_span = solver.sync_device_dtype(x, t_span)
    stepping_class = solver.stepping_class

    # instantiate the interpolator similar to the solver steps above
    if isinstance(solver, Tsitouras45):
        if verbose: warn("Running interpolation not yet implemented for `tsit5`")
        interpolator = None

    if type(interpolator) == str:
        interpolator = str_to_interp(interpolator, x.dtype)
        x, t_span = interpolator.sync_device_dtype(x, t_span)

    # access parallel integration routines with different t_spans for each sample in `x`.
    if len(t_span.shape) > 1:
        raise NotImplementedError("Parallel routines not implemented yet, check experimental versions of `torchdyn`")
    # odeint routine with a single t_span for all samples
    elif len(t_span.shape) == 1:
        if stepping_class == 'fixed':
            if atol != odeint.__defaults__[0] or rtol != odeint.__defaults__[1]:
                warn("Setting tolerances has no effect on fixed-step methods")
            # instantiate save_at tensor
            return _fixed_odeint(f_, x, t_span, solver, save_at=save_at, args=args)
        elif stepping_class == 'adaptive':
            t = t_span[0]
            k1 = f_(t, x)
            dt = init_step(f, k1, x, t, solver.order, atol, rtol)
            if len(save_at) > 0: warn("Setting save_at has no effect on adaptive-step methods")
            return _adaptive_odeint(f_, k1, x, dt, t_span, solver, atol, rtol, args, interpolator, return_all_eval, seminorm)

In [32]:
## Define the customized forward and backward method under autograd.function class:
def generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B0=None, 
                  return_all_eval=False, maxiter=4, fine_steps=4, save_at=()):
    "Dispatches to appropriate `odeint` function depending on `Problem` class (ODEProblem, MultipleShootingProblem)"
    if problem_type == 'standard':
        return odeint(vf, x, t_span, solver, atol=atol, rtol=rtol, interpolator=interpolator, return_all_eval=return_all_eval,
                      save_at=save_at)
    elif problem_type == 'multiple_shooting':
        return odeint_mshooting(vf, x, t_span, solver, B0=B0, fine_steps=fine_steps, maxiter=maxiter)


# TODO: optimize and make conditional gradient computations w.r.t end times
# TODO: link `seminorm` arg from `ODEProblem`
def _gather_odefunc_adjoint(vf, vf_params, solver, atol, rtol, interpolator, solver_adjoint, 
                            atol_adjoint, rtol_adjoint, integral_loss, problem_type, maxiter=4, fine_steps=4):
    "Prepares definition of autograd.Function for adjoint sensitivity analysis of the above `ODEProblem`"
    class _ODEProblemFunc(torch.autograd.Function):
        @staticmethod
        def forward(ctx, vf_params, x, t_span, B=None, save_at=()):
            t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B, 
                                        False, maxiter, fine_steps, save_at)
            ctx.save_for_backward(sol, t_sol)
            return t_sol, sol

        @staticmethod
        def backward(ctx, *grad_output):
            sol, t_sol = ctx.saved_tensors
            vf_params = torch.cat([p.contiguous().flatten() for p in vf.parameters()])
            # initialize flattened adjoint state
            xT, λT, μT = sol[-1], grad_output[-1][-1], torch.zeros_like(vf_params)
            xT_nel, λT_nel, μT_nel = xT.numel(), λT.numel(), μT.numel()
            xT_shape, λT_shape, μT_shape = xT.shape, λT.shape, μT.shape

            λT_flat = λT.flatten()
            λtT = λT_flat @ vf(t_sol[-1], xT).flatten()
            # concatenate all states of adjoint system
            A = torch.cat([xT.flatten(), λT_flat, μT.flatten(), λtT[None]])

            def adjoint_dynamics(t, A):
                if len(t.shape) > 0: t = t[0]
                x, λ, μ = A[:xT_nel], A[xT_nel:xT_nel+λT_nel], A[-μT_nel-1:-1]
                x, λ, μ = x.reshape(xT.shape), λ.reshape(λT.shape), μ.reshape(μT.shape)
                with torch.set_grad_enabled(True):
                    x, t = x.requires_grad_(True), t.requires_grad_(True)
                    dx = vf(t, x)
                    dλ, dt, *dμ = tuple(torch.autograd.grad(dx, (x, t) + tuple(vf.parameters()), -λ,
                                    allow_unused=True, retain_graph=True))

                    if integral_loss:
                        dg = torch.autograd.grad(integral_loss(t, x).sum(), x, allow_unused=True, retain_graph=True)[0]
                        dλ = dλ - dg

                    dμ = torch.cat([el.flatten() if el is not None else torch.zeros(1) 
                                    for el in dμ], dim=-1)
                    if dt == None: dt = torch.zeros(1).to(t)
                    if len(t.shape) == 0: dt = dt.unsqueeze(0)
                return torch.cat([dx.flatten(), dλ.flatten(), dμ.flatten(), dt.flatten()])

            # solve the adjoint equation
            n_elements = (xT_nel, λT_nel, μT_nel)
            dLdt = torch.zeros(len(t_sol)).to(xT)
            dLdt[-1] = λtT
            for i in range(len(t_sol) - 1, 0, -1):
                t_adj_sol, A = odeint(adjoint_dynamics, A, t_sol[i - 1:i + 1].flip(0), 
                                      solver_adjoint, atol=atol_adjoint, rtol=rtol_adjoint,
                                      seminorm=(True, xT_nel+λT_nel))
                # prepare adjoint state for next interval
                #TODO: reuse vf_eval for dLdt calculations
                xt = A[-1, :xT_nel].reshape(xT_shape)
                dLdt_ = A[-1, xT_nel:xT_nel + λT_nel]@vf(t_sol[i], xt).flatten()
                A[-1, -1:] -= grad_output[0][i - 1]
                dLdt[i-1] = dLdt_

                A = torch.cat([A[-1, :xT_nel], A[-1, xT_nel:xT_nel + λT_nel], A[-1, -μT_nel-1:-1], A[-1, -1:]])
                A[xT_nel:xT_nel + λT_nel] += grad_output[-1][i - 1].flatten()

            λ, μ = A[xT_nel:xT_nel + λT_nel], A[-μT_nel-1:-1]
            λ, μ = λ.reshape(λT.shape), μ.reshape(μT.shape)
            λ_tspan = torch.stack([dLdt[0], dLdt[-1]])
            return (μ, λ, λ_tspan, None, None, None)

    return _ODEProblemFunc

In [15]:
## Define the ODEproblem class for bridge
class ODEProblem(nn.Module):
    def __init__(
        self,
        vector_field: Union[Callable, nn.Module],
        solver: Union[str, nn.Module],
        interpolator: Union[str, Callable, None] = None,
        order: int = 1,
        atol: float = 1e-4,
        rtol: float = 1e-4,
        sensitivity: str = "autograd",
        solver_adjoint: Union[str, nn.Module, None] = None,
        atol_adjoint: float = 1e-6,
        rtol_adjoint: float = 1e-6,
        seminorm: bool = False,
        integral_loss: Union[Callable, None] = None,
        optimizable_params: Union[Iterable, Generator] = (),
    ):
        """An ODE Problem coupling a given vector field with solver and sensitivity algorithm to compute gradients w.r.t different quantities.

        Args:
            vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`.
                                       In the second case, the Callable is automatically wrapped for consistency
            solver (Union[str, nn.Module]):
            order (int, optional): Order of the ODE. Defaults to 1.
            atol (float, optional): Absolute tolerance of the solver. Defaults to 1e-4.
            rtol (float, optional): Relative tolerance of the solver. Defaults to 1e-4.
            sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'.
            solver_adjoint (Union[str, nn.Module, None], optional): ODE solver for the adjoint. Defaults to None.
            atol_adjoint (float, optional): Defaults to 1e-6.
            rtol_adjoint (float, optional): Defaults to 1e-6.
            seminorm (bool, optional): Indicates whether the a seminorm should be used for error estimation during adjoint backsolves. Defaults to False.
            integral_loss (Union[Callable, None]): Integral loss to optimize for. Defaults to None.
            optimizable_parameters (Union[Iterable, Generator]): parameters to calculate sensitivies for. Defaults to ().
        Notes:
            Integral losses can be passed as generic function or `nn.Modules`.
        """
        super().__init__()
        # instantiate solver at initialization
        if type(solver) == str:
            solver = str_to_solver(solver)
        if solver_adjoint is None:
            solver_adjoint = solver
        else:
            solver_adjoint = str_to_solver(solver_adjoint)

        self.solver, self.interpolator, self.atol, self.rtol = (
            solver,
            interpolator,
            atol,
            rtol,
        )
        self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint = (
            solver_adjoint,
            atol_adjoint,
            rtol_adjoint,
        )
        self.sensitivity, self.integral_loss = sensitivity, integral_loss

        # wrap vector field if `t, x` is not the call signature
        vector_field = standardize_vf_call_signature(vector_field)

        self.vf, self.order, self.sensalg = vector_field, order, sensitivity
        optimizable_params = tuple(optimizable_params)

        if len(tuple(self.vf.parameters())) > 0:
            self.vf_params = torch.cat(
                [p.contiguous().flatten() for p in self.vf.parameters()]
            )

        elif len(optimizable_params) > 0:
            # use `optimizable_parameters` if f itself does not have a .parameters() iterable
            # TODO: advanced logic to retain naming in case `state_dicts()` are passed
            for k, p in enumerate(optimizable_params):
                self.vf.register_parameter(f"optimizable_parameter_{k}", p)
            self.vf_params = torch.cat(
                [p.contiguous().flatten() for p in optimizable_params]
            )

        else:
            print("Your vector field does not have `nn.Parameters` to optimize.")
            dummy_parameter = nn.Parameter(torch.zeros(1))
            self.vf.register_parameter("dummy_parameter", dummy_parameter)
            self.vf_params = torch.cat(
                [p.contiguous().flatten() for p in self.vf.parameters()]
            )

    def _autograd_func(self):
        "create autograd functions for backward pass"
        self.vf_params = torch.cat(
            [p.contiguous().flatten() for p in self.vf.parameters()]
        )
        if (self.sensalg == "adjoint"):  # alias .apply as direct call to preserve consistency of call signature
            return _gather_odefunc_adjoint(
                self.vf,
                self.vf_params,
                self.solver,
                self.atol,
                self.rtol,
                self.interpolator,
                self.solver_adjoint,
                self.atol_adjoint,
                self.rtol_adjoint,
                self.integral_loss,
                problem_type="standard",
            ).apply
        elif self.sensalg == "interpolated_adjoint":
            return _gather_odefunc_interp_adjoint(
                self.vf,
                self.vf_params,
                self.solver,
                self.atol,
                self.rtol,
                self.interpolator,
                self.solver_adjoint,
                self.atol_adjoint,
                self.rtol_adjoint,
                self.integral_loss,
                problem_type="standard",
            ).apply

    def odeint(self, x: Tensor, t_span: Tensor, save_at: Tensor = (), args={}):
        "Returns Tuple(`t_eval`, `solution`)"
        if self.sensalg == "autograd":
            return odeint(
                self.vf,
                x,
                t_span,
                self.solver,
                self.atol,
                self.rtol,
                interpolator=self.interpolator,
                save_at=save_at,
                args=args,
            )
        else:
            return self._autograd_func()(self.vf_params, x, t_span, save_at, args)

    def forward(self, x: Tensor, t_span: Tensor, save_at: Tensor = (), args={}):
        "For safety redirects to intended method `odeint`"
        return self.odeint(x, t_span, save_at, args)

In [16]:
## Define the NeuralODE class
class NeuralODE(ODEProblem, pl.LightningModule):
    def __init__(
        self,
        vector_field: Union[Callable, nn.Module],
        solver: Union[str, nn.Module] = "tsit5",
        order: int = 1,
        atol: float = 1e-3,
        rtol: float = 1e-3,
        sensitivity="autograd",
        solver_adjoint: Union[str, nn.Module, None] = None,
        atol_adjoint: float = 1e-4,
        rtol_adjoint: float = 1e-4,
        interpolator: Union[str, Callable, None] = None,
        integral_loss: Union[Callable, None] = None,
        seminorm: bool = False,
        return_t_eval: bool = True,
        optimizable_params: Union[Iterable, Generator] = (),
    ):
        """Generic Neural Ordinary Differential Equation.

        Args:
            vector_field ([Callable]): the vector field, called with `vector_field(t, x)` for `vector_field(x)`. 
                                       In the second case, the Callable is automatically wrapped for consistency
            solver (Union[str, nn.Module]): 
            order (int, optional): Order of the ODE. Defaults to 1.
            atol (float, optional): Absolute tolerance of the solver. Defaults to 1e-4.
            rtol (float, optional): Relative tolerance of the solver. Defaults to 1e-4.
            sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'.
            solver_adjoint (Union[str, nn.Module, None], optional): ODE solver for the adjoint. Defaults to None.
            atol_adjoint (float, optional): Defaults to 1e-6.
            rtol_adjoint (float, optional): Defaults to 1e-6.
            integral_loss (Union[Callable, None], optional): Defaults to None.
            seminorm (bool, optional): Whether to use seminorms for adaptive stepping in backsolve adjoints. Defaults to False.
            return_t_eval (bool): Whether to return (t_eval, sol) or only sol. Useful for chaining NeuralODEs in `nn.Sequential`.
            optimizable_parameters (Union[Iterable, Generator]): parameters to calculate sensitivies for. Defaults to ().
        Notes:
            In `torchdyn`-style, forward calls to a Neural ODE return both a tensor `t_eval` of time points at which the solution is evaluated
            as well as the solution itself. This behavior can be controlled by setting `return_t_eval` to False. Calling `trajectory` also returns
            the solution only. 

            The Neural ODE class automates certain delicate steps that must be done depending on the solver and model used. 
            The `prep_odeint` method carries out such steps. Neural ODEs wrap `ODEProblem`.
        """
        super().__init__(
            vector_field=standardize_vf_call_signature(
                vector_field, order, defunc_wrap=True
            ),
            order=order,
            sensitivity=sensitivity,
            solver=solver,
            atol=atol,
            rtol=rtol,
            solver_adjoint=solver_adjoint,
            atol_adjoint=atol_adjoint,
            rtol_adjoint=rtol_adjoint,
            seminorm=seminorm,
            interpolator=interpolator,
            integral_loss=integral_loss,
            optimizable_params=optimizable_params,
        )
        # data-control conditioning
        self._control, self.controlled, self.t_span = None, False, None
        self.return_t_eval = return_t_eval
        if integral_loss is not None:
            self.vf.integral_loss = integral_loss
        self.vf.sensitivity = sensitivity

    def _prep_integration(self, x: Tensor, t_span: Tensor) -> Tensor:
        "Performs generic checks before integration. Assigns data control inputs and augments state for CNFs"

        # assign a basic value to `t_span` for `forward` calls that do no explicitly pass an integration interval
        if t_span is None and self.t_span is None:
            t_span = torch.linspace(0, 1, 2)
        elif t_span is None:
            t_span = self.t_span

        # loss dimension detection routine; for CNF div propagation and integral losses w/ autograd
        excess_dims = 0
        if (not self.integral_loss is None) and self.sensitivity == "autograd":
            excess_dims += 1

        # handle aux. operations required for some jacobian trace CNF estimators e.g Hutchinson's
        # as well as datasets-control set to DataControl module
        for _, module in self.vf.named_modules():
            if hasattr(module, "trace_estimator"):
                if module.noise_dist is not None:
                    module.noise = module.noise_dist.sample((x.shape[0],))
                excess_dims += 1

            # data-control set routine. Is performed once at the beginning of odeint since the control is fixed to IC
            if hasattr(module, "_control"):
                self.controlled = True
                module._control = x[:, excess_dims:].detach()
        return x, t_span

    def forward(
        self,
        x: Union[Tensor, Dict],
        t_span: Tensor = None,
        save_at: Iterable = (),
        args={},
    ):
        x, t_span = self._prep_integration(x, t_span)
        t_eval, sol = super().forward(x, t_span, save_at, args)
        if self.return_t_eval:
            return t_eval, sol
        else:
            return sol

    def trajectory(self, x: torch.Tensor, t_span: Tensor):
        x, t_span = self._prep_integration(x, t_span)
        _, sol = odeint(
            self.vf, x, t_span, solver=self.solver, atol=self.atol, rtol=self.rtol
        )
        return sol

    def __repr__(self):
        npar = sum([p.numel() for p in self.vf.parameters()])
        return f"Neural ODE:\n\t- order: {self.order}\
        \n\t- solver: {self.solver}\n\t- adjoint solver: {self.solver_adjoint}\
        \n\t- tolerances: relative {self.rtol} absolute {self.atol}\
        \n\t- adjoint tolerances: relative {self.rtol_adjoint} absolute {self.atol_adjoint}\
        \n\t- num_parameters: {npar}\
        \n\t- NFE: {self.vf.nfe}"

## Define some simple neural network

In [35]:
## Notice all the f below takes data with 2 dimensions and ouput 2 dimensions (2 classes)

In [22]:
## Define neural network f:
f = nn.Sequential(
        nn.Linear(2, 16),
        nn.Linear(16, 2),
        nn.ReLU()
    )

In [60]:
## Define neural network f:
f = nn.Sequential(
        nn.Linear(2, 16),
        nn.Linear(16, 32),
        nn.Linear(32, 32),
        nn.Linear(32, 16),
        nn.Linear(16, 2),
        nn.ReLU()
    )

In [91]:
class FCNet(nn.Module):
    def __init__(self, in_features, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(in_features, 128)  # First fully connected layer
        self.bn1 = nn.BatchNorm1d(128)
        self.relu1 = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(128, 64)  # Second fully connected layer
        self.bn2 = nn.BatchNorm1d(64)
        self.relu2 = nn.ReLU(inplace=True)
        self.fc3 = nn.Linear(64, 32)
        self.relu3 = nn.ReLU(inplace=True)
        self.fc4 = nn.Linear(32, num_classes) # Output layer
        
    def forward(self, xb):
        xb = self.fc1(xb)
        xb = self.bn1(xb)
        xb = self.relu1(xb)
        xb = self.fc2(xb)
        xb = self.bn2(xb)
        xb = self.relu2(xb)
        xb = self.fc3(xb)
        xb = self.relu3(xb)
        xb = self.fc4(xb)
        return xb

f = FCNet(in_features=2, num_classes=2)

## Define the model in the cutomized forward and backward flow

In [89]:
## Model using EPI2 solver
model = NeuralODE(f, sensitivity='adjoint', solver=EPI2_one_step(dtype=torch.float32)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


In [33]:
## Model using EPI2 (adaptive) solver
model = NeuralODE(f, sensitivity='adjoint', solver=EPI2_adaptive(dtype=torch.float32)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


In [92]:
## Model using DP45 solver
model = NeuralODE(f, sensitivity='adjoint', solver=DormandPrince45(dtype=torch.float32)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.012)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


## Prepare the training and Validation data loader:

In [24]:
from torchdyn.datasets import *
d = ToyDataset()
## Prepare the training data loader
X, yn = d.generate(n_samples=1024, noise=1e-1, dataset_type='moons')

X_train = torch.Tensor(X).to(device)
y_train = torch.LongTensor(yn.long()).to(device)
train = data.TensorDataset(X_train, y_train)
trainloader = data.DataLoader(train, batch_size=256, shuffle=True)

In [25]:
## Prepare the Test data loader
X_test, yn_test = d.generate(n_samples=128, noise=1e-1, dataset_type='moons')

X_test = torch.Tensor(X_test).to(device)
y_test = torch.LongTensor(yn_test.long()).to(device)
Test = data.TensorDataset(X_test, y_test)
Testloader = data.DataLoader(Test, batch_size=128, shuffle=True)

In [27]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    total=0
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        _, trajectory = model(data, t_span)
        output = trajectory[-1]
        loss = nn.CrossEntropyLoss()(output, target)

        loss.backward()
        optimizer.step()
        # Calculate the number of correct predictions
        pred = output.argmax(dim=1)  # Get the index of the max log-probability
        correct += (pred==target.view_as(pred)).sum().item()

        total += target.size(0)
        
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    print(f'# of correct in each batch: {correct}')
    print(f'# of trials in each batch: {total}')
    accuracy = 100. * correct / total
    print(f'Train Epoch: {epoch} Accuracy: {accuracy:.2f}%')
    
    
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            
            _, trajectory = model(data, t_span)
            output = trajectory[-1]
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [None]:
## Specify the time span
t_span = torch.linspace(0, 1, 15)

## Start with Training and Validation

In [37]:
## simple model with EPI2 solver:
for epoch in range(1, 160+1):
        train(model, device, trainloader, optimizer, epoch)
        test(model, device, Testloader)
        scheduler.step()

  warn("Setting tolerances has no effect on fixed-step methods")


# of correct in each batch: 197
# of trials in each batch: 1024
Train Epoch: 1 Accuracy: 19.24%

Test set: Average loss: -0.1451, Accuracy: 24/128 (19%)

# of correct in each batch: 194
# of trials in each batch: 1024
Train Epoch: 2 Accuracy: 18.95%

Test set: Average loss: -0.2671, Accuracy: 25/128 (20%)

# of correct in each batch: 211
# of trials in each batch: 1024
Train Epoch: 3 Accuracy: 20.61%

Test set: Average loss: -0.3449, Accuracy: 29/128 (23%)

# of correct in each batch: 245
# of trials in each batch: 1024
Train Epoch: 4 Accuracy: 23.93%

Test set: Average loss: -0.3870, Accuracy: 33/128 (26%)

# of correct in each batch: 262
# of trials in each batch: 1024
Train Epoch: 5 Accuracy: 25.59%

Test set: Average loss: -0.4114, Accuracy: 35/128 (27%)

# of correct in each batch: 282
# of trials in each batch: 1024
Train Epoch: 6 Accuracy: 27.54%

Test set: Average loss: -0.4276, Accuracy: 36/128 (28%)

# of correct in each batch: 293
# of trials in each batch: 1024
Train Epoch:

KeyboardInterrupt: 

In [59]:
## complex f with EPI2 solver:
for epoch in range(1, 120+1):
        train(model, device, trainloader, optimizer, epoch)
        test(model, device, Testloader)
        scheduler.step()

  warn("Setting tolerances has no effect on fixed-step methods")


# of correct in each batch: 205
# of trials in each batch: 1024
Train Epoch: 1 Accuracy: 20.02%

Test set: Average loss: -0.6464, Accuracy: 64/128 (50%)

# of correct in each batch: 512
# of trials in each batch: 1024
Train Epoch: 2 Accuracy: 50.00%

Test set: Average loss: -0.9326, Accuracy: 64/128 (50%)

# of correct in each batch: 512
# of trials in each batch: 1024
Train Epoch: 3 Accuracy: 50.00%

Test set: Average loss: -1.0374, Accuracy: 64/128 (50%)

# of correct in each batch: 512
# of trials in each batch: 1024
Train Epoch: 4 Accuracy: 50.00%

Test set: Average loss: -1.2826, Accuracy: 64/128 (50%)

# of correct in each batch: 512
# of trials in each batch: 1024
Train Epoch: 5 Accuracy: 50.00%

Test set: Average loss: -1.3310, Accuracy: 64/128 (50%)

# of correct in each batch: 512
# of trials in each batch: 1024
Train Epoch: 6 Accuracy: 50.00%

Test set: Average loss: -1.3016, Accuracy: 64/128 (50%)

# of correct in each batch: 512
# of trials in each batch: 1024
Train Epoch:

In [62]:
## complex f with EPI2 solver:
for epoch in range(1, 120+1):
        train(model, device, trainloader, optimizer, epoch)
        test(model, device, Testloader)
        scheduler.step()

  warn("Setting tolerances has no effect on fixed-step methods")


# of correct in each batch: 169
# of trials in each batch: 1024
Train Epoch: 1 Accuracy: 16.50%

Test set: Average loss: -0.9219, Accuracy: 26/128 (20%)

# of correct in each batch: 432
# of trials in each batch: 1024
Train Epoch: 2 Accuracy: 42.19%

Test set: Average loss: -1.8157, Accuracy: 72/128 (56%)

# of correct in each batch: 557
# of trials in each batch: 1024
Train Epoch: 3 Accuracy: 54.39%

Test set: Average loss: -1.9957, Accuracy: 66/128 (52%)

# of correct in each batch: 545
# of trials in each batch: 1024
Train Epoch: 4 Accuracy: 53.22%

Test set: Average loss: -2.2955, Accuracy: 71/128 (55%)

# of correct in each batch: 600
# of trials in each batch: 1024
Train Epoch: 5 Accuracy: 58.59%

Test set: Average loss: -2.7562, Accuracy: 84/128 (66%)

# of correct in each batch: 687
# of trials in each batch: 1024
Train Epoch: 6 Accuracy: 67.09%

Test set: Average loss: -3.1630, Accuracy: 89/128 (70%)

# of correct in each batch: 742
# of trials in each batch: 1024
Train Epoch:

In [90]:
## complex f with EPI2 solver:
for epoch in range(1, 20+1):
        train(model, device, trainloader, optimizer, epoch)
        test(model, device, Testloader)
        scheduler.step()

  warn("Setting tolerances has no effect on fixed-step methods")


# of correct in each batch: 356
# of trials in each batch: 1024
Train Epoch: 1 Accuracy: 34.77%

Test set: Average loss: -1.2051, Accuracy: 72/128 (56%)

# of correct in each batch: 614
# of trials in each batch: 1024
Train Epoch: 2 Accuracy: 59.96%

Test set: Average loss: -3.7382, Accuracy: 79/128 (62%)

# of correct in each batch: 670
# of trials in each batch: 1024
Train Epoch: 3 Accuracy: 65.43%

Test set: Average loss: -29.3953, Accuracy: 81/128 (63%)

# of correct in each batch: 630
# of trials in each batch: 1024
Train Epoch: 4 Accuracy: 61.52%

Test set: Average loss: -122.1381, Accuracy: 79/128 (62%)

# of correct in each batch: 619
# of trials in each batch: 1024
Train Epoch: 5 Accuracy: 60.45%

Test set: Average loss: -590.3235, Accuracy: 76/128 (59%)

# of correct in each batch: 624
# of trials in each batch: 1024
Train Epoch: 6 Accuracy: 60.94%

Test set: Average loss: -863.5206, Accuracy: 75/128 (59%)

# of correct in each batch: 610
# of trials in each batch: 1024
Train

In [93]:
## complex f with EPI2 solver:
for epoch in range(1, 20+1):
        train(model, device, trainloader, optimizer, epoch)
        test(model, device, Testloader)
        scheduler.step()

  warn("Setting tolerances has no effect on fixed-step methods")


# of correct in each batch: 388
# of trials in each batch: 1024
Train Epoch: 1 Accuracy: 37.89%

Test set: Average loss: -0.2897, Accuracy: 35/128 (27%)

# of correct in each batch: 505
# of trials in each batch: 1024
Train Epoch: 2 Accuracy: 49.32%

Test set: Average loss: -0.2464, Accuracy: 42/128 (33%)

# of correct in each batch: 507
# of trials in each batch: 1024
Train Epoch: 3 Accuracy: 49.51%

Test set: Average loss: -0.2519, Accuracy: 43/128 (34%)

# of correct in each batch: 492
# of trials in each batch: 1024
Train Epoch: 4 Accuracy: 48.05%

Test set: Average loss: -0.2413, Accuracy: 41/128 (32%)

# of correct in each batch: 570
# of trials in each batch: 1024
Train Epoch: 5 Accuracy: 55.66%

Test set: Average loss: -0.2146, Accuracy: 39/128 (30%)

# of correct in each batch: 469
# of trials in each batch: 1024
Train Epoch: 6 Accuracy: 45.80%

Test set: Average loss: -0.1825, Accuracy: 41/128 (32%)

# of correct in each batch: 472
# of trials in each batch: 1024
Train Epoch: