diff --git a/pymc3/distributions/mixture.py b/pymc3/distributions/mixture.py index 215667ec1b..b0640542d6 100644 --- a/pymc3/distributions/mixture.py +++ b/pymc3/distributions/mixture.py @@ -138,7 +138,10 @@ def random_choice(*args, **kwargs): comp_samples = self._comp_samples(point=point, size=size, repeat=repeat) if comp_samples.ndim > 1: - return np.squeeze(comp_samples[np.arange(w_samples.size), w_samples]) + row_ix = (np.arange(w_samples.shape[0]) + .reshape([w_samples.shape[0]] + [1 for _ in w_samples.shape[1:]])) + + return np.squeeze(comp_samples[row_ix, w_samples]) else: return np.squeeze(comp_samples[w_samples])