Skip to content

Commit

Permalink
Another attempt to remove whitespace
Browse files Browse the repository at this point in the history
  • Loading branch information
fonnesbeck committed May 23, 2017
1 parent 73e2773 commit ad63768
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 41 deletions.
10 changes: 5 additions & 5 deletions pymc3/distributions/continuous.py
Expand Up @@ -1592,22 +1592,22 @@ class Interpolated(Continuous):
R"""
Univariate probability distribution defined as a linear interpolation
of probability density function evaluated on some lattice of points.
The lattice can be uneven, so the steps between different points can have
different size and it is possible to vary the precision between regions
of the support.
The probability density function values don not have to be normalized, as the
interpolated density is any way normalized to make the total probability
equal to $1$.
Both parameters `x_points` and values `pdf_points` are not variables, but
plain array-like objects, so they are constant and cannot be sampled.
======== ===========================================
Support :math:`x \in [x\_points[0], x\_points[-1]]`
======== ===========================================
Parameters
----------
x_points : array-like
Expand Down
64 changes: 32 additions & 32 deletions pymc3/distributions/mixture.py
Expand Up @@ -20,16 +20,16 @@ def all_discrete(comp_dists):
class Mixture(Distribution):
R"""
Mixture log-likelihood
Often used to model subpopulation heterogeneity
.. math:: f(x \mid w, \theta) = \sum_{i = 1}^n w_i f_i(x \mid \theta_i)
======== ============================================
Support :math:`\cap_{i = 1}^n \textrm{support}(f_i)`
Mean :math:`\sum_{i = 1}^n w_i \mu_i`
======== ============================================
Parameters
----------
w : array of floats
Expand All @@ -40,98 +40,98 @@ class Mixture(Distribution):
"""
def __init__(self, w, comp_dists, *args, **kwargs):
shape = kwargs.pop('shape', ())

self.w = w
self.comp_dists = comp_dists

defaults = kwargs.pop('defaults', [])

if all_discrete(comp_dists):
dtype = kwargs.pop('dtype', 'int64')
else:
dtype = kwargs.pop('dtype', 'float64')

try:
self.mean = (w * self._comp_means()).sum(axis=-1)

if 'mean' not in defaults:
defaults.append('mean')
except AttributeError:
pass

try:
comp_modes = self._comp_modes()
comp_mode_logps = self.logp(comp_modes)
self.mode = comp_modes[tt.argmax(w * comp_mode_logps, axis=-1)]

if 'mode' not in defaults:
defaults.append('mode')
except AttributeError:
pass

super(Mixture, self).__init__(shape, dtype, defaults=defaults,
*args, **kwargs)

def _comp_logp(self, value):
comp_dists = self.comp_dists

try:
value_ = value if value.ndim > 1 else tt.shape_padright(value)

return comp_dists.logp(value_)
except AttributeError:
return tt.stack([comp_dist.logp(value) for comp_dist in comp_dists],
axis=1)

def _comp_means(self):
try:
return tt.as_tensor_variable(self.comp_dists.mean)
except AttributeError:
return tt.stack([comp_dist.mean for comp_dist in self.comp_dists],
axis=1)

def _comp_modes(self):
try:
return tt.as_tensor_variable(self.comp_dists.mode)
except AttributeError:
return tt.stack([comp_dist.mode for comp_dist in self.comp_dists],
axis=1)

def _comp_samples(self, point=None, size=None, repeat=None):
try:
samples = self.comp_dists.random(point=point, size=size, repeat=repeat)
except AttributeError:
samples = np.column_stack([comp_dist.random(point=point, size=size, repeat=repeat)
for comp_dist in self.comp_dists])

return np.squeeze(samples)

def logp(self, value):
w = self.w

return bound(logsumexp(tt.log(w) + self._comp_logp(value), axis=-1).sum(),
w >= 0, w <= 1, tt.allclose(w.sum(axis=-1), 1))

def random(self, point=None, size=None, repeat=None):
def random_choice(*args, **kwargs):
w = kwargs.pop('w')
w /= w.sum(axis=-1, keepdims=True)
k = w.shape[-1]

if w.ndim > 1:
return np.row_stack([np.random.choice(k, p=w_) for w_ in w])
else:
return np.random.choice(k, p=w, *args, **kwargs)

w = draw_values([self.w], point=point)

w_samples = generate_samples(random_choice,
w=w,
broadcast_shape=w.shape[:-1] or (1,),
dist_shape=self.shape,
size=size).squeeze()
comp_samples = self._comp_samples(point=point, size=size, repeat=repeat)

if comp_samples.ndim > 1:
return np.squeeze(comp_samples[np.arange(w_samples.size), w_samples])
else:
Expand All @@ -141,17 +141,17 @@ def random_choice(*args, **kwargs):
class NormalMixture(Mixture):
R"""
Normal mixture log-likelihood
.. math::
f(x \mid w, \mu, \sigma^2) = \sum_{i = 1}^n w_i N(x \mid \mu_i, \sigma^2_i)
======== =======================================
Support :math:`x \in \mathbb{R}`
Mean :math:`\sum_{i = 1}^n w_i \mu_i`
Variance :math:`\sum_{i = 1}^n w_i^2 \sigma^2_i`
======== =======================================
Parameters
----------
w : array of floats
Expand All @@ -167,7 +167,7 @@ class NormalMixture(Mixture):
def __init__(self, w, mu, *args, **kwargs):
_, sd = get_tau_sd(tau=kwargs.pop('tau', None),
sd=kwargs.pop('sd', None))

super(NormalMixture, self).__init__(w, Normal.dist(mu, sd=sd),
*args, **kwargs)

Expand All @@ -180,4 +180,4 @@ def _repr_latex_(self, name=None, dist=None):
return r'${} \sim \text{{NormalMixture}}(\mathit{{w}}={}, \mathit{{mu}}={}, \mathit{{sigma}}={})$'.format(name,
get_variable_name(w),
get_variable_name(mu),
get_variable_name(sigma))
get_variable_name(sigma))
6 changes: 3 additions & 3 deletions pymc3/distributions/multivariate.py
Expand Up @@ -268,15 +268,15 @@ class MvStudentT(Continuous):
({\mathbf x}-{\mu})^T
{\Sigma}^{-1}({\mathbf x}-{\mu})
\right]^{(\nu+p)/2}}
======== =============================================
Support :math:`x \in \mathbb{R}^k`
Mean :math:`\mu` if :math:`\nu > 1` else undefined
Variance :math:`\frac{\nu}{\mu-2}\Sigma`
if :math:`\nu>2` else undefined
======== =============================================
Parameters
----------
Expand Down Expand Up @@ -926,7 +926,7 @@ class LKJCorr(Continuous):
[- - - 7 8]
[- - - - 9]
[- - - - -]]
References
----------
Expand Down
1 change: 0 additions & 1 deletion pymc3/distributions/timeseries.py
Expand Up @@ -296,4 +296,3 @@ def _repr_latex_(self, name=None, dist=None):
get_variable_name(nu),
get_variable_name(mu),
get_variable_name(cov))

0 comments on commit ad63768

Please sign in to comment.