Skip to content

Commit

Permalink
Refactor NormalMixture
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 28, 2022
1 parent b7840b2 commit 14e06dc
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 64 deletions.
18 changes: 10 additions & 8 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from aesara.tensor.random.op import RandomVariable

from pymc.aesaraf import take_along_axis
from pymc.distributions.continuous import Normal
from pymc.distributions.continuous import Normal, get_tau_sigma
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import Discrete, Distribution, SymbolicDistribution
from pymc.distributions.logprob import logp
Expand Down Expand Up @@ -391,7 +391,7 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs):
return mix_logp


class NormalMixture(Mixture):
class NormalMixture:
R"""
Normal mixture log-likelihood
Expand Down Expand Up @@ -446,18 +446,20 @@ class NormalMixture(Mixture):
pm.NormalMixture("y", w=weights, mu=μ, sigma=σ, observed=data)
"""

def __init__(self, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), *args, **kwargs):
def __new__(cls, name, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), **kwargs):
if sd is not None:
sigma = sd
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)

self.mu = mu = at.as_tensor_variable(mu)
self.sigma = self.sd = sigma = at.as_tensor_variable(sigma)
return Mixture(name, w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs)

super().__init__(w, Normal.dist(mu, sigma=sigma, shape=comp_shape), *args, **kwargs)
@classmethod
def dist(cls, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), **kwargs):
if sd is not None:
sigma = sd
_, sigma = get_tau_sigma(tau=tau, sigma=sigma)

def _distr_parameters_for_repr(self):
return ["w", "mu", "sigma"]
return Mixture.dist(w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs)


class MixtureSameFamily(Distribution):
Expand Down
3 changes: 2 additions & 1 deletion pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def pymc_random(
fails=10,
extra_args=None,
model_args=None,
change_rv_size_fn=change_rv_size,
):
if valuedomain is None:
valuedomain = Domain([0], edges=(None, None))
Expand All @@ -81,7 +82,7 @@ def pymc_random(
model_args = {}

model, param_vars = build_model(dist, valuedomain, paramdomains, extra_args)
model_dist = change_rv_size(model.named_vars["value"], size, expand=True)
model_dist = change_rv_size_fn(model.named_vars["value"], size, expand=True)
pymc_rand = aesara.function([], model_dist)

domains = paramdomains.copy()
Expand Down
91 changes: 36 additions & 55 deletions pymc/tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,28 +569,33 @@ def mixmixlogp(value, point):
assert_allclose(priorlogp + mixmixlogpg.sum(), model.logp(test_point), rtol=rtol)


@pytest.mark.xfail(reason="NormalMixture not refactored yet")
class TestNormalMixture(SeededTest):
@classmethod
def setup_class(cls):
TestMixture.setup_class()
def test_normal_mixture_sampling(self):
norm_w = np.array([0.75, 0.25])
norm_mu = np.array([0.0, 5.0])
norm_sd = np.ones_like(norm_mu)
norm_x = generate_normal_mixture_data(norm_w, norm_mu, norm_sd, size=1000)

def test_normal_mixture(self):
with Model() as model:
w = Dirichlet("w", floatX(np.ones_like(self.norm_w)), shape=self.norm_w.size)
mu = Normal("mu", 0.0, 10.0, shape=self.norm_w.size)
tau = Gamma("tau", 1.0, 1.0, shape=self.norm_w.size)
NormalMixture("x_obs", w, mu, tau=tau, observed=self.norm_x)
w = Dirichlet("w", floatX(np.ones_like(norm_w)), shape=norm_w.size)
mu = Normal("mu", 0.0, 10.0, shape=norm_w.size)
tau = Gamma("tau", 1.0, 1.0, shape=norm_w.size)
NormalMixture("x_obs", w, mu, tau=tau, observed=norm_x)
step = Metropolis()
trace = sample(5000, step, random_seed=self.random_seed, progressbar=False, chains=1)
trace = sample(
5000,
step,
random_seed=self.random_seed,
progressbar=False,
chains=1,
return_inferencedata=False,
)

assert_allclose(np.sort(trace["w"].mean(axis=0)), np.sort(self.norm_w), rtol=0.1, atol=0.1)
assert_allclose(
np.sort(trace["mu"].mean(axis=0)), np.sort(self.norm_mu), rtol=0.1, atol=0.1
)
assert_allclose(np.sort(trace["w"].mean(axis=0)), np.sort(norm_w), rtol=0.1, atol=0.1)
assert_allclose(np.sort(trace["mu"].mean(axis=0)), np.sort(norm_mu), rtol=0.1, atol=0.1)

@pytest.mark.parametrize(
"nd,ncomp", [(tuple(), 5), (1, 5), (3, 5), ((3, 3), 5), (3, 3), ((3, 3), 3)], ids=str
"nd, ncomp", [(tuple(), 5), (1, 5), (3, 5), ((3, 3), 5), (3, 3), ((3, 3), 3)], ids=str
)
def test_normal_mixture_nd(self, nd, ncomp):
nd = to_tuple(nd)
Expand All @@ -608,7 +613,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
ws = Dirichlet("ws", np.ones(ncomp), shape=(ncomp,))
mixture0 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd, comp_shape=comp_shape)
obs0 = NormalMixture(
"obs", w=ws, mu=mus, tau=taus, shape=nd, comp_shape=comp_shape, observed=observed
"obs", w=ws, mu=mus, tau=taus, comp_shape=comp_shape, observed=observed
)

