In [1]:
import os

import tensorflow as tf
import numpy as np
from scipy.integrate import odeint
from scipy.interpolate import interp1d

import time

In [2]:
def odeSolve(nPat):

    Atimesj  = []
    Btimesj  = []
    tspan    = np.linspace(0, 100, 101)

    for i in range(nPat):
        tmp_doseA, tmp_doseB   = np.zeros(shape=tspan.shape), np.zeros(shape=tspan.shape)

        for trange, dose in [ ([  5, 15],    3 ),
                              ([ 35, 50],   35 ),
                              ([ 50, 60],    3 ),
                              ([ 60, 75],  300 ),
                              ([ 75, 80],  7.6 ) ]:
            twindow            = range(trange[0], trange[1] + 1)
            tmp_doseA[twindow] = dose

        for trange, dose in [ ([  5, 15], 70   ),
                              ([ 35, 50], 12.5 ),
                              ([ 75, 80], 7.6  ) ]:
            twindow            = range(trange[0], trange[1] + 1)
            tmp_doseB[twindow] = dose   

        Atimesj.append(tmp_doseA)
        Btimesj.append(tmp_doseB)

    Atimesj     = np.array(Atimesj).reshape(nPat, -1)
    Btimesj     = np.array(Btimesj).reshape(nPat, -1)

    AjInterp    = [interp1d(tspan, a_vec, bounds_error=False, fill_value=(a_vec[0], a_vec[-1])) for a_vec in Atimesj]
    BjInterp    = [interp1d(tspan, b_vec, bounds_error=False, fill_value=(b_vec[0], b_vec[-1])) for b_vec in Btimesj]
    
    fj          = np.hstack([np.array([12 , 7, 15 ])] * nPat).reshape(nPat, -1)
    rj          = np.hstack([np.array([6  , 3,  8 ])] * nPat).reshape(nPat, -1)
    mj          = np.hstack([np.array([10 , 17, 2 ])] * nPat).reshape(nPat, -1)
        
    def rhs(y, t, fj, rj, mj):

        try:
            Nnt      = np.array(y).reshape(nPat, -1)
#             Aj       = np.array([interp(t) for interp in AjInterp]).reshape(nPat, -1)
#             Bj       = np.array([interp(t) for interp in BjInterp]).reshape(nPat, -1)
            
#             results  = fj - rj * Nnt / (1 + Aj) - mj * Nnt / (1 + Bj)
            results  = fj - rj * Nnt - mj * Nnt
            results  = results.flatten()
  
            return results

        except Exception as e:
            print(t, str(e))

        return rhs_results
        
    args      = (fj, rj, mj)
    
#     start       = time.time()
#     Aj_t        = np.array([interp(1.231) for interp in AjInterp]).reshape(nPat, -1)
#     timeCost_Aj = time.time() - start
        
    start     = time.time()
    y, report = odeint(rhs, y0=np.array([1, 1, 1] * nPat), t=np.linspace(0, 100, 101), args=args, full_output=True)
    timeCost  = time.time() - start
    
    return y, timeCost, report

In [3]:
for n in [1, 10, 100, 1000]:
    y, timeCost, report = odeSolve(nPat=n)
    print(report['nst'][-1], report['nfe'][-1], report['nje'][-1])
    print('N', n, 'timeCost', timeCost, 'per User', timeCost / n)

120 249 4
N 1 timeCost 0.002711772918701172 per User 0.002711772918701172
120 357 4
N 10 timeCost 0.005485057830810547 per User 0.0005485057830810547
120 1437 4
N 100 timeCost 0.023775100708007812 per User 0.00023775100708007814
120 12237 4
N 1000 timeCost 1.2202341556549072 per User 0.0012202341556549073
