Skip to content

Commit

Permalink
Merge 0194816 into c0edddd
Browse files Browse the repository at this point in the history
  • Loading branch information
fonnesbeck committed Jul 26, 2019
2 parents c0edddd + 0194816 commit 8743f60
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Expand Up @@ -8,6 +8,7 @@
- Sequential Monte Carlo - Approximate Bayesian Computation step method is now available. The implementation is in an experimental stage and will be further improved.
- Added `Matern12` covariance function for Gaussian processes. This is the Matern kernel with nu=1/2.
- Progressbar reports number of divergences in real time, when available [#3547](https://github.com/pymc-devs/pymc3/pull/3547).
- Sampling from variational approximation now allows for alternative trace backends [#3557].

### Maintenance
- Moved math operations out of `Rice`, `TruncatedNormal`, `Triangular` and `ZeroInflatedNegativeBinomial` `random` methods. Math operations on values returned by `draw_values` might not broadcast well, and all the `size` aware broadcasting is left to `generate_samples`. Fixes [#3481](https://github.com/pymc-devs/pymc3/issues/3481) and [#3508](https://github.com/pymc-devs/pymc3/issues/3508)
Expand Down
16 changes: 10 additions & 6 deletions pymc3/tests/test_variational_inference.py
Expand Up @@ -153,12 +153,16 @@ def three_var_approx_single_group_mf(three_var_model):


def test_sample_simple(three_var_approx):
trace = three_var_approx.sample(500)
assert set(trace.varnames) == {'one', 'one_log__', 'three', 'two'}
assert len(trace) == 500
assert trace[0]['one'].shape == (10, 2)
assert trace[0]['two'].shape == (10, )
assert trace[0]['three'].shape == (10, 1, 2)
for backend,name in ((None, None),
('text', 'test'),
('sqlite', 'test.sqlite'),
('hdf5', 'test.h5')):
trace = three_var_approx.sample(100, backend=backend, name=name)
assert set(trace.varnames) == {'one', 'one_log__', 'three', 'two'}
assert len(trace) == 100
assert trace[0]['one'].shape == (10, 2)
assert trace[0]['two'].shape == (10, )
assert trace[0]['three'].shape == (10, 1, 2)


@pytest.fixture
Expand Down
15 changes: 12 additions & 3 deletions pymc3/variational/opvi.py
Expand Up @@ -45,6 +45,7 @@
from ..blocking import (
ArrayOrdering, DictToArrayBijection, VarMap
)
from ..backends import NDArray, Text, SQLite, HDF5
from ..model import modelcontext
from ..theanof import tt_rng, change_flags, identity
from ..util import get_default_varnames
Expand Down Expand Up @@ -1569,7 +1570,8 @@ def inner(draws=100):

return inner

def sample(self, draws=500, include_transformed=True):
def sample(self, draws=500, include_transformed=True, backend='ndarray',
name=None):
"""Draw samples from variational posterior.
Parameters
Expand All @@ -1578,6 +1580,11 @@ def sample(self, draws=500, include_transformed=True):
Number of random samples.
include_transformed : `bool`
If True, transformed variables are also sampled. Default is False.
backend : `str`
Trace backend type to use. Valid entries include: 'ndarray' (default),
'text', 'sqlite', 'hdf5'.
name : `str`
Name for backend (required for non-NDArray backends). Default is None.
Returns
-------
Expand All @@ -1588,8 +1595,10 @@ def sample(self, draws=500, include_transformed=True):
include_transformed=include_transformed)
samples = self.sample_dict_fn(draws) # type: dict
points = ({name: records[i] for name, records in samples.items()} for i in range(draws))
trace = pm.sampling.NDArray(model=self.model, vars=vars_sampled, test_point={
name: records[0] for name, records in samples.items()
_backends = dict(ndarray=NDArray, text=Text, hdf5=HDF5, sqlite=SQLite)

trace = _backends[backend](name=name, model=self.model, vars=vars_sampled, test_point={
name: records[0] for name, records in samples.items()
})
try:
trace.setup(draws=draws, chain=0)
Expand Down

0 comments on commit 8743f60

Please sign in to comment.