Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions pymc3/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pymc3 as pm
from itertools import product
import theano.tensor as tt
from theano.compile.ops import as_op


def simple_model():
Expand Down Expand Up @@ -34,6 +35,20 @@ def multidimensional_model():
return model.test_point, model, (mu, tau ** -1)


def simple_arbitrary_det():
@as_op(itypes=[tt.dscalar], otypes=[tt.dscalar])
def arbitrary_det(value):
return value

with Model() as model:
a = Normal('a')
b = arbitrary_det(a)
c = Normal('obs', mu=b.astype('float64'),
observed=np.array([1, 3, 5]))

return model.test_point, model


def simple_init():
start, model, moments = simple_model()
step = Metropolis(model.vars, np.diag([1.]), model=model)
Expand Down
9 changes: 7 additions & 2 deletions pymc3/tests/test_starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import numpy as np
from pymc3.tuning import starting
from pymc3 import Model, Uniform, Normal, Beta, Binomial, find_MAP, Point
from .models import simple_model, non_normal, exponential_beta

from .models import simple_model, non_normal, exponential_beta, simple_arbitrary_det

def test_accuracy_normal():
_, model, (mu, _) = simple_model()
Expand Down Expand Up @@ -53,6 +52,12 @@ def test_find_MAP_discrete():
assert map_est2['ss'] == 14


def test_find_MAP_no_gradient():
_, model = simple_arbitrary_det()
with model:
find_MAP()


def test_find_MAP():
tol = 2.0**-11 # 16 bit machine epsilon, a low bar
data = np.random.randn(100)
Expand Down
46 changes: 31 additions & 15 deletions pymc3/tuning/starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
__all__ = ['find_MAP']


def find_MAP(start=None, vars=None, fmin=None, return_raw=False,
model=None, *args, **kwargs):
def find_MAP(start=None, vars=None, fmin=None,
return_raw=False,model=None, *args, **kwargs):
"""
Sets state to the local maximum a posteriori point given a model.
Current default of fmin_Hessian does not deal well with optimizing close
Expand Down Expand Up @@ -55,8 +55,15 @@ def find_MAP(start=None, vars=None, fmin=None, return_raw=False,

disc_vars = list(typefilter(vars, discrete_types))

if disc_vars:
pm._log.warning("Warning: vars contains discrete variables. MAP " +
try:
model.fastdlogp(vars)
gradient_avail = True
except AttributeError:
gradient_avail = False

if disc_vars or not gradient_avail :
pm._log.warning("Warning: gradient not available." +
"(E.g. vars contains discrete variables). MAP " +
"estimates may not be accurate for the default " +
"parameters. Defaulting to non-gradient minimization " +
"fmin_powell.")
Expand All @@ -74,19 +81,21 @@ def find_MAP(start=None, vars=None, fmin=None, return_raw=False,
bij = DictToArrayBijection(ArrayOrdering(vars), start)

logp = bij.mapf(model.fastlogp)
dlogp = bij.mapf(model.fastdlogp(vars))

def logp_o(point):
return nan_to_high(-logp(point))

def grad_logp_o(point):
return nan_to_num(-dlogp(point))

# Check to see if minimization function actually uses the gradient
if 'fprime' in getargspec(fmin).args:
dlogp = bij.mapf(model.fastdlogp(vars))
def grad_logp_o(point):
return nan_to_num(-dlogp(point))

r = fmin(logp_o, bij.map(
start), fprime=grad_logp_o, *args, **kwargs)
compute_gradient = True
else:
compute_gradient = False

# Check to see if minimization function uses a starting value
if 'x0' in getargspec(fmin).args:
r = fmin(logp_o, bij.map(start), *args, **kwargs)
Expand All @@ -100,17 +109,24 @@ def grad_logp_o(point):

mx = bij.rmap(mx0)

if (not allfinite(mx0) or
not allfinite(model.logp(mx)) or
not allfinite(model.dlogp()(mx))):
allfinite_mx0 = allfinite(mx0)
allfinite_logp = allfinite(model.logp(mx))
if compute_gradient:
allfinite_dlogp = allfinite(model.dlogp()(mx))
else:
allfinite_dlogp = True

if (not allfinite_mx0 or
not allfinite_logp or
not allfinite_dlogp):

messages = []
for var in vars:

vals = {
"value": mx[var.name],
"logp": var.logp(mx),
"dlogp": var.dlogp()(mx)}
"logp": var.logp(mx)}
if compute_gradient:
vals["dlogp"] = var.dlogp()(mx)

def message(name, values):
if np.size(values) < 10:
Expand Down