In [1]:
import os

import tensorflow as tf
import numpy as np
from scipy.integrate import odeint
from scipy.interpolate import interp1d

import time

In [2]:
def odeSolve(nPat, interpolate=0):

    tspan = np.linspace(0, 100, 101)
    
    fj    = np.array([12 , 7, 15] * nPat).reshape(nPat, -1)
    rj    = np.array([6  , 3,  8] * nPat).reshape(nPat, -1)
    mj    = np.array([10 , 17, 2] * nPat).reshape(nPat, -1)
    
    Aj_const = np.array([3, 3, 3] * nPat).reshape(nPat, -1)
    Bj_const = np.array([3, 3, 3] * nPat).reshape(nPat, -1)
    
    Atimesj  = []
    Btimesj  = []
    for i in range(nPat):
        
        tmp_doseA, tmp_doseB   = np.zeros(shape=tspan.shape), np.zeros(shape=tspan.shape)
        
        for trange, dose in [ ([  5, 15],    3 ),
                              ([ 35, 50],   35 ),
                              ([ 50, 60],    3 ),
                              ([ 60, 75],  300 ),
                              ([ 75, 80],  7.6 ) ]:
            twindow            = range(trange[0], trange[1] + 1)
            tmp_doseA[twindow] = dose

        for trange, dose in [ ([  5, 15], 70   ),
                              ([ 35, 50], 12.5 ),
                              ([ 75, 80], 7.6  ) ]:
            twindow            = range(trange[0], trange[1] + 1)
            tmp_doseB[twindow] = dose   

        Atimesj.append(tmp_doseA)
        Btimesj.append(tmp_doseB)
    
    Atimesj  = np.array(Atimesj).reshape(nPat, -1)
    Btimesj  = np.array(Btimesj).reshape(nPat, -1)
    AjInterp = interp1d(tspan, Atimesj, bounds_error=False, fill_value=(Atimesj[:,0], Atimesj[:,-1]))
    BjInterp = interp1d(tspan, Btimesj, bounds_error=False, fill_value=(Btimesj[:,0], Btimesj[:,-1]))
    
    np.random.seed(seed=1234)
    NNwts    = [ np.random.random(size=(5, 12)), 
                 np.random.random(size=(12, 3)), 
                 np.random.random(size=(3,  2))  ]
    NNb      = [ 0, 1, -1 ]
    Taus     = np.array([1, 4 ] * nPat).reshape(nPat, -1)
    NNact    = [ np.tanh, np.tanh, np.tanh ]

    upSizeToLeft = np.array([[1, 0, 0, 0, 0],
                             [0, 1, 0, 0, 0],
                             [0, 0, 1, 0, 0]])
    downSizeLeftPart = np.array([[1, 0, 0],
                                 [0, 1, 0],
                                 [0, 0, 1],
                                 [0, 0, 0],
                                 [0, 0, 0]])

    upSizeToRight = np.array([[0, 0, 0, 1, 0],
                              [0, 0, 0, 0, 1]])
    downSizeRightPart = np.array([[0, 0],
                                  [0, 0],
                                  [0, 0],
                                  [1, 0],
                                  [0, 1]])
    
    upSize3_2 = np.array([[1, 1, 1, 0, 0]])
    zerosMatrix = np.zeros(shape=(nPat, 2))
                        
    fj_padded = np.dot(fj, upSizeToLeft)
    rj_padded = np.dot(rj, upSizeToLeft)
    mj_padded = np.dot(mj, upSizeToLeft)
    Aj_padded = np.dot(Aj_const, upSizeToLeft)
    Bj_padded = np.dot(Bj_const, upSizeToLeft)

    def rhs(y, t, downSizeRightPart, Taus, upSizeToRight, upSize3_2,
                  fj_padded, rj_padded, mj_padded, Aj_padded, Bj_padded):

        try:
            Nnt_matrix = np.array(y).reshape(nPat, -1)

            for index, (w, b, a) in enumerate(zip(NNwts, NNb, NNact)):
                if index == 0:
                    nn_res = np.dot(Nnt_matrix, w) + b
                else:
                    nn_res = np.dot(nn_res, w) + b
                nn_res = a(nn_res)

            stressMatrix = np.dot(Nnt_matrix, downSizeRightPart)
            nn_res       = nn_res - stressMatrix / Taus
            nn_res       = np.dot(nn_res, upSizeToRight)
                        
            if   interpolate == 1:
                Aj        = AjInterp(t).reshape(nPat, -1)
                Aj        = np.dot(Aj, upSize3_2)
                Bj        = BjInterp(t).reshape(nPat, -1)
                Bj        = np.dot(Bj, upSize3_2)
                
                meds_res  = fj_padded - rj_padded * Nnt_matrix / (1 + Aj) \
                                      - mj_padded * Nnt_matrix / (1 + Bj)
                
            elif interpolate == 2:
                Aj        = AjInterp(t).reshape(nPat, -1)
                Aj        = np.concatenate([Aj, Aj, Aj, zerosMatrix], axis=1)
                Bj        = BjInterp(t).reshape(nPat, -1)
                Bj        = np.concatenate([Bj, Bj, Bj, zerosMatrix], axis=1)
                
                meds_res  = fj_padded - rj_padded * Nnt_matrix / (1 + Aj) \
                                      - mj_padded * Nnt_matrix / (1 + Bj)
                
            else:
                meds_res  = fj_padded - rj_padded * Nnt_matrix / (1 + Aj_padded) \
                                      - mj_padded * Nnt_matrix / (1 + Bj_padded)

            results   = nn_res + meds_res
            results   = results.flatten()

        except Exception as e:
            print(t, str(e))

        return results
    
    args      = (downSizeRightPart, Taus, upSizeToRight, upSize3_2, fj_padded, rj_padded, mj_padded, Aj_padded, Bj_padded)
    start     = time.time()
    y, report = odeint(rhs, y0   = np.array([1, 1, 1, 2, 2] * nPat), 
                            t    = np.linspace(0, 100, 101), 
                            args = args,
                            full_output = True)
    timeCost  = time.time() - start

    return y, timeCost, report

In [3]:
for intepolate in [1, 2]:
    print('-'*30)
    print('interpolate', intepolate)
    for n in [1, 10, 100]:
        y, timeCost, report = odeSolve(nPat=n, interpolate=intepolate)
        print(report['nst'][-1], report['nfe'][-1], report['nje'][-1])
        print('N', n, 'timeCost', timeCost, 'per User', timeCost / n)

------------------------------
interpolate 1
1702 4216 110
N 1 timeCost 0.48931241035461426 per User 0.48931241035461426
1810 11553 154
N 10 timeCost 2.1548264026641846 per User 0.21548264026641845
1736 68721 130
N 100 timeCost 16.745413064956665 per User 0.16745413064956666
------------------------------
interpolate 2
1702 4216 110
N 1 timeCost 0.5382077693939209 per User 0.5382077693939209
1810 11553 154
N 10 timeCost 2.199556827545166 per User 0.2199556827545166
1736 68721 130
N 100 timeCost 16.899113416671753 per User 0.16899113416671752


In [4]:
0.47612524032592773 / 3591, 22.872165203094482 / 97042

(0.00013258848240766576, 0.00023569346471728203)