Skip to content

Commit

Permalink
Add MAP estimates for transformed and untransformed variables (#2523)
Browse files Browse the repository at this point in the history
* Add MAP estimates for transformed and untransformed variables

* Fix docstring

* Fix lint
  • Loading branch information
kyleabeauchamp authored and Junpeng Lao committed Aug 22, 2017
1 parent d01aaf0 commit 1f34a5f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 137 deletions.
16 changes: 16 additions & 0 deletions pymc3/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,19 @@ def beta_bernoulli(n=2):
pm.Beta('x', 3, 1, shape=n, transform=None)
pm.Bernoulli('y', 0.5)
return model.test_point, model, None


def simple_normal(bounded_prior=False):
"""Simple normal for testing MLE / MAP; probes issue #2482."""
x0 = 10.0
sd = 1.0
a, b = (9, 12) # bounds for uniform RV, need non-symmetric to reproduce issue

with pm.Model() as model:
if bounded_prior:
mu_i = pm.Uniform("mu_i", a, b)
else:
mu_i = pm.Flat("mu_i")
pm.Normal("X_obs", mu=mu_i, sd=sd, observed=x0)

return model.test_point, model, None
20 changes: 19 additions & 1 deletion pymc3/tests/test_tuning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from numpy import inf
from pymc3.tuning import scaling
from pymc3.tuning import scaling, find_MAP
from . import models


Expand All @@ -14,3 +14,21 @@ def test_guess_scaling():
start, model, _ = models.non_normal(n=5)
a1 = scaling.guess_scaling(start, model=model)
assert all((a1 > 0) & (a1 < 1e200))


def test_mle_jacobian():
"""Test MAP / MLE estimation for distributions with flat priors."""
truth = 10.0 # Simple normal model should give mu=10.0

start, model, _ = models.simple_normal(bounded_prior=False)
with model:
map_estimate = find_MAP(model=model)

rtol = 1E-5 # this rtol should work on both floatX precisions
np.testing.assert_allclose(map_estimate["mu_i"], truth, rtol=rtol)

start, model, _ = models.simple_normal(bounded_prior=True)
with model:
map_estimate = find_MAP(model=model)

np.testing.assert_allclose(map_estimate["mu_i"], truth, rtol=rtol)
145 changes: 9 additions & 136 deletions pymc3/tuning/starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
from numpy import isfinite, nan_to_num, logical_not
import pymc3 as pm
import time
from ..vartypes import discrete_types, typefilter
from ..model import modelcontext, Point
from ..theanof import inputvars
Expand All @@ -20,7 +19,7 @@


def find_MAP(start=None, vars=None, fmin=None,
return_raw=False, model=None, live_disp=False, callback=None,
return_raw=False, model=None, callback=None,
*args, **kwargs):
"""
Sets state to the local maximum a posteriori point given a model.
Expand All @@ -31,20 +30,16 @@ def find_MAP(start=None, vars=None, fmin=None,
----------
start : `dict` of parameter values (Defaults to `model.test_point`)
vars : list
List of variables to set to MAP point (Defaults to all continuous).
List of variables to optimize and set to optimum (Defaults to all continuous).
fmin : function
Optimization algorithm (Defaults to `scipy.optimize.fmin_bfgs` unless
discrete variables are specified in `vars`, then
`scipy.optimize.fmin_powell` which will perform better).
return_raw : Bool
Whether to return extra value returned by fmin (Defaults to `False`)
model : Model (optional if in `with` context)
live_disp : Bool
Display table tracking optimization progress when run from within
an IPython notebook.
callback : callable
Callback function to pass to scipy optimization routine. Overrides
live_disp if callback is given.
Callback function to pass to scipy optimization routine.
*args, **kwargs
Extra args passed to fmin
"""
Expand Down Expand Up @@ -89,7 +84,8 @@ def find_MAP(start=None, vars=None, fmin=None,

start = Point(start, model=model)
bij = DictToArrayBijection(ArrayOrdering(vars), start)

logp_func = bij.mapf(model.fastlogp)
x0 = bij.map(start)
logp = bij.mapf(model.fastlogp_nojac)
def logp_o(point):
return nan_to_high(-logp(point))
Expand All @@ -100,15 +96,9 @@ def logp_o(point):
def grad_logp_o(point):
return nan_to_num(-dlogp(point))

if live_disp and callback is None:
callback = Monitor(bij, logp_o, model, grad_logp_o)

r = fmin(logp_o, bij.map(start), fprime=grad_logp_o, callback=callback, *args, **kwargs)
compute_gradient = True
else:
if live_disp and callback is None:
callback = Monitor(bij, logp_o, dlogp=None)

# Check to see if minimization function uses a starting value
if 'x0' in getargspec(fmin).args:
r = fmin(logp_o, bij.map(start), callback=callback, *args, **kwargs)
Expand All @@ -121,12 +111,6 @@ def grad_logp_o(point):
else:
mx0 = r

if live_disp:
try:
callback.update(mx0)
except:
pass

mx = bij.rmap(mx0)

allfinite_mx0 = allfinite(mx0)
Expand Down Expand Up @@ -171,13 +155,16 @@ def message(name, values):
"density. 2) your distribution logp's are " +
"properly specified. Specific issues: \n" +
specific_errors)
mx = {v.name: mx[v.name].astype(v.dtype) for v in model.vars}

vars = model.unobserved_RVs
mx = {var.name: value for var, value in zip(vars, model.fastfn(vars)(mx))}

if return_raw:
return mx, r
else:
return mx


def allfinite(x):
return np.all(isfinite(x))

Expand All @@ -192,120 +179,6 @@ def allinmodel(vars, model):
raise ValueError("Some variables not in the model: " + str(notin))



class Monitor(object):
def __init__(self, bij, logp, model, dlogp=None):
try:
from IPython.display import display
from ipywidgets import HTML, VBox, HBox, FlexBox
self.prog_table = HTML(width='100%')
self.param_table = HTML(width='100%')
r_col = VBox(children=[self.param_table], padding=3, width='100%')
l_col = HBox(children=[self.prog_table], padding=3, width='25%')
self.hor_align = FlexBox(children = [l_col, r_col], width='100%', orientation='vertical')
display(self.hor_align)
self.using_notebook = True
self.update_interval = 1
except:
self.using_notebook = False
self.update_interval = 2

self.iters = 0
self.bij = bij
self.model = model
self.fn = model.fastfn(model.unobserved_RVs)
self.logp = logp
self.dlogp = dlogp
self.t_initial = time.time()
self.t0 = self.t_initial
self.paramtable = {}

def __call__(self, x):
self.iters += 1
if time.time() - self.t0 > self.update_interval or self.iters == 1:
self.update(x)

def update(self, x):
self._update_progtable(x)
self._update_paramtable(x)
if self.using_notebook:
self._display_notebook()
self.t0 = time.time()

def _update_progtable(self, x):
s = time.time() - self.t_initial
hours, remainder = divmod(int(s), 3600)
minutes, seconds = divmod(remainder, 60)
self.t_elapsed = "{:2d}h{:2d}m{:2d}s".format(hours, minutes, seconds)
self.logpost = -1.0*np.float(self.logp(x))
self.dlogpost = np.linalg.norm(self.dlogp(x))

def _update_paramtable(self, x):
var_state = self.fn(self.bij.rmap(x))
for var, val in zip(self.model.unobserved_RVs, var_state):
if not var.name.endswith("_"):
valstr = format_values(val)
self.paramtable[var.name] = {"size": val.size, "valstr": valstr}

def _display_notebook(self):
## Progress table
html = r"""<style type="text/css">
table { border-collapse:collapse }
.tg {border-collapse:collapse;border-spacing:0;border:none;}
.tg td{font-family:Arial, sans-serif;font-size:14px;padding:3px 3px;border-style:solid;border-width:0px;overflow:hidden;word-break:normal;}
.tg th{Impact, Charcoal, sans-serif;font-size:13px;font-weight:bold;padding:3px 3px;border-style:solid;border-width:0px;overflow:hidden;word-break:normal; background-color:#0E688A;color:#ffffff;}
.tg .tg-vkoh{white-space:pre;font-weight:normal;font-family:"Lucida Console", Monaco, monospace !important; background-color:#ffffff;color:#000000}
.tg .tg-suao{font-weight:bold;font-family:"Lucida Console", Monaco, monospace !important;background-color:#0E688A;color:#ffffff;}
"""
html += r"""
</style>
<table class="tg" style="undefined;">
<col width="400px" />
<tr>
<th class= "tg-vkoh">Time Elapsed: {:s}</th>
</tr>
<tr>
<th class= "tg-vkoh">Iteration: {:d}</th>
</tr>
<tr>
<th class= "tg-vkoh">Log Posterior: {:.3f}</th>
</tr>
""".format(self.t_elapsed, self.iters, self.logpost)
if self.dlogp is not None:
html += r"""
<tr>
<th class= "tg-vkoh">||grad||: {:.3f}</th>
</tr>""".format(self.dlogpost)
html += "</table>"
self.prog_table.value = html
## Parameter table
html = r"""<style type="text/css">
.tg .tg-bgft{font-weight:normal;font-family:"Lucida Console", Monaco, monospace !important;background-color:#0E688A;color:#ffffff;}
.tg td{font-family:Arial, sans-serif;font-size:12px;padding:3px 3px;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#504A4E;color:#333;background-color:#fff;word-wrap: break-word;}
.tg th{Impact, Charcoal, sans-serif;font-size:13px;font-weight:bold;padding:3px 3px;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#504A4E;background-color:#0E688A;color:#ffffff;}
</style>
<table class="tg" style="undefined;">
<col width="130px" />
<col width="50px" />
<col width="600px" />
<tr>
<th class="tg">Parameter</th>
<th class="tg">Size</th>
<th class="tg">Current Value</th>
</tr>
"""
for var, values in self.paramtable.items():
html += r"""
<tr>
<td class="tg-bgft">{:s}</td>
<td class="tg-vkoh">{:d}</td>
<td class="tg-vkoh">{:s}</td>
</tr>
""".format(var, values["size"], values["valstr"])
html += "</table>"
self.param_table.value = html


def format_values(val):
fmt = "{:8.3f}"
if val.size == 1:
Expand Down

0 comments on commit 1f34a5f

Please sign in to comment.