Skip to content

Commit

Permalink
migrate useful functions from previous PR
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Jan 22, 2017
1 parent 424bec0 commit 9f61ab4
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 4 deletions.
57 changes: 57 additions & 0 deletions pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,60 @@ def i1(x):
x**9 / 1474560 + x**11 / 176947200 + x**13 / 29727129600,
np.e**x / (2 * np.pi * x)**0.5 * (1 - 3 / (8 * x) + 15 / (128 * x**2) + 315 / (3072 * x**3)
+ 14175 / (98304 * x**4)))


def sd2rho(sd):
"""
`sd -> rho` theano converter
:math:`mu + sd*e = mu + log(1+exp(rho))*e`"""
return tt.log(tt.exp(sd) - 1)


def rho2sd(rho):
"""
`rho -> sd` theano converter
:math:`mu + sd*e = mu + log(1+exp(rho))*e`"""
return tt.log1p(tt.exp(rho))


def log_normal(x, mean, **kwargs):
"""
Calculate logarithm of normal distribution at point `x`
with given `mean` and `std`
Parameters
----------
x : Tensor
point of evaluation
mean : Tensor
mean of normal distribution
kwargs : one of parameters `{sd, tau, w, rho}`
Notes
-----
There are four variants for density parametrization.
They are:
1) standard deviation - `std`
2) `w`, logarithm of `std` :math:`w = log(std)`
3) `rho` that follows this equation :math:`rho = log(exp(std) - 1)`
4) `tau` that follows this equation :math:`tau = std^{-1}`
----
"""
sd = kwargs.get('sd')
w = kwargs.get('w')
rho = kwargs.get('rho')
tau = kwargs.get('tau')
eps = kwargs.get('eps', 0.0)
check = sum(map(lambda a: a is not None, [sd, w, rho, tau]))
if check > 1:
raise ValueError('more than one required kwarg is passed')
if check == 0:
raise ValueError('none of required kwarg is passed')
if sd is not None:
std = sd
elif w is not None:
std = tt.exp(w)
elif rho is not None:
std = rho2sd(rho)
else:
std = tau**(-1)
std += eps
return c - tt.log(tt.abs_(std)) - (x - mean) ** 2 / (2 * std ** 2)
3 changes: 3 additions & 0 deletions pymc3/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ def invlogit(x, eps=sys.float_info.epsilon):

def logit(p):
return tt.log(p / (1 - p))

def flatten_list(tensors):
return tt.concatenate([var.ravel() for var in tensors])
54 changes: 50 additions & 4 deletions pymc3/theanof.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,24 @@
from theano import theano, scalar, tensor as tt
from theano.gof.graph import inputs
from theano.tensor import TensorVariable
from theano.sandbox.rng_mrg import MRG_RandomStreams
from theano.gof import Op
from theano.configparser import change_flags
from .memoize import memoize
from .blocking import ArrayOrdering
from .data import DataGenerator

__all__ = ['gradient', 'hessian', 'hessian_diag', 'inputvars',
'cont_inputs', 'floatX', 'jacobian',
'CallableTensor', 'join_nonshared_inputs',
'make_shared_replacements', 'generator']
__all__ = ['gradient',
'hessian',
'hessian_diag',
'inputvars',
'cont_inputs',
'floatX',
'jacobian',
'CallableTensor',
'join_nonshared_inputs',
'make_shared_replacements',
'generator']


def inputvars(a):
Expand Down Expand Up @@ -307,3 +315,41 @@ def set_gen(self, gen):
def generator(gen):
"""shortcut for `GeneratorOp`"""
return GeneratorOp(gen)()

@change_flags(compute_test_value='off')
def launch_rng(rng):
"""Helper function for safe launch of rng.
If not launched, there will be problems with test_value
Parameters
----------
rng : theano.sandbox.rng_mrg.MRG_RandomStreams` instance
"""
state = rng.rstate
rng.inc_rstate()
rng.set_rstate(state)

_tt_rng = MRG_RandomStreams()
launch_rng(_tt_rng)


def tt_rng():
"""Get the package-level random number generator.
Returns
-------
`theano.sandbox.rng_mrg.MRG_RandomStreams` instance
`theano.sandbox.rng_mrg.MRG_RandomStreams`
instance passed to the most recent call of `set_tt_rng`
"""
return _tt_rng


def set_tt_rng(new_rng):
"""Set the package-level random number generator.
Parameters
----------
new_rng : `theano.sandbox.rng_mrg.MRG_RandomStreams` instance
The random number generator to use.
"""
global _tt_rng
_tt_rng = new_rng
launch_rng(_tt_rng)

0 comments on commit 9f61ab4

Please sign in to comment.