In [1]:
import os

import tensorflow as tf
import numpy as npo

import autograd.numpy as np
from autograd import grad
from autograd.scipy.integrate import odeint
from autograd.builtins import tuple
from autograd.misc.optimizers import adam

# from scipy.integrate import odeint
# from scipy.interpolate import interp1d

import time
from tqdm import tqdm

In [2]:
nPat        = 10
fj          = np.hstack([np.array([12 , 7, 15 ])] * nPat).reshape(nPat, -1).astype(np.float32)
rj          = np.hstack([np.array([6  , 3,  8 ])] * nPat).reshape(nPat, -1).astype(np.float32)
mj          = np.hstack([np.array([10 , 17, 2 ])] * nPat).reshape(nPat, -1).astype(np.float32)

def rhs(y, t, params):
    
    fj, rj, mj = params

    Nnt      = np.array(y).reshape(nPat, -1)
    results  = fj - rj * Nnt - mj * Nnt
    results  = results / 100
    results  = results.flatten()

    return results

params    = [fj, rj, mj]

start     = time.time()
true_y    = odeint(rhs, y0=np.array([1, 1, 1] * nPat), t=np.linspace(0, 100, 101), args=(params,))
timeCost  = time.time() - start
print('N', nPat, 'timeCost', timeCost, 'per User', timeCost / nPat)

dy = rhs(np.array([1, 1, 1] * nPat), t=np.linspace(0, 100, 101), params=params)
print(dy.min(), dy.max())

N 10 timeCost 0.005349159240722656 per User 0.0005349159240722656
-0.13 0.05


In [3]:
def loss(params, iterations):
    pred_y   = odeint(rhs, np.array([1, 1, 1] * nPat), np.linspace(0, 100, 101), tuple((params,)))
    return np.square(true_y - pred_y).mean()

In [4]:
lossGrad = grad(loss)

In [5]:
init_params = [ np.hstack([np.zeros(shape=(3,))] * nPat).reshape(nPat, -1).astype(np.float32), 
                np.hstack([np.zeros(shape=(3,))] * nPat).reshape(nPat, -1).astype(np.float32),
                np.hstack([np.zeros(shape=(3,))] * nPat).reshape(nPat, -1).astype(np.float32)   ]

In [7]:
nIterations = 100
pbar        = tqdm(range(nIterations))

def callback(params, iterations, g):

    pred_y = odeint(rhs, np.array([1, 1, 1] * nPat), np.linspace(0, 100, 101), tuple((params,)))
    description = "Iteration {:d} train loss {:.6f}".format(
                      iterations, np.square(true_y - pred_y).mean())
    pbar.set_description(description)
    pbar.update(1)

  0%|          | 0/100 [00:00<?, ?it/s]

In [8]:
start            = time.time()
optimized_params = adam(grad(loss), init_params, num_iters=nIterations, callback=callback, step_size=0.01)
timeCost         = time.time() - start
print('N', nPat, 'timeCost', timeCost, 'per User', timeCost / nPat)

Iteration 99 train loss 0.034520: 100%|██████████| 100/100 [01:50<00:00,  1.09s/it]

N 10 timeCost 110.12117099761963 per User 11.012117099761962


In [9]:
optimized_params

[array([[-0.10825356, -0.45906344,  0.2397624 ],
        [-0.10825356, -0.45906344,  0.2397624 ],
        [-0.10825356, -0.45906344,  0.2397624 ],
        [-0.10825356, -0.45906344,  0.2397624 ],
        [-0.10825356, -0.45906344,  0.2397624 ],
        [-0.10825356, -0.45906344,  0.2397624 ],
        [-0.10825356, -0.45906344,  0.2397624 ],
        [-0.10825356, -0.45906344,  0.2397624 ],
        [-0.10825356, -0.45906344,  0.2397624 ],
        [-0.10825356, -0.45906344,  0.2397624 ]], dtype=float32),
 array([[ 0.15736288,  0.4554216 , -0.18560188],
        [ 0.15736288,  0.4554216 , -0.18560188],
        [ 0.15736288,  0.4554216 , -0.18560188],
        [ 0.15736288,  0.4554216 , -0.18560188],
        [ 0.15736288,  0.4554216 , -0.18560188],
        [ 0.15736288,  0.4554216 , -0.18560188],
        [ 0.15736288,  0.4554216 , -0.18560188],
        [ 0.15736288,  0.4554216 , -0.18560188],
        [ 0.15736288,  0.4554216 , -0.18560188],
        [ 0.15736288,  0.4554216 , -0.18560188]], dt

In [10]:
params

[array([[ 12.,   7.,  15.],
        [ 12.,   7.,  15.],
        [ 12.,   7.,  15.],
        [ 12.,   7.,  15.],
        [ 12.,   7.,  15.],
        [ 12.,   7.,  15.],
        [ 12.,   7.,  15.],
        [ 12.,   7.,  15.],
        [ 12.,   7.,  15.],
        [ 12.,   7.,  15.]], dtype=float32), array([[ 6.,  3.,  8.],
        [ 6.,  3.,  8.],
        [ 6.,  3.,  8.],
        [ 6.,  3.,  8.],
        [ 6.,  3.,  8.],
        [ 6.,  3.,  8.],
        [ 6.,  3.,  8.],
        [ 6.,  3.,  8.],
        [ 6.,  3.,  8.],
        [ 6.,  3.,  8.]], dtype=float32), array([[ 10.,  17.,   2.],
        [ 10.,  17.,   2.],
        [ 10.,  17.,   2.],
        [ 10.,  17.,   2.],
        [ 10.,  17.,   2.],
        [ 10.,  17.,   2.],
        [ 10.,  17.,   2.],
        [ 10.,  17.,   2.],
        [ 10.,  17.,   2.],
        [ 10.,  17.,   2.]], dtype=float32)]