In [1]:
import os

import tensorflow as tf
import numpy as np
from scipy.integrate import odeint
from scipy.interpolate import interp1d

import time

# os.environ["CUDA_VISIBLE_DEVICES"] = ''

In [2]:
def odeSolve(nPat, interpolate=0, dtype=tf.float32):

    tspan    = np.linspace(0, 100, 101)
    Atimesj  = []
    Btimesj  = []
    for i in range(nPat):
        
        tmp_doseA, tmp_doseB   = np.zeros(shape=tspan.shape), np.zeros(shape=tspan.shape)
        
        for trange, dose in [ ([  5, 15],    3 ),
                              ([ 35, 50],   35 ),
                              ([ 50, 60],    3 ),
                              ([ 60, 75],  300 ),
                              ([ 75, 80],  7.6 ) ]:
            twindow            = range(trange[0], trange[1] + 1)
            tmp_doseA[twindow] = dose

        for trange, dose in [ ([  5, 15], 70   ),
                              ([ 35, 50], 12.5 ),
                              ([ 75, 80], 7.6  ) ]:
            twindow            = range(trange[0], trange[1] + 1)
            tmp_doseB[twindow] = dose   

        Atimesj.append(tmp_doseA)
        Btimesj.append(tmp_doseB)
    
    Atimesj  = np.array(Atimesj).reshape(nPat, -1)
    Btimesj  = np.array(Btimesj).reshape(nPat, -1)
    AjInterp = interp1d(tspan, Atimesj, bounds_error=False, fill_value=(Atimesj[:,0], Atimesj[:,-1]))
    BjInterp = interp1d(tspan, Btimesj, bounds_error=False, fill_value=(Btimesj[:,0], Btimesj[:,-1]))

    tf.reset_default_graph()
    
    with tf.device('/device:CPU:0'):
        
        with tf.variable_scope('parameters'):
            fj_tf       = tf.Variable(np.array([12 , 7, 15] * nPat).reshape(nPat, -1), dtype=dtype, name='fj')
            rj_tf       = tf.Variable(np.array([6  , 3,  8] * nPat).reshape(nPat, -1), dtype=dtype, name='rj')
            mj_tf       = tf.Variable(np.array([10 , 17, 2] * nPat).reshape(nPat, -1), dtype=dtype, name='mj')
            Aj_const_tf = tf.Variable(np.array([3, 3, 3] * nPat).reshape(nPat, -1), dtype=dtype, name='Aj_const')
            Bj_const_tf = tf.Variable(np.array([3, 3, 3] * nPat).reshape(nPat, -1), dtype=dtype, name='Bj_const')
            
            np.random.seed(seed=1234)
            NNwts_tf    = [ tf.Variable(np.random.random(size=(5, 12)), dtype=dtype, name='w1'), 
                            tf.Variable(np.random.random(size=(12, 3)), dtype=dtype, name='w2'), 
                            tf.Variable(np.random.random(size=(3,  2)), dtype=dtype, name='w3')   ]
            NNb_tf      = [ tf.Variable(bias, dtype=dtype, name='bias{}'.format(i)) 
                                                                       for i, bias in enumerate([ 0, 1, -1 ]) ]
            Taus_tf     = tf.Variable(np.array([1, 4 ] * nPat).reshape(nPat, -1), dtype=dtype, name='Taus') 
            NNact_tf    = [ tf.nn.tanh, tf.nn.tanh, tf.nn.tanh ]
            
            upSizeToLeft = tf.Variable(np.array([[1, 0, 0, 0, 0],
                                                 [0, 1, 0, 0, 0],
                                                 [0, 0, 1, 0, 0]]),
                                       dtype=dtype,
                                       name='upSizeToLeft')
            downSizeLeftPart = tf.Variable(np.array([[1, 0, 0],
                                                     [0, 1, 0],
                                                     [0, 0, 1],
                                                     [0, 0, 0],
                                                     [0, 0, 0]]),
                                            dtype=dtype,
                                            name='downSizeLeftPart')
            
            upSizeToRight = tf.Variable(np.array([[0, 0, 0, 1, 0],
                                                  [0, 0, 0, 0, 1]]),
                                        dtype=dtype,
                                        name='upSizeToRight')
            downSizeRightPart = tf.Variable(np.array([[0, 0],
                                                      [0, 0],
                                                      [0, 0],
                                                      [1, 0],
                                                      [0, 1]]),
                                           dtype=dtype,
                                           name='downSizeRightPart')
            upSize3_2 = tf.Variable(np.array([[1, 1, 1, 0, 0]]), dtype=dtype, name='upSize3_2')
            
            zerosMatrix = tf.Variable(np.zeros(shape=(nPat, 2)), dtype=dtype, name='zeroM1')

                                                
        with tf.variable_scope('input'):
            Nnt_matrix  = tf.placeholder(dtype=dtype, name='Nnt_matrix')

        
        with tf.variable_scope('neuralnetwork'):
            for index, (w, b, a) in enumerate(zip(NNwts_tf, NNb_tf, NNact_tf)):
                if index == 0:
                    nn_res = tf.matmul(Nnt_matrix, w) + b
                else:
                    nn_res = tf.matmul(nn_res, w) + b
                nn_res = a(nn_res)
                            
            stressMatrix = tf.matmul(Nnt_matrix, downSizeRightPart, name='downSize_Nntmatrix')
            nn_res       = nn_res - stressMatrix / Taus_tf
            nn_res       = tf.matmul(nn_res, upSizeToRight, name='upSize_nnRes')
        
        with tf.variable_scope('definedEq_meds'):
            fj_padded = tf.matmul(fj_tf,       upSizeToLeft, name='upSize_fj')
            rj_padded = tf.matmul(rj_tf,       upSizeToLeft, name='upSize_rj')
            mj_padded = tf.matmul(mj_tf,       upSizeToLeft, name='upSize_mj')
            Aj_padded = tf.matmul(Aj_const_tf, upSizeToLeft, name='upSize_Aj')
            Bj_padded = tf.matmul(Bj_const_tf, upSizeToLeft, name='upSize_Bj')
            
            if   interpolate == 1:
                Aj_interpolate  = tf.placeholder(dtype=dtype, name='Aj_interpolate')
                Bj_interpolate  = tf.placeholder(dtype=dtype, name='Bj_interpolate')
                Aj_matrix       = tf.matmul(Aj_interpolate, upSize3_2, name='upSizeInterpolatedAj')
                Bj_matrix       = tf.matmul(Bj_interpolate, upSize3_2, name='upSizeInterpolatedBj')
                
                meds_res        = fj_padded - rj_padded * Nnt_matrix / (1 + Aj_matrix) \
                                            - mj_padded * Nnt_matrix / (1 + Bj_matrix)
            elif interpolate == 2:
                Aj_interpolate  = tf.placeholder(dtype=dtype, name='Aj_interpolate')
                Bj_interpolate  = tf.placeholder(dtype=dtype, name='Bj_interpolate')
                Aj_matrix       = tf.concat( [Aj_interpolate, Aj_interpolate, Aj_interpolate, zerosMatrix], axis=1 )
                Bj_matrix       = tf.concat( [Bj_interpolate, Bj_interpolate, Bj_interpolate, zerosMatrix], axis=1 )
                
                meds_res        = fj_padded - rj_padded * Nnt_matrix / (1 + Aj_matrix) \
                                            - mj_padded * Nnt_matrix / (1 + Bj_matrix)
            else:
                meds_res        = fj_padded - rj_padded * Nnt_matrix / (1 + Aj_padded) \
                                            - mj_padded * Nnt_matrix / (1 + Bj_padded)
            
        with tf.variable_scope('combineResults'):
            results     = nn_res + meds_res
            results     = tf.reshape(results, shape=[-1], name='flattenOps')
            
        init        = tf.global_variables_initializer()
        config      = tf.ConfigProto(
                           device_count = {'GPU': 0}
                      )
        sess        = tf.Session(config=config)
        sess.run( init )
        
    def rhs(y, t):

        try:
            if interpolate in [1, 2]:
                rhs_results = sess.run(results, 
                                       feed_dict={
                                           Nnt_matrix : np.array(y).reshape(nPat, -1),
                                           Aj_interpolate : AjInterp(t).reshape(nPat, -1),
                                           Bj_interpolate : BjInterp(t).reshape(nPat, -1)
                                       })
            else:
                rhs_results = sess.run(results, 
                                       feed_dict={
                                           Nnt_matrix : np.array(y).reshape(nPat, -1)
                                       })
            return rhs_results

        except Exception as e:
            print(t, str(e))

        return rhs_results
        
    start     = time.time()
    y, report = odeint(rhs, y0=np.array([1, 1, 1, 2, 2] * nPat), t=np.linspace(0, 100, 101), full_output=True)
    timeCost  = time.time() - start
    
    return y, timeCost, report

In [3]:
for n in [1, 10, 100]:
    y, timeCost, report = odeSolve(nPat=n, interpolate=1, dtype=tf.float64)
    print(report['nst'][-1], report['nfe'][-1], report['nje'][-1])
    print('N', n, 'timeCost', timeCost, 'per User', timeCost / n)

1810 4623 154
N 1 timeCost 4.62121057510376 per User 4.62121057510376
1810 11553 154
N 10 timeCost 9.35519552230835 per User 0.9355195522308349
1736 68721 130
N 100 timeCost 71.95516705513 per User 0.7195516705513001
