From 9b818b820b7b8686220a318ff81803d71e2415b2 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Tue, 7 Dec 2021 07:52:37 +0100 Subject: [PATCH 01/10] Remove non-aeppl _logcdf singledispatch --- pymc/distributions/logprob.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index a0a9eb7e9..5b12dd37e 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -276,18 +276,6 @@ def logcdf(rv, value): return logcdf_aeppl(rv, value) -@singledispatch -def _logcdf(op, values, *args, **kwargs): - """Create a log-CDF graph. - - This function dispatches on the type of `op`, which should be a subclass - of `RandomVariable`. If you want to implement new log-CDF graphs - for a `RandomVariable`, register a new function on this dispatcher. - - """ - raise NotImplementedError() - - def logpt_sum(*args, **kwargs): """Return the sum of the logp values for the given observations. From 3ed512c38597234b923e52f865b8613149749885 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Tue, 7 Dec 2021 07:52:57 +0100 Subject: [PATCH 02/10] Add deprecation warning to logpt_sum --- pymc/distributions/logprob.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index 5b12dd37e..0a7ee30ce 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -282,5 +282,9 @@ def logpt_sum(*args, **kwargs): Subclasses can use this to improve the speed of logp evaluations if only the sum of the logp values is needed. """ - # TODO: Deprecate this + warnings.warn( + "logpt_sum has been deprecated, you can use logpt instead, which now defaults" + "to the same behavior of logpt_sum", + DeprecationWarning, + ) return logpt(*args, sum=True, **kwargs) From fd12b83ed73095d8c9dcad6daf44264ac7239910 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Tue, 7 Dec 2021 11:27:24 +0100 Subject: [PATCH 03/10] Use observation as value_var of observed rvs --- pymc/model.py | 25 ++++++++++++------------- pymc/tests/test_missing.py | 9 ++++----- pymc/tests/test_smc.py | 4 ++-- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index 4dcb20f16..2cff1c589 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1275,7 +1275,7 @@ def make_obs_var( observed_rv_var.tag.observations = nonmissing_data - self.create_value_var(observed_rv_var, transform) + self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data) self.add_random_variable(observed_rv_var, dims) self.observed_RVs.append(observed_rv_var) @@ -1285,22 +1285,21 @@ def make_obs_var( rv_var = at.set_subtensor(rv_var[antimask_idx], observed_rv_var) rv_var = Deterministic(name, rv_var, self, dims, auto=True) - elif sps.issparse(data): - data = sparse.basic.as_sparse(data, name=name) - rv_var.tag.observations = data - self.create_value_var(rv_var, transform) - self.add_random_variable(rv_var, dims) - self.observed_RVs.append(rv_var) else: - data = at.as_tensor_variable(data, name=name) + if sps.issparse(data): + data = sparse.basic.as_sparse(data, name=name) + else: + data = at.as_tensor_variable(data, name=name) rv_var.tag.observations = data - self.create_value_var(rv_var, transform) + self.create_value_var(rv_var, transform=None, value_var=data) self.add_random_variable(rv_var, dims) self.observed_RVs.append(rv_var) return rv_var - def create_value_var(self, rv_var: TensorVariable, transform: Any) -> TensorVariable: + def create_value_var( + self, rv_var: TensorVariable, transform: Any, value_var: Optional[Variable] = None + ) -> TensorVariable: """Create a ``TensorVariable`` that will be used as the random variable's "value" in log-likelihood graphs. @@ -1311,13 +1310,13 @@ def create_value_var(self, rv_var: TensorVariable, transform: Any) -> TensorVari this branch of the conditional. """ - value_var = rv_var.type() + if value_var is None: + value_var = rv_var.type() + value_var.name = rv_var.name if aesara.config.compute_test_value != "off": value_var.tag.test_value = rv_var.tag.test_value - value_var.name = rv_var.name - rv_var.tag.value_var = value_var # Make the value variable a transformed value variable, diff --git a/pymc/tests/test_missing.py b/pymc/tests/test_missing.py index 2769b8862..050ba8ef5 100644 --- a/pymc/tests/test_missing.py +++ b/pymc/tests/test_missing.py @@ -19,8 +19,7 @@ from numpy import array, ma -from pymc.distributions.continuous import Gamma, Normal, Uniform -from pymc.distributions.transforms import interval +from pymc.distributions import Gamma, Normal, Uniform from pymc.exceptions import ImputationWarning from pymc.model import Model from pymc.sampling import sample, sample_posterior_predictive, sample_prior_predictive @@ -94,10 +93,10 @@ def test_interval_missing_observations(): with pytest.warns(ImputationWarning): theta2 = Normal("theta2", mu=theta1, observed=obs2, rng=rng) - assert "theta1_observed_interval__" in model.named_vars + assert "theta1_observed" in model.named_vars assert "theta1_missing_interval__" in model.named_vars - assert isinstance( - model.rvs_to_values[model.named_vars["theta1_observed"]].tag.transform, interval + assert not hasattr( + model.rvs_to_values[model.named_vars["theta1_observed"]].tag, "transform" ) prior_trace = sample_prior_predictive(return_inferencedata=False) diff --git a/pymc/tests/test_smc.py b/pymc/tests/test_smc.py index e862018ba..2d4a6e060 100644 --- a/pymc/tests/test_smc.py +++ b/pymc/tests/test_smc.py @@ -409,12 +409,12 @@ def test_multiple_simulators(self): a_val = m.rvs_to_values[a] sim1_val = m.rvs_to_values[sim1] logp_sim1 = pm.logpt(sim1, sim1_val) - logp_sim1_fn = aesara.function([sim1_val, a_val], logp_sim1) + logp_sim1_fn = aesara.function([a_val], logp_sim1) b_val = m.rvs_to_values[b] sim2_val = m.rvs_to_values[sim2] logp_sim2 = pm.logpt(sim2, sim2_val) - logp_sim2_fn = aesara.function([sim2_val, b_val], logp_sim2) + logp_sim2_fn = aesara.function([b_val], logp_sim2) assert any( node for node in logp_sim1_fn.maker.fgraph.toposort() if isinstance(node.op, SortOp) From 022cefa3e6b64f73a8e91f927c4db3d7454c4418 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Tue, 7 Dec 2021 11:30:13 +0100 Subject: [PATCH 04/10] Add NotImplementedError for partial observed multivariate variables --- pymc/model.py | 5 +++++ pymc/tests/test_missing.py | 27 ++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/pymc/model.py b/pymc/model.py index 2cff1c589..c6e81ce1c 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1239,6 +1239,11 @@ def make_obs_var( ) warnings.warn(impute_message, ImputationWarning) + if rv_var.owner.op.ndim_supp > 0: + raise NotImplementedError( + f"Automatic inputation is only supported for univariate RandomVariables, but {rv_var} is multivariate" + ) + # We can get a random variable comprised of only the unobserved # entries by lifting the indices through the `RandomVariable` `Op`. diff --git a/pymc/tests/test_missing.py b/pymc/tests/test_missing.py index 050ba8ef5..de5b2a540 100644 --- a/pymc/tests/test_missing.py +++ b/pymc/tests/test_missing.py @@ -19,7 +19,7 @@ from numpy import array, ma -from pymc.distributions import Gamma, Normal, Uniform +from pymc.distributions import Dirichlet, Gamma, Normal, Uniform from pymc.exceptions import ImputationWarning from pymc.model import Model from pymc.sampling import sample, sample_posterior_predictive, sample_prior_predictive @@ -163,3 +163,28 @@ def test_missing_logp(): m_missing_logp = m_missing.logp({"theta1_missing": [2, 4], "theta2_missing": [0, 1, 3]}) assert m_logp == m_missing_logp + + +def test_missing_multivariate(): + """Test model with missing variables whose transform changes base shape still works""" + + with Model() as m_miss: + with pytest.raises( + NotImplementedError, + match="Automatic inputation is only supported for univariate RandomVariables", + ): + x = Dirichlet( + "x", a=[1, 2, 3], observed=np.array([[0.3, 0.3, 0.4], [np.nan, np.nan, np.nan]]) + ) + + # TODO: Test can be used when local_subtensor_rv_lift supports multivariate distributions + # from pymc.distributions.transforms import simplex + # + # with Model() as m_unobs: + # x = Dirichlet("x", a=[1, 2, 3]) + # + # inp_vals = simplex.forward(np.array([0.3, 0.3, 0.4])).eval() + # assert np.isclose( + # m_miss.logp({"x_missing_simplex__": inp_vals}), + # m_unobs.logp_nojac({"x_simplex__": inp_vals}) * 2, + # ) From a5240aec107dd3607669a06226dc740041e4b0cf Mon Sep 17 00:00:00 2001 From: Ricardo Date: Sun, 12 Dec 2021 20:11:39 +0100 Subject: [PATCH 05/10] Test that missing values work as expected in distribution with vector parameters --- pymc/tests/test_missing.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/pymc/tests/test_missing.py b/pymc/tests/test_missing.py index de5b2a540..27b593d62 100644 --- a/pymc/tests/test_missing.py +++ b/pymc/tests/test_missing.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd import pytest +import scipy.stats from numpy import array, ma @@ -188,3 +189,21 @@ def test_missing_multivariate(): # m_miss.logp({"x_missing_simplex__": inp_vals}), # m_unobs.logp_nojac({"x_simplex__": inp_vals}) * 2, # ) + + +def test_missing_vector_parameter(): + with Model() as m: + x = Normal( + "x", + np.array([-10, 10]), + 0.1, + observed=np.array([[np.nan, 10], [-10, np.nan], [np.nan, np.nan]]), + ) + x_draws = x.eval() + assert x_draws.shape == (3, 2) + assert np.all(x_draws[:, 0] < 0) + assert np.all(x_draws[:, 1] > 0) + assert np.isclose( + m.logp({"x_missing": np.array([-10, 10, -10, 10])}), + scipy.stats.norm(scale=0.1).logpdf(0) * 6, + ) From c28b9c84a016c66d3e1158da522a1976f9e4c5ee Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 8 Dec 2021 13:06:53 +0100 Subject: [PATCH 06/10] Change rng of partially observed RVs --- pymc/model.py | 10 ++++++++++ pymc/tests/test_missing.py | 25 +++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/pymc/model.py b/pymc/model.py index c6e81ce1c..371d61aff 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1276,6 +1276,16 @@ def make_obs_var( clone=False, ) (observed_rv_var,) = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner) + # Make a clone of the RV, but change the rng so that observed and missing + # are not treated as equivalent nodes by aesara. This would happen if the + # size of the masked and unmasked array happened to coincide + _, size, _, *inps = observed_rv_var.owner.inputs + rng = self.model.next_rng() + observed_rv_var = observed_rv_var.owner.op(*inps, size=size, rng=rng) + # Add default_update to new rng + new_rng = observed_rv_var.owner.outputs[0] + observed_rv_var.update = (rng, new_rng) + rng.default_update = new_rng observed_rv_var.name = f"{name}_observed" observed_rv_var.tag.observations = nonmissing_data diff --git a/pymc/tests/test_missing.py b/pymc/tests/test_missing.py index 27b593d62..115175df2 100644 --- a/pymc/tests/test_missing.py +++ b/pymc/tests/test_missing.py @@ -18,8 +18,10 @@ import pytest import scipy.stats +from aesara.graph import graph_inputs from numpy import array, ma +from pymc import logpt from pymc.distributions import Dirichlet, Gamma, Normal, Uniform from pymc.exceptions import ImputationWarning from pymc.model import Model @@ -207,3 +209,26 @@ def test_missing_vector_parameter(): m.logp({"x_missing": np.array([-10, 10, -10, 10])}), scipy.stats.norm(scale=0.1).logpdf(0) * 6, ) + + +def test_missing_symmetric(): + """Check that logpt works when partially observed variable have equal observed and + unobserved dimensions. + + This would fail in a previous implementation because the two variables would be + equivalent and one of them would be discarded during MergeOptimization while + buling the logpt graph + """ + with Model() as m: + x = Gamma("x", alpha=3, beta=10, observed=np.array([1, np.nan])) + + x_obs_rv = m["x_observed"] + x_obs_vv = m.rvs_to_values[x_obs_rv] + + x_unobs_rv = m["x_missing"] + x_unobs_vv = m.rvs_to_values[x_unobs_rv] + + logp = logpt([x_obs_rv, x_unobs_rv], {x_obs_rv: x_obs_vv, x_unobs_rv: x_unobs_vv}) + logp_inputs = list(graph_inputs([logp])) + assert x_obs_vv in logp_inputs + assert x_unobs_vv in logp_inputs From afad6acc405c7f8e371e72d34c2478aa97311892 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Tue, 7 Dec 2021 07:34:50 +0100 Subject: [PATCH 07/10] Refactor logpt and raise more informative ValueErrors --- pymc/distributions/logprob.py | 116 ++++++++++-------------- pymc/tests/test_distributions.py | 4 +- pymc/tests/test_distributions_random.py | 8 +- pymc/tests/test_transforms.py | 80 ++++++++-------- 4 files changed, 100 insertions(+), 108 deletions(-) diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index 0a7ee30ce..8b6144275 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -24,10 +24,8 @@ from aeppl.logprob import logcdf as logcdf_aeppl from aeppl.logprob import logprob as logp_aeppl from aeppl.transforms import TransformValuesOpt -from aesara import config from aesara.graph.basic import graph_inputs, io_toposort -from aesara.graph.op import Op, compute_test_value -from aesara.tensor.random.op import RandomVariable +from aesara.graph.op import Op from aesara.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -164,59 +162,64 @@ def logpt( # joint_logprob directly. # If var is not a list make it one. - if not isinstance(var, list): + if not isinstance(var, (list, tuple)): var = [var] - # If logpt isn't provided values and the variable (provided in var) - # is an RV, it is assumed that the tagged value var or observation is - # the value variable for that particular RV. + # If logpt isn't provided values it is assumed that the tagged value var or + # observation is the value variable for that particular RV. if rv_values is None: rv_values = {} - for _var in var: - if isinstance(_var.owner.op, RandomVariable): - rv_value_var = getattr( - _var.tag, "observations", getattr(_var.tag, "value_var", _var) - ) - rv_values = {_var: rv_value_var} + for rv in var: + value_var = getattr(rv.tag, "observations", getattr(rv.tag, "value_var", None)) + if value_var is None: + raise ValueError(f"No value variable found for var {rv}") + rv_values[rv] = value_var + # Else we assume we were given a single rv and respective value elif not isinstance(rv_values, Mapping): - # Else if we're given a single value and a single variable we assume a mapping among them. - rv_values = ( - {var[0]: at.as_tensor_variable(rv_values).astype(var[0].type)} if len(var) == 1 else {} - ) - - # Since the filtering of logp graph is based on value variables - # provided to this function - if not rv_values: - warnings.warn("No value variables provided the logp will be an empty graph") + if len(var) == 1: + rv_values = {var[0]: at.as_tensor_variable(rv_values).astype(var[0].type)} + else: + raise ValueError("rv_values must be a dict if more than one var is requested") if scaling: rv_scalings = {} - for _var in var: - rv_value_var = getattr(_var.tag, "observations", getattr(_var.tag, "value_var", _var)) - rv_scalings[rv_value_var] = _get_scaling( - getattr(_var.tag, "total_size", None), rv_value_var.shape, rv_value_var.ndim + for rv, value_var in rv_values.items(): + rv_scalings[value_var] = _get_scaling( + getattr(rv.tag, "total_size", None), value_var.shape, value_var.ndim ) # Aeppl needs all rv-values pairs, not just that of the requested var. # Hence we iterate through the graph to collect them. tmp_rvs_to_values = rv_values.copy() - transform_map = {} for node in io_toposort(graph_inputs(var), var): try: curr_vars = [node.default_output()] except ValueError: curr_vars = node.outputs for curr_var in curr_vars: - rv_value_var = getattr( + if curr_var in tmp_rvs_to_values: + continue + # Check if variable has a value variable + value_var = getattr( curr_var.tag, "observations", getattr(curr_var.tag, "value_var", None) ) - if rv_value_var is None: - continue - rv_value = rv_values.get(curr_var, rv_value_var) - tmp_rvs_to_values[curr_var] = rv_value - # Along with value variables we also check for transforms if any. - if hasattr(rv_value_var.tag, "transform") and transformed: - transform_map[rv_value] = rv_value_var.tag.transform + if value_var is not None: + tmp_rvs_to_values[curr_var] = value_var + + # After collecting all necessary rvs and values, we check for any value transforms + transform_map = {} + if transformed: + for rv, value_var in tmp_rvs_to_values.items(): + if hasattr(value_var.tag, "transform"): + transform_map[value_var] = value_var.tag.transform + # If the provided value_variable does not have transform information, we + # check if the original `rv.tag.value_var` does. + # TODO: This logic should be replaced by an explicit dict of + # `{value_var: transform}` similar to `rv_values`. + else: + original_value_var = getattr(rv.tag, "value_var", None) + if original_value_var is not None and hasattr(original_value_var.tag, "transform"): + transform_map[value_var] = original_value_var.tag.transform transform_opt = TransformValuesOpt(transform_map) temp_logp_var_dict = factorized_joint_logprob( @@ -224,40 +227,21 @@ def logpt( ) # aeppl returns the logpt for every single value term we provided to it. This includes - # the extra values we plugged in above so we need to filter those out. + # the extra values we plugged in above, so we filter those we actually wanted in the + # same order they were given in. logp_var_dict = {} - for value_var, _logp in temp_logp_var_dict.items(): - if value_var in rv_values.values(): - logp_var_dict[value_var] = _logp + for value_var in rv_values.values(): + logp_var_dict[value_var] = temp_logp_var_dict[value_var] - # If it's an empty dictionary the logp is None - if not logp_var_dict: - logp_var = None - else: - # Otherwise apply appropriate scalings and at.add and/or at.sum the - # graphs accordingly. - if scaling: - for _value in logp_var_dict.keys(): - if _value in rv_scalings: - logp_var_dict[_value] *= rv_scalings[_value] - - if len(logp_var_dict) == 1: - logp_var_dict = tuple(logp_var_dict.values())[0] - if sum: - logp_var = at.sum(logp_var_dict) - else: - logp_var = logp_var_dict - else: - if sum: - logp_var = at.sum([at.sum(factor) for factor in logp_var_dict.values()]) - else: - logp_var = at.add(*logp_var_dict.values()) + if scaling: + for value_var in logp_var_dict.keys(): + if value_var in rv_scalings: + logp_var_dict[value_var] *= rv_scalings[value_var] - # Recompute test values for the changes introduced by the replacements - # above. - if config.compute_test_value != "off": - for node in io_toposort(graph_inputs((logp_var,)), (logp_var,)): - compute_test_value(node) + if sum: + logp_var = at.sum([at.sum(factor) for factor in logp_var_dict.values()]) + else: + logp_var = at.add(*logp_var_dict.values()) return logp_var diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index a4c1ee472..8f4ad3e2d 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -2521,9 +2521,11 @@ def test_continuous(self): assert logpt(InfBoundedNormal, 0).eval() != -np.inf assert logpt(InfBoundedNormal, 11).eval() != -np.inf - value = at.dscalar("x") + value = model.rvs_to_values[LowerNormalTransform] assert logpt(LowerNormalTransform, value).eval({value: -1}) != -np.inf + value = model.rvs_to_values[UpperNormalTransform] assert logpt(UpperNormalTransform, value).eval({value: 1}) != -np.inf + value = model.rvs_to_values[BoundedNormalTransform] assert logpt(BoundedNormalTransform, value).eval({value: 0}) != -np.inf assert logpt(BoundedNormalTransform, value).eval({value: 11}) != -np.inf diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 0216de75f..ab97fa26c 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -45,7 +45,7 @@ def random_polyagamma(*args, **kwargs): from pymc.distributions.continuous import get_tau_sigma, interpolated from pymc.distributions.discrete import _OrderedLogistic, _OrderedProbit from pymc.distributions.dist_math import clipped_beta_rvs -from pymc.distributions.logprob import logpt +from pymc.distributions.logprob import logp from pymc.distributions.multivariate import _OrderedMultinomial, quaddist_matrix from pymc.distributions.shape_utils import to_tuple from pymc.tests.helpers import SeededTest, select_by_precision @@ -1626,8 +1626,8 @@ def test_errors(self): rowcov=np.eye(3), colcov=np.eye(3), ) - with pytest.raises(TypeError): - logpt(matrixnormal, aesara.tensor.ones((3, 3, 3))) + with pytest.raises(ValueError): + logp(matrixnormal, aesara.tensor.ones((3, 3, 3))) with pm.Model(): with pytest.warns(FutureWarning): @@ -1856,7 +1856,7 @@ def test_density_dist_without_random(self): pm.DensityDist( "density_dist", mu, - logp=lambda value, mu: logpt(pm.Normal.dist(mu, 1, size=100), value), + logp=lambda value, mu: logp(pm.Normal.dist(mu, 1, size=100), value), observed=np.random.randn(100), initval=0, ) diff --git a/pymc/tests/test_transforms.py b/pymc/tests/test_transforms.py index 8fafd50e9..d38300cdc 100644 --- a/pymc/tests/test_transforms.py +++ b/pymc/tests/test_transforms.py @@ -23,7 +23,7 @@ import pymc as pm import pymc.distributions.transforms as tr -from pymc.aesaraf import jacobian +from pymc.aesaraf import floatX, jacobian from pymc.distributions import logpt from pymc.tests.checks import close_to, close_to_logical from pymc.tests.helpers import SeededTest @@ -285,40 +285,46 @@ def build_model(self, distfam, params, size, transform, initval=None): def check_transform_elementwise_logp(self, model): x = model.free_RVs[0] - x0 = x.tag.value_var - assert x.ndim == logpt(x, sum=False).ndim + x_val_transf = x.tag.value_var - pt = model.initial_point - array = np.random.randn(*pt[x0.name].shape) - transform = x0.tag.transform - logp_notrans = logpt(x, transform.backward(array, *x.owner.inputs), transformed=False) + pt = model.recompute_initial_point(0) + test_array_transf = floatX(np.random.randn(*pt[x_val_transf.name].shape)) + transform = x_val_transf.tag.transform + test_array_untransf = transform.backward(test_array_transf, *x.owner.inputs).eval() - jacob_det = transform.log_jac_det(aesara.shared(array), *x.owner.inputs) - assert logpt(x, sum=False).ndim == jacob_det.ndim + # Create input variable with same dimensionality as untransformed test_array + x_val_untransf = at.constant(test_array_untransf).type() - v1 = logpt(x, array, jacobian=False).eval() - v2 = logp_notrans.eval() + jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs) + assert logpt(x, sum=False).ndim == x.ndim == jacob_det.ndim + + v1 = logpt(x, x_val_transf, jacobian=False).eval({x_val_transf: test_array_transf}) + v2 = logpt(x, x_val_untransf, transformed=False).eval({x_val_untransf: test_array_untransf}) close_to(v1, v2, tol) - def check_vectortransform_elementwise_logp(self, model, vect_opt=0): + def check_vectortransform_elementwise_logp(self, model): x = model.free_RVs[0] - x0 = x.tag.value_var - # TODO: For some reason the ndim relations - # dont hold up here. But final log-probablity - # values are what we expected. - # assert (x.ndim - 1) == logpt(x, sum=False).ndim - - pt = model.initial_point - array = np.random.randn(*pt[x0.name].shape) - transform = x0.tag.transform - logp_nojac = logpt(x, transform.backward(array, *x.owner.inputs), transformed=False) - - jacob_det = transform.log_jac_det(aesara.shared(array), *x.owner.inputs) - # assert logpt(x).ndim == jacob_det.ndim - + x_val_transf = x.tag.value_var + + pt = model.recompute_initial_point(0) + test_array_transf = floatX(np.random.randn(*pt[x_val_transf.name].shape)) + transform = x_val_transf.tag.transform + test_array_untransf = transform.backward(test_array_transf, *x.owner.inputs).eval() + + # Create input variable with same dimensionality as untransformed test_array + x_val_untransf = at.constant(test_array_untransf).type() + + jacob_det = transform.log_jac_det(test_array_transf, *x.owner.inputs) + # Original distribution is univariate + if x.owner.op.ndim_supp == 0: + assert logpt(x, sum=False).ndim == x.ndim == (jacob_det.ndim + 1) + # Original distribution is multivariate + else: + assert logpt(x, sum=False).ndim == (x.ndim - 1) == jacob_det.ndim + + a = logpt(x, x_val_transf, jacobian=False).eval({x_val_transf: test_array_transf}) + b = logpt(x, x_val_untransf, transformed=False).eval({x_val_untransf: test_array_untransf}) # Hack to get relative tolerance - a = logpt(x, array.astype(aesara.config.floatX), jacobian=False).eval() - b = logp_nojac.eval() close_to(a, b, np.abs(0.5 * (a + b) * tol)) @pytest.mark.parametrize( @@ -406,7 +412,7 @@ def test_vonmises(self, mu, kappa, size): ) def test_dirichlet(self, a, size): model = self.build_model(pm.Dirichlet, {"a": a}, size=size, transform=tr.simplex) - self.check_vectortransform_elementwise_logp(model, vect_opt=1) + self.check_vectortransform_elementwise_logp(model) def test_normal_ordered(self): model = self.build_model( @@ -416,7 +422,7 @@ def test_normal_ordered(self): initval=np.asarray([-1.0, 1.0, 4.0]), transform=tr.ordered, ) - self.check_vectortransform_elementwise_logp(model, vect_opt=0) + self.check_vectortransform_elementwise_logp(model) @pytest.mark.parametrize( "sd,size", @@ -434,7 +440,7 @@ def test_half_normal_ordered(self, sd, size): initval=initval, transform=tr.Chain([tr.log, tr.ordered]), ) - self.check_vectortransform_elementwise_logp(model, vect_opt=0) + self.check_vectortransform_elementwise_logp(model) @pytest.mark.parametrize("lam,size", [(2.5, (2,)), (np.ones(3), (4, 3))]) def test_exponential_ordered(self, lam, size): @@ -446,7 +452,7 @@ def test_exponential_ordered(self, lam, size): initval=initval, transform=tr.Chain([tr.log, tr.ordered]), ) - self.check_vectortransform_elementwise_logp(model, vect_opt=0) + self.check_vectortransform_elementwise_logp(model) @pytest.mark.parametrize( "a,b,size", @@ -468,7 +474,7 @@ def test_beta_ordered(self, a, b, size): initval=initval, transform=tr.Chain([tr.logodds, tr.ordered]), ) - self.check_vectortransform_elementwise_logp(model, vect_opt=0) + self.check_vectortransform_elementwise_logp(model) @pytest.mark.parametrize( "lower,upper,size", @@ -491,7 +497,7 @@ def transform_params(*inputs): initval=initval, transform=tr.Chain([interval, tr.ordered]), ) - self.check_vectortransform_elementwise_logp(model, vect_opt=1) + self.check_vectortransform_elementwise_logp(model) @pytest.mark.parametrize("mu,kappa,size", [(0.0, 1.0, (2,)), (np.zeros(3), np.ones(3), (4, 3))]) def test_vonmises_ordered(self, mu, kappa, size): @@ -503,7 +509,7 @@ def test_vonmises_ordered(self, mu, kappa, size): initval=initval, transform=tr.Chain([tr.circular, tr.ordered]), ) - self.check_vectortransform_elementwise_logp(model, vect_opt=0) + self.check_vectortransform_elementwise_logp(model) @pytest.mark.parametrize( "lower,upper,size,transform", @@ -522,7 +528,7 @@ def test_uniform_other(self, lower, upper, size, transform): initval=initval, transform=transform, ) - self.check_vectortransform_elementwise_logp(model, vect_opt=1) + self.check_vectortransform_elementwise_logp(model) @pytest.mark.parametrize( "mu,cov,size,shape", @@ -536,7 +542,7 @@ def test_mvnormal_ordered(self, mu, cov, size, shape): model = self.build_model( pm.MvNormal, {"mu": mu, "cov": cov}, size=size, initval=initval, transform=tr.ordered ) - self.check_vectortransform_elementwise_logp(model, vect_opt=1) + self.check_vectortransform_elementwise_logp(model) def test_triangular_transform(): From f5467602096501ed368449ce1c2cc457df56557c Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 8 Dec 2021 09:32:36 +0100 Subject: [PATCH 08/10] Return separate logp terms when logpt is called with `sum==False` --- pymc/distributions/logprob.py | 16 +++++++++++----- pymc/tests/test_logprob.py | 3 ++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/pymc/distributions/logprob.py b/pymc/distributions/logprob.py index 8b6144275..956e555d1 100644 --- a/pymc/distributions/logprob.py +++ b/pymc/distributions/logprob.py @@ -15,7 +15,7 @@ from collections.abc import Mapping from functools import singledispatch -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union import aesara.tensor as at import numpy as np @@ -119,7 +119,7 @@ def _get_scaling(total_size, shape, ndim): def logpt( - var: TensorVariable, + var: Union[TensorVariable, List[TensorVariable]], rv_values: Optional[Union[TensorVariable, Dict[TensorVariable, TensorVariable]]] = None, *, jacobian: bool = True, @@ -127,7 +127,7 @@ def logpt( transformed: bool = True, sum: bool = True, **kwargs, -) -> TensorVariable: +) -> Union[TensorVariable, List[TensorVariable]]: """Create a measure-space (i.e. log-likelihood) graph for a random variable or a list of random variables at a given point. @@ -154,7 +154,7 @@ def logpt( transformed Apply transforms. sum - Sum the log-likelihood. + Sum the log-likelihood or return each term as a separate list item. """ # TODO: In future when we drop support for tag.value_var most of the following @@ -241,7 +241,13 @@ def logpt( if sum: logp_var = at.sum([at.sum(factor) for factor in logp_var_dict.values()]) else: - logp_var = at.add(*logp_var_dict.values()) + logp_var = list(logp_var_dict.values()) + # TODO: deprecate special behavior when only one variable is requested and + # always return a list. This is here for backwards compatibility as logpt + # started as a replacement to factor.logpt, but it should now be considered an + # internal function reached only via model.logp* methods. + if len(logp_var) == 1: + logp_var = logp_var[0] return logp_var diff --git a/pymc/tests/test_logprob.py b/pymc/tests/test_logprob.py index c12fbc92c..53a1061cb 100644 --- a/pymc/tests/test_logprob.py +++ b/pymc/tests/test_logprob.py @@ -144,7 +144,8 @@ def test_logpt_subtensor(): I_value_var = I_rv.type() I_value_var.name = "I_value" - A_idx_logp = logpt(A_idx, {A_idx: A_idx_value_var, I_rv: I_value_var}, sum=False) + A_idx_logps = logpt(A_idx, {A_idx: A_idx_value_var, I_rv: I_value_var}, sum=False) + A_idx_logp = at.add(*A_idx_logps) logp_vals_fn = aesara.function([A_idx_value_var, I_value_var], A_idx_logp) From 2e9d54ca41135c610b60b8a65dea77b103af5d7b Mon Sep 17 00:00:00 2001 From: Ricardo Date: Tue, 7 Dec 2021 11:17:26 +0100 Subject: [PATCH 09/10] Add model.logp_elemswiset --- pymc/model.py | 65 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 3 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index 371d61aff..dd526561d 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -284,9 +284,8 @@ def logp(self): """Compiled log probability density function""" return self.model.fn(self.logpt) - @property - def logp_elemwise(self): - return self.model.fn(self.logp_elemwiset) + def logp_elemwise(self, vars=None, jacobian=True): + return self.model.fn(self.logp_elemwiset(vars=vars, jacobian=jacobian)) def dlogp(self, vars=None): """Compiled log probability density gradient function""" @@ -728,6 +727,66 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs): } return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs) + def logp_elemwiset( + self, + vars: Optional[Union[Variable, List[Variable]]] = None, + jacobian: bool = True, + ) -> List[Variable]: + """Elemwise log-probability of the model. + + Parameters + ---------- + vars: list of random variables or potential terms, optional + Compute the gradient with respect to those variables. If None, use all + free and observed random variables, as well as potential terms in model. + jacobian + Whether to include jacobian terms in logprob graph. Defaults to True. + + Returns + ------- + Elemwise logp terms for ecah requested variable, in the same order of input. + """ + if vars is None: + vars = self.free_RVs + self.observed_RVs + self.potentials + elif not isinstance(vars, (list, tuple)): + vars = [vars] + + # We need to separate random variables from potential terms, and remember their + # original order so that we can merge them together in the same order at the end + rv_values = {} + potentials = [] + rv_order, potential_order = [], [] + for i, var in enumerate(vars): + value_var = self.rvs_to_values.get(var) + if value_var is not None: + rv_values[var] = value_var + rv_order.append(i) + else: + if var in self.potentials: + potentials.append(var) + potential_order.append(i) + else: + raise ValueError( + f"Requested variable {var} not found among the model variables" + ) + + rv_logps = [] + if rv_values: + rv_logps = logpt(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian) + if not isinstance(rv_logps, list): + rv_logps = [rv_logps] + + # Replace random variables by their value variables in potential terms + potential_logps = [] + if potentials: + potential_logps, _ = rvs_to_value_vars(potentials, apply_transforms=True) + + logp_elemwise = [None] * len(vars) + for logp_order, logp in zip((rv_order + potential_order), (rv_logps + potential_logps)): + logp_elemwise[logp_order] = logp + + return logp_elemwise + @property def logpt(self): """Aesara scalar of log-probability of the model""" From 046fc7a3c24261c393d75ab76742713b0dcf7ec2 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Tue, 7 Dec 2021 11:31:39 +0100 Subject: [PATCH 10/10] Use model.logp_elemwise in InferenceDataConverter --- pymc/backends/arviz.py | 9 ++++++--- pymc/sampling_jax.py | 2 +- pymc/tests/test_idata_conversion.py | 26 +++++++++++++++++++++----- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 26b784b36..ae44bb65e 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -24,7 +24,6 @@ import pymc from pymc.aesaraf import extract_obs_data -from pymc.distributions import logpt from pymc.model import modelcontext from pymc.util import get_default_varnames @@ -264,11 +263,15 @@ def _extract_log_likelihood(self, trace): if self.model is None: return None + # TODO: We no longer need one function per observed variable if self.log_likelihood is True: - cached = [(var, self.model.fn(logpt(var))) for var in self.model.observed_RVs] + cached = [ + (var, self.model.fn(self.model.logp_elemwiset(var)[0])) + for var in self.model.observed_RVs + ] else: cached = [ - (var, self.model.fn(logpt(var))) + (var, self.model.fn(self.model.logp_elemwiset(var)[0])) for var in self.model.observed_RVs if var.name in self.log_likelihood ] diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py index 051e9818f..a8785f3ab 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling_jax.py @@ -122,7 +122,7 @@ def _get_log_likelihood(model, samples): "Compute log-likelihood for all observations" data = {} for v in model.observed_RVs: - logp_v = replace_shared_variables([logpt(v)]) + logp_v = replace_shared_variables([model.logp_elemwiset(v)[0]]) fgraph = FunctionGraph(model.value_vars, logp_v, clone=False) optimize_graph(fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"]) jax_fn = jax_funcify(fgraph) diff --git a/pymc/tests/test_idata_conversion.py b/pymc/tests/test_idata_conversion.py index bfd275c31..f950f994b 100644 --- a/pymc/tests/test_idata_conversion.py +++ b/pymc/tests/test_idata_conversion.py @@ -143,6 +143,11 @@ def test_to_idata(self, data, eight_schools_params, chains, draws): np.isclose(ivalues[chain], values[chain * draws : (chain + 1) * draws]) ) + chains = inference_data.posterior.dims["chain"] + draws = inference_data.posterior.dims["draw"] + obs = inference_data.observed_data["obs"] + assert inference_data.log_likelihood["obs"].shape == (chains, draws) + obs.shape + def test_predictions_to_idata(self, data, eight_schools_params): "Test that we can add predictions to a previously-existing InferenceData." test_dict = { @@ -329,6 +334,11 @@ def test_missing_data_model(self): fails = check_multiple_attrs(test_dict, inference_data) assert not fails + # The missing part of partial observed RVs is not included in log_likelihood + # See https://github.com/pymc-devs/pymc/issues/5255 + assert inference_data.log_likelihood["y_observed"].shape == (2, 100, 3) + + @pytest.mark.xfal(reason="Multivariate partial observed RVs not implemented for V4") @pytest.mark.xfail(reason="LKJCholeskyCov not refactored for v4") def test_mv_missing_data_model(self): data = ma.masked_values([[1, 2], [2, 2], [-1, 4], [2, -1], [-1, -1]], value=-1) @@ -375,8 +385,12 @@ def test_multiple_observed_rv(self, log_likelihood): if not log_likelihood: test_dict.pop("log_likelihood") test_dict["~log_likelihood"] = [] - if isinstance(log_likelihood, list): + elif isinstance(log_likelihood, list): test_dict["log_likelihood"] = ["y1", "~y2"] + assert inference_data.log_likelihood["y1"].shape == (2, 100, 10) + else: + assert inference_data.log_likelihood["y1"].shape == (2, 100, 10) + assert inference_data.log_likelihood["y2"].shape == (2, 100, 100) fails = check_multiple_attrs(test_dict, inference_data) assert not fails @@ -445,12 +459,12 @@ def test_single_observation(self): inference_data = pm.sample(500, chains=2, return_inferencedata=True) assert inference_data + assert inference_data.log_likelihood["w"].shape == (2, 500, 1) - @pytest.mark.xfail(reason="Potential not refactored for v4") def test_potential(self): with pm.Model(): x = pm.Normal("x", 0.0, 1.0) - pm.Potential("z", logpt(pm.Normal.dist(x, 1.0), np.random.randn(10))) + pm.Potential("z", pm.logp(pm.Normal.dist(x, 1.0), np.random.randn(10))) inference_data = pm.sample(100, chains=2, return_inferencedata=True) assert inference_data @@ -463,7 +477,7 @@ def test_constant_data(self, use_context): y = pm.Data("y", [1.0, 2.0, 3.0]) beta = pm.Normal("beta", 0, 1) obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable - trace = pm.sample(100, tune=100, return_inferencedata=False) + trace = pm.sample(100, chains=2, tune=100, return_inferencedata=False) if use_context: inference_data = to_inference_data(trace=trace) @@ -472,6 +486,7 @@ def test_constant_data(self, use_context): test_dict = {"posterior": ["beta"], "observed_data": ["obs"], "constant_data": ["x"]} fails = check_multiple_attrs(test_dict, inference_data) assert not fails + assert inference_data.log_likelihood["obs"].shape == (2, 100, 3) def test_predictions_constant_data(self): with pm.Model(): @@ -570,7 +585,7 @@ def test_multivariate_observations(self): with pm.Model(coords=coords): p = pm.Beta("p", 1, 1, size=(3,)) pm.Multinomial("y", 20, p, dims=("experiment", "direction"), observed=data) - idata = pm.sample(draws=50, tune=100, return_inferencedata=True) + idata = pm.sample(draws=50, chains=2, tune=100, return_inferencedata=True) test_dict = { "posterior": ["p"], "sample_stats": ["lp"], @@ -581,6 +596,7 @@ def test_multivariate_observations(self): assert not fails assert "direction" not in idata.log_likelihood.dims assert "direction" in idata.observed_data.dims + assert idata.log_likelihood["y"].shape == (2, 50, 20) def test_constant_data_coords_issue_5046(self): """This is a regression test against a bug where a local coords variable was overwritten."""