In [1]:
from itertools import islice

from scipy.integrate import odeint
from theano import tensor as T
from theano.ifelse import ifelse

import theano
import numpy as np

counter = 0

def dpds1(p, v, u, t, g=9.81, m=1):
    global counter
    counter += 1
    
    if p < 0.:
        d1 = 2 * p + 1
        d2 = 2
    else:
        z = 1 + 5 * p ** 2
        d1 = 1 / (z ** 1.5)
        d2 = (-15 * p) / (z ** 2.5)
    
    a = (u / m - g * d1 - v**2 * d1 * d2) / (1 + d1 ** 2)
    return [v, a, 0]

dpds_odeint = lambda sa, t: dpds1(*sa, t)

pp = theano.printing.pp
c = lambda x: T.constant(x, dtype='float64')


def diff_hills(p):
    z = 1 + 5 * p ** 2
    d1 = [2 * p + 1, 1 / (z ** 1.5)]
    d2 = [c(2), (-15 * p) / (z ** 2.5)]
    
    return ifelse(p < 0, [d1[0], d2[0]], [d1[1], d2[1]])

def dpds_t(p, v, u, t, g=c(9.81), m=1):
    d1, d2 = diff_hills(p)
    a = (u / m - g * d1 - v**2 * d1 * d2) / (1 + d1 ** 2)
    return [v, a, c(0)]


def euler(func, y0, t0=0, step=0.001):
    yield y0
    while True:
        t0 += step
        # avoiding += otherwise y0 gets updated in place
        y0 = y0 + np.multiply(func(*y0, t0), step)
        yield y0

In [2]:
u = T.iscalar('u')
p, v, t = T.dscalars('p', 'v', 't')

dpds2 = theano.function([p, v, u, t], dpds_t(p, v, u, t), on_unused_input='ignore')

In [3]:
x = np.array([-0.5, 0, -4])

In [4]:
dpds1(*x, 0)

[0.0, -4.0, 0]

In [5]:
dpds2(*x, 0)

[array(0.0), array(-4.0), array(0.0)]

In [6]:
f1 = euler(dpds1, x)
f2 = euler(dpds2, x)

In [7]:
counter = 0

In [8]:
list(islice(f1, 0, 101, 100))

[array([-0.5,  0. , -4. ]), array([-0.5194885 , -0.38708314, -4.        ])]

In [9]:
counter

100

In [10]:
list(islice(f2, 0, 101, 100))

[array([-0.5,  0. , -4. ]), array([-0.5194885 , -0.38708314, -4.        ])]

In [11]:
counter = 0

In [12]:
odeint(dpds_odeint, x, [.0, .1])

array([[-0.5       ,  0.        , -4.        ],
       [-0.51966909, -0.3866922 , -4.        ]])

In [13]:
counter

49

In [14]:
counter = 0

In [15]:
odeint(dpds_odeint, x, [.0, .1], hmin=0.001, hmax=0.001)

array([[-0.5       ,  0.        , -4.        ],
       [-0.51966909, -0.38669218, -4.        ]])

In [16]:
counter

204