Skip to content

Commit af7ea76

Browse files
committed
Fixed mixture random method (and also multinomial when using a Cholesky covariance matrix).
1 parent 6da77ac commit af7ea76

File tree

4 files changed

+33
-40
lines changed

4 files changed

+33
-40
lines changed

pymc3/distributions/distribution.py

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -262,27 +262,6 @@ def __init__(self):
262262
self.drawn_vars = dict()
263263

264264

265-
class _DrawValuesContextDetacher(_DrawValuesContext,
266-
metaclass=InitContextMeta):
267-
"""
268-
Context manager that starts a new drawn variables context copying the
269-
parent's context drawn_vars dict. The following changes do not affect the
270-
parent contexts but do affect the subsequent calls. This can be used to
271-
iterate the same random method many times to get different results, while
272-
respecting the drawn variables from previous contexts.
273-
"""
274-
def __new__(cls, *args, **kwargs):
275-
return super().__new__(cls)
276-
277-
def __init__(self):
278-
self.drawn_vars = self.drawn_vars.copy()
279-
280-
def update_parent(self):
281-
parent = self.parent
282-
if parent is not None:
283-
parent.drawn_vars.update(self.drawn_vars)
284-
285-
286265
def is_fast_drawable(var):
287266
return isinstance(var, (numbers.Number,
288267
np.ndarray,
@@ -660,20 +639,30 @@ def generate_samples(generator, *args, **kwargs):
660639
samples = generator(size=broadcast_shape, *args, **kwargs)
661640
elif dist_shape == broadcast_shape:
662641
samples = generator(size=size_tup + dist_shape, *args, **kwargs)
663-
elif len(dist_shape) == 0 and size_tup and broadcast_shape[:len(size_tup)] == size_tup:
664-
# Input's dist_shape is scalar, but it has size repetitions.
665-
# So now the size matches but we have to manually broadcast to
666-
# the right dist_shape
667-
samples = [generator(*args, **kwargs)]
668-
if samples[0].shape == broadcast_shape:
669-
samples = samples[0]
642+
elif len(dist_shape) == 0 and size_tup and broadcast_shape:
643+
# There is no dist_shape (scalar distribution) but the parameters
644+
# broadcast shape and size_tup determine the size to provide to
645+
# the generator
646+
if broadcast_shape[:len(size_tup)] == size_tup:
647+
# Input's dist_shape is scalar, but it has size repetitions.
648+
# So now the size matches but we have to manually broadcast to
649+
# the right dist_shape
650+
samples = [generator(*args, **kwargs)]
651+
if samples[0].shape == broadcast_shape:
652+
samples = samples[0]
653+
else:
654+
suffix = broadcast_shape[len(size_tup):] + dist_shape
655+
samples.extend([generator(*args, **kwargs).
656+
reshape(broadcast_shape)[..., np.newaxis]
657+
for _ in range(np.prod(suffix,
658+
dtype=int) - 1)])
659+
samples = np.hstack(samples).reshape(size_tup + suffix)
670660
else:
671-
suffix = broadcast_shape[len(size_tup):] + dist_shape
672-
samples.extend([generator(*args, **kwargs).
673-
reshape(broadcast_shape)[..., np.newaxis]
674-
for _ in range(np.prod(suffix,
675-
dtype=int) - 1)])
676-
samples = np.hstack(samples).reshape(size_tup + suffix)
661+
# The parameter shape is given, but we have to concatenate it
662+
# with the size tuple
663+
samples = generator(size=size_tup + broadcast_shape,
664+
*args,
665+
**kwargs)
677666
else:
678667
samples = None
679668
# Args have been broadcast correctly, can just ask for the right shape out

pymc3/distributions/mixture.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from .distribution import (Discrete, Distribution, draw_values,
1010
generate_samples, _DrawValuesContext,
1111
_DrawValuesContextBlocker, to_tuple,
12-
broadcast_distribution_samples,
13-
_DrawValuesContextDetacher)
12+
broadcast_distribution_samples)
1413
from .continuous import get_tau_sigma, Normal
1514
from ..theanof import _conversion_map
1615

@@ -192,7 +191,11 @@ def generator(*args, **kwargs):
192191
# differently from scipy.*.rvs, and generate_samples
193192
# follows the latter usage pattern. For this reason we
194193
# decorate (horribly hack) the size kwarg of
195-
# comp_dist.random
194+
# comp_dist.random. We also have to disable pylint W0640
195+
# because comp_dist is changed at each iteration of the
196+
# for loop, and this generator function must be defined
197+
# for each comp_dist.
198+
# pylint: disable=W0640
196199
if len(args) > 2:
197200
args[1] = size
198201
else:

pymc3/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def random(self, point=None, size=None):
267267
else:
268268
std_norm_shape = mu.shape
269269
standard_normal = np.random.standard_normal(std_norm_shape)
270-
return mu + np.tensordot(standard_normal, chol, axes=[[-1], [-1]])
270+
return mu + np.einsum('...ij,...j->...i', chol, standard_normal)
271271
else:
272272
mu, tau = draw_values([self.mu, self.tau], point=point, size=size)
273273
if mu.shape[-1] != tau[0].shape[-1]:

pymc3/tests/test_sampling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,9 @@ def test_normal_scalar(self):
223223
ppc = pm.sample_posterior_predictive(trace, samples=1000, vars=[a])
224224
assert 'a' in ppc
225225
assert ppc['a'].shape == (1000,)
226-
_, pval = stats.kstest(ppc['a'],
227-
stats.norm(loc=0, scale=np.sqrt(2)).cdf)
226+
# mu's standard deviation may have changed thanks to a's observed
227+
_, pval = stats.kstest(ppc['a'] - trace['mu'],
228+
stats.norm(loc=0, scale=1).cdf)
228229
assert pval > 0.001
229230

230231
with model:

0 commit comments

Comments
 (0)