Skip to content

Commit

Permalink
WIP: Implement opvi (#1694)
Browse files Browse the repository at this point in the history
* migrate useful functions from previous PR

(cherry picked from commit 9f61ab4)

* opvi draft

(cherry picked from commit d0997ff)

* made some test work

(cherry picked from commit b1a87d5)

* refactored approximation to support aevb (without test)

* refactor opvi

delete unnecessary methods from operator, change method order

* change log_q_local computation

* add full rank approximation

* add more_params argument to ObjectiveFunction.updates (aevb case)

* refactor density computation in full rank approximation

* typo: cast dict values to list

* typo: cast dict values to list

* typo: undefined T in dist_math

* refactor gradient scaling as suggested in approximateinference.org/accepted/RoederEtAl2016.pdf

* implement Langevin-Stein (LS) operator

* fix docstring

* add blank line in docs

* refactor ObjectiveFunction

* add not working LS Op test

* experiments with not working LS Op

* change activations

* refactor networks

* add step_function

* remove Langevin Stein, done refactoring

* remove Langevin Stein, done refactoring

* change optimizers

* refactor init params

* implement tests

* implement Inference

* code style

* test fix

* add minibatch test (fails now)

* add more tests for minibatch training

* add logdet to FullRank approximation

* add conversion of arrays to floatX

* tiny changes

* change number of iterations

* fix test and pylint check

* memoize functions in Objective function

* Optimize code a lot

* a bit more efficient pickling

* add docs

* Add MeanField -> FullRank parameter transfer

* refactor MeanField and FullRank a bit

* fix FullRank bug with shapes in random

* refactor Model.flatten (CC @taku-y)

* add `approximate` to inference

* rename approximate->fit

* change abbreviations

* Fix bug with scaling input variable in aevb

* fix theano bottleneck in graph

* more efficient scaling for local vars

* fix typo in local Q

* add aevb test

* refactor memoize to work with my objects

* add tests for numpy view usage

* pickle-hash fix

* pickle-hash fix again

* add node sampling + make up some code

* add notebook with example

* sample_proba explained
  • Loading branch information
ferrine authored and twiecki committed Mar 15, 2017
1 parent 01e5aef commit 4a713dc
Show file tree
Hide file tree
Showing 13 changed files with 2,745 additions and 28 deletions.
865 changes: 865 additions & 0 deletions docs/source/notebooks/bayesian_neural_network_opvi-advi.ipynb

Large diffs are not rendered by default.

117 changes: 117 additions & 0 deletions pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import theano.tensor as tt

from .special import gammaln
from ..math import logdet as _logdet

c = - 0.5 * np.log(2 * np.pi)


def bound(logp, *conditions, **kwargs):
Expand Down Expand Up @@ -96,3 +99,117 @@ def i1(x):
x**9 / 1474560 + x**11 / 176947200 + x**13 / 29727129600,
np.e**x / (2 * np.pi * x)**0.5 * (1 - 3 / (8 * x) + 15 / (128 * x**2) + 315 / (3072 * x**3)
+ 14175 / (98304 * x**4)))


def sd2rho(sd):
"""
`sd -> rho` theano converter
:math:`mu + sd*e = mu + log(1+exp(rho))*e`"""
return tt.log(tt.exp(sd) - 1)


def rho2sd(rho):
"""
`rho -> sd` theano converter
:math:`mu + sd*e = mu + log(1+exp(rho))*e`"""
return tt.log1p(tt.exp(rho))


