In [1]:
import os

import tensorflow as tf
import numpy as npo

import autograd.numpy as np
from autograd import grad
from autograd.scipy.integrate import odeint
from autograd.builtins import tuple
from autograd.misc.optimizers import adam

# from scipy.integrate import odeint
# from scipy.interpolate import interp1d

import time
from tqdm import tqdm

In [2]:
def solveODE(nPat):
    
    fj          = np.hstack([np.array([12 , 7, 15 ])] * nPat).reshape(nPat, -1).astype(np.float32)
    rj          = np.hstack([np.array([6  , 3,  8 ])] * nPat).reshape(nPat, -1).astype(np.float32)
    mj          = np.hstack([np.array([10 , 17, 2 ])] * nPat).reshape(nPat, -1).astype(np.float32)

    def rhs(y, t, params):

        fj, rj, mj = params

        Nnt      = np.array(y).reshape(nPat, -1)
        results  = fj - rj * Nnt - mj * Nnt
        results  = results.flatten()

        return results

    params    = [fj, rj, mj]

    start     = time.time()
    true_y    = odeint(rhs, y0=np.array([1, 1, 1] * nPat), t=np.linspace(0, 100, 101), args=(params,))
    ODEcost   = time.time() - start
    
    def loss(params, iterations):
        pred_y   = odeint(rhs, np.array([1, 1, 1] * nPat), np.linspace(0, 100, 101), tuple((params,)))
        return np.square(true_y - pred_y).mean()
    
    init_params = [ np.hstack([np.zeros(shape=(3,))] * nPat).reshape(nPat, -1).astype(np.float32), 
                    np.hstack([np.zeros(shape=(3,))] * nPat).reshape(nPat, -1).astype(np.float32),
                    np.hstack([np.zeros(shape=(3,))] * nPat).reshape(nPat, -1).astype(np.float32)   ]
    
#     pbar   = tqdm(range(1))

    def callback(params, iterations, g):

        pred_y = odeint(rhs, np.array([1, 1, 1] * nPat), np.linspace(0, 100, 101), tuple((params,)))
        description = "Iteration {:d} train loss {:.6f}".format(
                          iterations, np.square(true_y - pred_y).mean())
        pbar.set_description(description)
        pbar.update(1)
        
    start     = time.time()
    optimized_params = adam(grad(loss), init_params, num_iters=1)
    TrainCost = time.time() - start
    
    return ODEcost, TrainCost

In [3]:
for nPat in [1000]:
    ODEcost, TrainCost = solveODE(nPat)
    print('N', nPat, 'ODEcost', ODEcost, 'perPax', ODEcost / nPat, 'TrainCost', TrainCost, 'perPax', TrainCost / nPat)

N 1000 ODEcost 1.4888520240783691 perPax 0.0014888520240783692 TrainCost 1.38493013381958 perPax 0.0013849301338195801
