Skip to content

Commit

Permalink
Merge pull request #45 from pymc-devs/master
Browse files Browse the repository at this point in the history
Sync Fork from Upstream Repo
  • Loading branch information
sthagen committed Mar 12, 2020
2 parents 76003bb + 6c5254f commit ecf6c90
Show file tree
Hide file tree
Showing 20 changed files with 1,128 additions and 72 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ benchmarks/results/
# VSCode
.vscode/
.mypy_cache

pytestdebug.log
.dir-locals.el
.pycheckers
2 changes: 2 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
- `DEMetropolis` can now tune both `lambda` and `scaling` parameters, but by default neither of them are tuned. See [#3743](https://github.com/pymc-devs/pymc3/pull/3743) for more info.
- `DEMetropolisZ`, an improved variant of `DEMetropolis` brings better parallelization and higher efficiency with fewer chains with a slower initial convergence. This implementation is experimental. See [#3784](https://github.com/pymc-devs/pymc3/pull/3784) for more info.
- Notebooks that give insight into `DEMetropolis`, `DEMetropolisZ` and the `DifferentialEquation` interface are now located in the [Tutorials/Deep Dive](https://docs.pymc.io/nb_tutorials/index.html) section.
- Add `fast_sample_posterior_predictive`, a vectorized alternative to `sample_posterior_predictive`. This alternative is substantially faster for large models.
- `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827))

### Maintenance
- Remove `sample_ppc` and `sample_ppc_w` that were deprecated in 3.6.
Expand Down
2 changes: 1 addition & 1 deletion pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from . import sampling

from .backends.tracetab import *
from .backends import save_trace, load_trace
from .backends import save_trace, load_trace, point_list_to_multitrace

from .plots import *
from .tests import test
Expand Down
2 changes: 1 addition & 1 deletion pymc3/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
For specific examples, see pymc3.backends.{ndarray,text,sqlite}.py.
"""
from ..backends.ndarray import NDArray, save_trace, load_trace
from ..backends.ndarray import NDArray, save_trace, load_trace, point_list_to_multitrace
from ..backends.text import Text
from ..backends.sqlite import SQLite
from ..backends.hdf5 import HDF5
Expand Down
7 changes: 4 additions & 3 deletions pymc3/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
"""
import itertools as itl
import logging
from typing import List
from typing import Dict, List, Optional
from abc import ABC

import numpy as np
import warnings
import theano.tensor as tt

from ..model import modelcontext
from ..model import modelcontext, Model
from .report import SamplerReport, merge_reports

logger = logging.getLogger('pymc3')
Expand All @@ -35,7 +36,7 @@ class BackendError(Exception):
pass


class BaseTrace:
class BaseTrace(ABC):
"""Base trace object
Parameters
Expand Down
20 changes: 18 additions & 2 deletions pymc3/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
import json
import os
import shutil
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, List

import numpy as np
from pymc3.backends import base
from pymc3.backends.base import MultiTrace
from pymc3.model import Model
from pymc3.model import Model, modelcontext
from pymc3.exceptions import TraceDirectoryError


Expand Down Expand Up @@ -366,3 +366,19 @@ def _slice_as_ndarray(strace, idx):
sliced.draw_idx = (stop - start) // step

return sliced

def point_list_to_multitrace(point_list: List[Dict[str, np.ndarray]], model: Optional[Model]=None) -> MultiTrace:
'''transform point list into MultiTrace'''
_model = modelcontext(model)
varnames = list(point_list[0].keys())
with _model:
chain = NDArray(model=_model, vars=[_model[vn] for vn in varnames])
chain.setup(draws=len(point_list), chain=0)
# since we are simply loading a trace by hand, we need only a vacuous function for
# chain.record() to use. This crushes the default.
def point_fun(point):
return [point[vn] for vn in varnames]
chain.fn = point_fun
for point in point_list:
chain.record(point)
return MultiTrace([chain])
25 changes: 24 additions & 1 deletion pymc3/backends/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from collections import namedtuple
import logging
import enum
import typing
from ..util import is_transformed_name, get_untransformed_name


Expand Down Expand Up @@ -51,11 +52,15 @@ class WarningType(enum.Enum):


class SamplerReport:
"""This object bundles warnings, convergence statistics and metadata of a sampling run."""
def __init__(self):
self._chain_warnings = {}
self._global_warnings = []
self._ess = None
self._rhat = None
self._n_tune = None
self._n_draws = None
self._t_sampling = None

@property
def _warnings(self):
Expand All @@ -68,6 +73,25 @@ def ok(self):
return all(_LEVELS[warn.level] < _LEVELS['warn']
for warn in self._warnings)

@property
def n_tune(self) -> typing.Optional[int]:
"""Number of tune iterations - not necessarily kept in trace!"""
return self._n_tune

@property
def n_draws(self) -> typing.Optional[int]:
"""Number of draw iterations."""
return self._n_draws

@property
def t_sampling(self) -> typing.Optional[float]:
"""
Number of seconds that the sampling procedure took.
(Includes parallelization overhead.)
"""
return self._t_sampling

def raise_ok(self, level='error'):
errors = [warn for warn in self._warnings
if _LEVELS[warn.level] >= _LEVELS[level]]
Expand Down Expand Up @@ -151,7 +175,6 @@ def _add_warnings(self, warnings, chain=None):
warn_list.extend(warnings)

def _log_summary(self):

def log_warning(warn):
level = _LEVELS[warn.level]
logger.log(level, warn.message)
Expand Down
5 changes: 4 additions & 1 deletion pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from . import transforms
from . import shape_utils

from .posterior_predictive import fast_sample_posterior_predictive

from .continuous import Uniform
from .continuous import Flat
from .continuous import HalfFlat
Expand Down Expand Up @@ -168,5 +170,6 @@
'Interpolated',
'Bound',
'Rice',
'Simulator'
'Simulator',
'fast_sample_posterior_predictive'
]
16 changes: 14 additions & 2 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# limitations under the License.

import numbers
import contextvars
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Callable

import numpy as np
import theano.tensor as tt
Expand All @@ -33,6 +37,7 @@
__all__ = ['DensityDist', 'Distribution', 'Continuous', 'Discrete',
'NoDistribution', 'TensorType', 'draw_values', 'generate_samples']

vectorized_ppc = contextvars.ContextVar('vectorized_ppc', default=None) # type: contextvars.ContextVar[Optional[Callable]]

class _Unpickling:
pass
Expand Down Expand Up @@ -530,11 +535,18 @@ def draw_values(params, point=None, size=None):
a) are named parameters in the point
b) are RVs with a random method
"""
# The following check intercepts and redirects calls to
# draw_values in the context of sample_posterior_predictive
ppc_sampler = vectorized_ppc.get(None)
if ppc_sampler is not None:
# this is being done inside new, vectorized sample_posterior_predictive
return ppc_sampler(params, trace=point, samples=size)

if point is None:
point = {}
# Get fast drawable values (i.e. things in point or numbers, arrays,
# constants or shares, or things that were already drawn in related
# contexts)
if point is None:
point = {}
with _DrawValuesContext() as context:
params = dict(enumerate(params))
drawn = context.drawn_vars
Expand Down

0 comments on commit ecf6c90

Please sign in to comment.