Skip to content

Commit

Permalink
ENH add SDE class (#1269)
Browse files Browse the repository at this point in the history
  • Loading branch information
maedoc authored and twiecki committed Sep 26, 2016
1 parent fc71f68 commit 0ca78be
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion pymc3/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .continuous import Normal, Flat
from .distribution import Continuous

__all__ = ['AR1', 'GaussianRandomWalk', 'GARCH11']
__all__ = ['AR1', 'GaussianRandomWalk', 'GARCH11', 'EulerMaruyama']


class AR1(Continuous):
Expand Down Expand Up @@ -124,3 +124,30 @@ def logp(self, x):
vol = self._get_volatility(x[:-1])
return (Normal.dist(0., sd=self.initial_vol).logp(x[0]) +
tt.sum(Normal.dist(0, sd=vol).logp(x[1:])))


class EulerMaruyama(Continuous):
"""
Stochastic differential equation discretized with the Euler-Maruyama method.
Parameters
----------
dt : float
time step of discretization
sde_fn : callable
function returning the drift and diffusion coefficients of SDE
sde_pars : tuple
parameters of the SDE, passed as *args to sde_fn
"""
def __init__(self, dt, sde_fn, sde_pars, *args, **kwds):
super(EulerMaruyama, self).__init__(*args, **kwds)
self.dt = dt
self.sde_fn = sde_fn
self.sde_pars = sde_pars

def logp(self, x):
xt = x[:-1]
f, g = self.sde_fn(x[:-1], *self.sde_pars)
mu = xt + self.dt * f
sd = tt.sqrt(self.dt) * g
return tt.sum(Normal.dist(mu=mu, sd=sd).logp(x[1:]))

0 comments on commit 0ca78be

Please sign in to comment.