with Model() as model1:
Expand All @@ -619,53 +624,27 @@ def test_normal_mixture_nd(self, nd, ncomp):
Normal.dist(mu=mus[..., i], tau=taus[..., i], shape=nd) for i in range(ncomp)
]
mixture1 = Mixture("m", w=ws, comp_dists=comp_dist, shape=nd)
obs1 = Mixture("obs", w=ws, comp_dists=comp_dist, shape=nd, observed=observed)
obs1 = Mixture("obs", w=ws, comp_dists=comp_dist, observed=observed)

with Model() as model2:
# Expected to fail if comp_shape is not provided,
# nd is multidim and it does not broadcast with ncomp. If by chance
# it does broadcast, an error is raised if the mixture is given
# observed data.
# Furthermore, the Mixture will also raise errors when the observed
# data is multidimensional but it does not broadcast well with
# comp_dists.
# Test that results are correct without comp_shape being passed to the Mixture.
# This used to fail in V3
mus = Normal("mus", shape=comp_shape)
taus = Gamma("taus", alpha=1, beta=1, shape=comp_shape)
ws = Dirichlet("ws", np.ones(ncomp), shape=(ncomp,))
if len(nd) > 1:
if nd[-1] != ncomp:
with pytest.raises(ValueError):
NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd)
mixture2 = None
else:
mixture2 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd)
else:
mixture2 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd)
observed_fails = False
if len(nd) >= 1 and nd != (1,):
try:
np.broadcast(np.empty(comp_shape), observed)
except Exception:
observed_fails = True
if observed_fails:
with pytest.raises(ValueError):
NormalMixture("obs", w=ws, mu=mus, tau=taus, shape=nd, observed=observed)
obs2 = None
else:
obs2 = NormalMixture("obs", w=ws, mu=mus, tau=taus, shape=nd, observed=observed)
mixture2 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd)
obs2 = NormalMixture("obs", w=ws, mu=mus, tau=taus, observed=observed)

testpoint = model0.compute_initial_point()
testpoint["mus"] = test_mus
testpoint["taus"] = test_taus
assert_allclose(model0.logp(testpoint), model1.logp(testpoint))
assert_allclose(mixture0.logp(testpoint), mixture1.logp(testpoint))
assert_allclose(obs0.logp(testpoint), obs1.logp(testpoint))
if mixture2 is not None and obs2 is not None:
assert_allclose(model0.logp(testpoint), model2.logp(testpoint))
if mixture2 is not None:
assert_allclose(mixture0.logp(testpoint), mixture2.logp(testpoint))
if obs2 is not None:
assert_allclose(obs0.logp(testpoint), obs2.logp(testpoint))
testpoint["taus_log__"] = np.log(test_taus)
for logp0, logp1, logp2 in zip(
model0.compile_logp(vars=[mixture0, obs0], sum=False)(testpoint),
model1.compile_logp(vars=[mixture1, obs1], sum=False)(testpoint),
model2.compile_logp(vars=[mixture2, obs2], sum=False)(testpoint),
):
assert_allclose(logp0, logp1)
assert_allclose(logp0, logp2)

def test_random(self):
def ref_rand(size, w, mu, sigma):
Expand All @@ -682,6 +661,7 @@ def ref_rand(size, w, mu, sigma):
extra_args={"comp_shape": 2},
size=1000,
ref_rand=ref_rand,
change_rv_size_fn=Mixture.change_size,
)
pymc_random(
NormalMixture,
Expand All @@ -693,6 +673,7 @@ def ref_rand(size, w, mu, sigma):
extra_args={"comp_shape": 3},
size=1000,
ref_rand=ref_rand,
change_rv_size_fn=Mixture.change_size,
)


Expand Down

0 comments on commit 14e06dc

Please sign in to comment.