In [1]:
import os

import tensorflow as tf
import numpy as np
from scipy.integrate import odeint
from scipy.interpolate import interp1d
from scipy import signal

import time

In [2]:
def odeSolve(nPat):

    tspan = np.linspace(0, 100, 101)
    
    Atimesj   = []
    Btimesj   = []
    fj        = []
    rj        = []
    mj        = []
    stress_v  = []
    
    for i in range(nPat):
        
        tmp_doseA, tmp_doseB   = np.zeros(shape=tspan.shape), np.zeros(shape=tspan.shape)
        
        for trange, dose in [ ([  5, 15],    3 ),
                              ([ 35, 50],   35 ),
                              ([ 50, 60],    3 ),
                              ([ 60, 75],  300 ),
                              ([ 75, 80],  7.6 ) ]:
            twindow            = range(trange[0], trange[1] + 1)
            tmp_doseA[twindow] = dose

        for trange, dose in [ ([  5, 15], 70   ),
                              ([ 35, 50], 12.5 ),
                              ([ 75, 80], 7.6  ) ]:
            twindow            = range(trange[0], trange[1] + 1)
            tmp_doseB[twindow] = dose   

        Atimesj.extend( [tmp_doseA, tmp_doseA, tmp_doseA] )
        Btimesj.extend( [tmp_doseB, tmp_doseB, tmp_doseB] )
        
        fj.append(np.array([12,  7, 15 ]))
        rj.append(np.array([6,   3, 8  ]))
        mj.append(np.array([10, 17, 2  ]))
        stress_v.extend( [signal.square(2 * np.pi * tspan / 20.0) * 3, 
                          signal.square(2 * np.pi * tspan / 15.0) * 3,
                          signal.square(2 * np.pi * tspan / 10.0) * 3] )
    
    Atimesj  = np.array(Atimesj).reshape(nPat, 3, -1)
    Btimesj  = np.array(Btimesj).reshape(nPat, 3, -1)
    fj       = np.array(fj).reshape(nPat, -1)
    rj       = np.array(rj).reshape(nPat, -1)
    mj       = np.array(mj).reshape(nPat, -1)
    stress_v = np.array(stress_v).reshape(nPat, 3, -1)
        
    AjInterp     = interp1d(tspan, Atimesj,  bounds_error=False, fill_value=(Atimesj[:,:,0],  Atimesj[:,:,-1]))
    BjInterp     = interp1d(tspan, Btimesj,  bounds_error=False, fill_value=(Btimesj[:,:,0],  Btimesj[:,:,-1]))
    stressInterp = interp1d(tspan, stress_v, bounds_error=False, fill_value=(stress_v[:,:,0], stress_v[:,:,-1]))
    
    np.random.seed(seed=1234)
    NNwts    = [ np.random.random(size=(6, 12)), 
                 np.random.random(size=(12, 3)), 
                 np.random.random(size=(3,  3))  ]
    NNb      = [ 0, 1, -1 ]
    Taus     = np.array([ 1, 4, 12 ] * nPat).reshape(nPat, -1)
    NNact    = [ np.tanh, np.tanh, np.tanh ]
    
    Nnt      = 3

    def rhs(y, t, NNwts, NNb, NNact, Taus):

        try:
            y           = np.array(y).reshape(nPat, -1)
            Nnt_val     = y[:, :Nnt]
            nn_inputs   = np.concatenate([Nnt_val, stressInterp(t)], axis=1)

            for index, (w, b, a) in enumerate(zip(NNwts, NNb, NNact)):
                if index == 0:
                    nn_res = np.dot(nn_inputs, w) + b
                else:
                    nn_res = np.dot(nn_res, w) + b
                nn_res = a(nn_res)

            nn_res    = nn_res - y[:, Nnt:] / Taus
            meds_res  = fj - rj * Nnt_val / (1 + AjInterp(t)) - mj * Nnt_val / (1 + BjInterp(t))

            result    = np.concatenate([nn_res, meds_res], axis=1)
            result    = result.flatten()

        except Exception as e:
            print(t, str(e))

        return result
    
    args      = (NNwts, NNb, NNact, Taus)
    start     = time.time()
    y, report = odeint(rhs, y0   = np.array([1, 1, 1, 2, 2, 2] * nPat), 
                            t    = np.linspace(0, 100, 101), 
                            args = args,
                            full_output = True)
    timeCost  = time.time() - start

    return y, timeCost, report

In [3]:
for n in [1, 10, 100]:
    y, timeCost, report = odeSolve(nPat=n)
    print(report['nst'][-1], report['nfe'][-1], report['nje'][-1])
    print('N', n, 'timeCost', timeCost, 'per User', timeCost / n)

2625 5589 0
N 1 timeCost 0.8537545204162598 per User 0.8537545204162598
2614 5584 0
N 10 timeCost 0.909440279006958 per User 0.0909440279006958
2614 5584 0
N 100 timeCost 1.2137401103973389 per User 0.012137401103973388