def log_normal(x, mean, **kwargs):
"""
Calculate logarithm of normal distribution at point `x`
with given `mean` and `std`
Parameters
----------
x : Tensor
point of evaluation
mean : Tensor
mean of normal distribution
kwargs : one of parameters `{sd, tau, w, rho}`
Notes
-----
There are four variants for density parametrization.
They are:
1) standard deviation - `std`
2) `w`, logarithm of `std` :math:`w = log(std)`
3) `rho` that follows this equation :math:`rho = log(exp(std) - 1)`
4) `tau` that follows this equation :math:`tau = std^{-1}`
----
"""
sd = kwargs.get('sd')
w = kwargs.get('w')
rho = kwargs.get('rho')
tau = kwargs.get('tau')
eps = kwargs.get('eps', 0.0)
check = sum(map(lambda a: a is not None, [sd, w, rho, tau]))
if check > 1:
raise ValueError('more than one required kwarg is passed')
if check == 0:
raise ValueError('none of required kwarg is passed')
if sd is not None:
std = sd
elif w is not None:
std = tt.exp(w)
elif rho is not None:
std = rho2sd(rho)
else:
std = tau**(-1)
std += eps
return c - tt.log(tt.abs_(std)) - (x - mean) ** 2 / (2 * std ** 2)


def log_normal_mv(x, mean, gpu_compat=False, **kwargs):
"""
Calculate logarithm of normal distribution at point `x`
with given `mean` and `sigma` matrix
Parameters
----------
x : Tensor
point of evaluation
mean : Tensor
mean of normal distribution
kwargs : one of parameters `{cov, tau, chol}`
Flags
----------
gpu_compat : False, because LogDet is not GPU compatible yet.
If this is set as true, the GPU compatible (but numerically unstable) log(det) is used.
Notes
-----
There are three variants for density parametrization.
They are:
1) covariance matrix - `cov`
2) precision matrix - `tau`,
3) cholesky decomposition matrix - `chol`
----
"""
if gpu_compat:
def logdet(m):
return tt.log(tt.abs_(tt.nlinalg.det(m)))
else:
logdet = _logdet

T = kwargs.get('tau')
S = kwargs.get('cov')
L = kwargs.get('chol')
check = sum(map(lambda a: a is not None, [T, S, L]))
if check > 1:
raise ValueError('more than one required kwarg is passed')
if check == 0:
raise ValueError('none of required kwarg is passed')
# avoid unnecessary computations
if L is not None:
S = L.dot(L.T)
T = tt.nlinalg.matrix_inverse(S)
log_det = -logdet(S)
elif T is not None:
log_det = logdet(T)
else:
T = tt.nlinalg.matrix_inverse(S)
log_det = -logdet(S)
delta = x - mean
k = S.shape[0]
result = k * tt.log(2 * np.pi) - log_det
result += delta.dot(T).dot(delta)
return -1 / 2. * result
6 changes: 5 additions & 1 deletion pymc3/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,17 @@ def logit(p):
return tt.log(p / (1 - p))


def flatten_list(tensors):
return tt.concatenate([var.ravel() for var in tensors])


class LogDet(Op):
"""Computes the logarithm of absolute determinant of a square
matrix M, log(abs(det(M))), on CPU. Avoids det(M) overflow/
underflow.
Note: Once PR #3959 (https://github.com/Theano/Theano/pull/3959/) by harpone is merged,
this must be removed.
this must be removed.
"""
def make_node(self, x):
x = theano.tensor.as_tensor_variable(x)
Expand Down
17 changes: 13 additions & 4 deletions pymc3/memoize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import pickle


