In [1]:
import os

import tensorflow as tf
import numpy as np
from scipy.integrate import odeint
from scipy.interpolate import interp1d

import time

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [2]:
def odeSolve(nPat):

    Atimesj  = []
    Btimesj  = []
    tspan    = np.linspace(0, 100, 101)

    for i in range(nPat):
        tmp_doseA, tmp_doseB   = np.zeros(shape=tspan.shape), np.zeros(shape=tspan.shape)

        for trange, dose in [ ([  5, 15],    3 ),
                              ([ 35, 50],   35 ),
                              ([ 50, 60],    3 ),
                              ([ 60, 75],  300 ),
                              ([ 75, 80],  7.6 ) ]:
            twindow            = range(trange[0], trange[1] + 1)
            tmp_doseA[twindow] = dose

        for trange, dose in [ ([  5, 15], 70   ),
                              ([ 35, 50], 12.5 ),
                              ([ 75, 80], 7.6  ) ]:
            twindow            = range(trange[0], trange[1] + 1)
            tmp_doseB[twindow] = dose   

        Atimesj.append(tmp_doseA)
        Btimesj.append(tmp_doseB)

    Atimesj     = np.array(Atimesj).reshape(nPat, -1)
    Btimesj     = np.array(Btimesj).reshape(nPat, -1)

    AjInterp    = [interp1d(tspan, a_vec, bounds_error=False, fill_value=(a_vec[0], a_vec[-1])) for a_vec in Atimesj]
    BjInterp    = [interp1d(tspan, b_vec, bounds_error=False, fill_value=(b_vec[0], b_vec[-1])) for b_vec in Btimesj]
    AjInterp2   = interp1d(tspan, Atimesj, bounds_error=False, fill_value=(Atimesj[:,0], Atimesj[:,-1]))
    BjInterp2   = interp1d(tspan, Btimesj, bounds_error=False, fill_value=(Btimesj[:,0], Btimesj[:,-1]))

    tf.reset_default_graph()
    
    with tf.device('/device:GPU:0'):
        
        fj_tf       = tf.Variable(np.hstack([np.array([12 , 7, 15 ])] * nPat).reshape(nPat, -1), dtype=tf.float32, name='fj')
        rj_tf       = tf.Variable(np.hstack([np.array([6  , 3,  8 ])] * nPat).reshape(nPat, -1), dtype=tf.float32, name='rj')
        mj_tf       = tf.Variable(np.hstack([np.array([10 , 17, 2 ])] * nPat).reshape(nPat, -1), dtype=tf.float32, name='mj')

        Nnt_list    = tf.placeholder(dtype=tf.float32, name='Nnt_list')
    #     Aj          = tf.placeholder(dtype=tf.float32, name='Aj')
    #     Bj          = tf.placeholder(dtype=tf.float32, name='Bj')

    #     results     = fj_tf - rj_tf * Nnt_list / (1 + Aj) - mj_tf * Nnt_list / (1 + Bj)
        results     = fj_tf - rj_tf * Nnt_list - mj_tf * Nnt_list
        results     = tf.reshape(results, shape=[-1], name='flattenOps')
        init        = tf.global_variables_initializer()
        sess        = tf.Session()
        sess.run( init )
        
    def rhs(y, t):

        try:
            rhs_results = sess.run(results, 
                                   feed_dict={
                                       Nnt_list : np.array(y).reshape(nPat, -1)
    #                                    Aj : np.array([interp(t) for interp in AjInterp]).reshape(nPat, -1),
    #                                    Bj : np.array([interp(t) for interp in BjInterp]).reshape(nPat, -1),
                                   })
            return rhs_results

        except Exception as e:
            print(t, str(e))

        return rhs_results
    
    start        = time.time()
    Aj_t         = np.array([interp(1.231) for interp in AjInterp]).reshape(nPat, -1)
    timeCost_Aj  = time.time() - start
                           
    start        = time.time()
    Aj_t         = AjInterp2(1.231).reshape(nPat, -1)
    timeCost_Aj2 = time.time() - start
    
    start     = time.time()
    y, report = odeint(rhs, y0=np.array([1, 1, 1] * nPat), t=np.linspace(0, 100, 101), full_output=True)
    timeCost  = time.time() - start
    
    return y, timeCost, report, timeCost_Aj, timeCost_Aj2

In [3]:
for n in [1, 10, 100, 1000]:
    y, timeCost, report, timeCost_Aj, timeCost_Aj2 = odeSolve(nPat=n)
    print(report['nst'][-1], report['nfe'][-1], report['nje'][-1])
    print('N', n, 'timeCost', timeCost, 'per User', timeCost / n,  timeCost_Aj * report['nfe'][-1], timeCost_Aj2 * report['nfe'][-1])

163 386 25
N 1 timeCost 0.20180559158325195 per User 0.20180559158325195 0.06690549850463867 0.06432867050170898
163 1061 25
N 10 timeCost 0.6502184867858887 per User 0.06502184867858887 0.284076452255249 0.08727192878723145
163 7811 25
N 100 timeCost 3.5580973625183105 per User 0.03558097362518311 19.904129028320312 0.696495532989502
163 75311 25
N 1000 timeCost 55.17753314971924 per User 0.05517753314971924 1127.0693469047546 7.128350019454956