def memoize(obj):
Expand All @@ -23,8 +24,16 @@ def hashable(a):
Turn some unhashable objects into hashable ones.
"""
if isinstance(a, dict):
return hashable(a.items())
return hashable(tuple((hashable(a1), hashable(a2)) for a1, a2 in a.items()))
try:
return tuple(map(hashable, a))
except:
return a
return hash(a)
except TypeError:
pass
# Not hashable >>>
try:
return hash(pickle.dumps(a))
except Exception:
if hasattr(a, '__dict__'):
return hashable(a.__dict__)
else:
return id(a)
40 changes: 36 additions & 4 deletions pymc3/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import collections
import threading
import six

import numpy as np
import scipy.sparse as sps
import theano
import theano.tensor as tt
import theano.sparse as sparse
from theano import theano, tensor as tt
from theano.tensor.var import TensorVariable

import pymc3 as pm
from pymc3.math import flatten_list
from .memoize import memoize
from .theanof import gradient, hessian, inputvars, generator
from .vartypes import typefilter, discrete_types, continuous_types, isgenerator
Expand All @@ -19,6 +20,8 @@
'Point', 'Deterministic', 'Potential'
]

FlatView = collections.namedtuple('FlatView', 'input, replacements, view')


class InstanceMethod(object):
"""Class for hiding references to instance methods so they can be pickled.
Expand Down Expand Up @@ -172,8 +175,10 @@ def fastd2logp(self, vars=None):
@property
def logpt(self):
"""Theano scalar of log-probability of the model"""

return tt.sum(self.logp_elemwiset) * self.scaling
if getattr(self, 'total_size', None) is not None:
return tt.sum(self.logp_elemwiset) * self.scaling
else:
return tt.sum(self.logp_elemwiset)

@property
def scaling(self):
Expand Down Expand Up @@ -659,6 +664,33 @@ def profile(self, outs, n=1000, point=None, profile=True, *args, **kwargs):

return f.profile

def flatten(self, vars=None):
"""Flattens model's input and returns:
FlatView with
* input vector variable
* replacements `input_var -> vars`
* view {variable: VarMap}
Parameters
----------
vars : list of variables or None
if None, then all model.free_RVs are used for flattening input
Returns
-------
flat_view
"""
if vars is None:
vars = self.free_RVs
order = ArrayOrdering(vars)
inputvar = tt.vector('flat_view', dtype=theano.config.floatX)
inputvar.tag.test_value = flatten_list(vars).tag.test_value
replacements = {self.named_vars[name]: inputvar[slc].reshape(shape).astype(dtype)
for name, slc, shape, dtype in order.vmap}
view = {vm.var: vm for vm in order.vmap}
flat_view = FlatView(inputvar, replacements, view)
return flat_view


def fn(outs, mode=None, model=None, *args, **kwargs):
"""Compiles a Theano function which returns the values of `outs` and
Expand Down
9 changes: 8 additions & 1 deletion pymc3/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import unittest
import numpy.random as nr
from logging.handlers import BufferingHandler
import numpy.random as nr
from theano.sandbox.rng_mrg import MRG_RandomStreams
from ..theanof import set_tt_rng, tt_rng


class SeededTest(unittest.TestCase):
Expand All @@ -12,6 +14,11 @@ def setUpClass(cls):

def setUp(self):
nr.seed(self.random_seed)
self.old_tt_rng = tt_rng()
set_tt_rng(MRG_RandomStreams(self.random_seed))

def tearDown(self):
set_tt_rng(self.old_tt_rng)

class TestHandler(BufferingHandler):
def __init__(self, matcher):
Expand Down
1 change: 1 addition & 0 deletions pymc3/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ class TestLatentOccupancy(SeededTest):
Copyright (c) 2008 University of Otago. All rights reserved.
"""
def setUp(self):
super(TestLatentOccupancy, self).setUp()
# Sample size
n = 100
# True mean count, given occupancy
Expand Down
4 changes: 3 additions & 1 deletion pymc3/tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from pymc3.math import LogDet, logdet, probit, invprobit
from .helpers import SeededTest


def test_probit():
p = np.array([0.01, 0.25, 0.5, 0.75, 0.99])
np.testing.assert_allclose(invprobit(probit(p)).eval(), p, atol=1e-5)

class TestLogDet(SeededTest):

class TestLogDet(SeededTest):
def setUp(self):
super(TestLogDet, self).setUp()
utt.seed_rng()
self.op_class = LogDet
self.op = logdet
Expand Down
Loading

0 comments on commit 4a713dc

Please sign in to comment.