diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 3020e3d01..9763d525f 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -29,3 +29,6 @@ methods in the current release of PyMC experimental. .. automodule:: pymc_experimental.utils.spline :members: bspline_interpolation +.. automodule:: pymc_experimental.utils.prior + :members: prior_from_idata + diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py new file mode 100644 index 000000000..34f907152 --- /dev/null +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -0,0 +1,141 @@ +import pymc_experimental as pmx +from pymc.distributions import transforms +import pytest +import arviz as az +import numpy as np +import pymc as pm + + +@pytest.mark.parametrize( + "case", + [ + (("a", dict(name="b")), dict(name="b", transform=None, dims=None)), + (("a", None), dict(name="a", transform=None, dims=None)), + (("a", transforms.log), dict(name="a", transform=transforms.log, dims=None)), + ( + ("a", dict(transform=transforms.log)), + dict(name="a", transform=transforms.log, dims=None), + ), + (("a", dict(name="b")), dict(name="b", transform=None, dims=None)), + (("a", dict(name="b", dims="test")), dict(name="b", transform=None, dims="test")), + (("a", ("test",)), dict(name="a", transform=None, dims=("test",))), + ], +) +def test_parsing_arguments(case): + inp, out = case + test = pmx.utils.prior._arg_to_param_cfg(*inp) + assert test == out + + +@pytest.fixture +def coords(): + return dict(test=range(3), simplex=range(4)) + + +@pytest.fixture +def user_param_cfg(): + return ("t",), dict( + a="d", + b=dict(transform=transforms.log, dims=("test",)), + c=dict(transform=transforms.simplex, dims=("simplex",)), + ) + + +@pytest.fixture +def param_cfg(user_param_cfg): + return pmx.utils.prior._parse_args(user_param_cfg[0], **user_param_cfg[1]) + + +@pytest.fixture +def transformed_data(param_cfg, coords): + vars = dict() + for k, cfg in param_cfg.items(): + if cfg["dims"] is not None: + extra_dims = [len(coords[d]) for d in cfg["dims"]] + if cfg["transform"] is not None: + t = np.random.randn(*extra_dims) + extra_dims = tuple(cfg["transform"].forward(t).shape.eval()) + else: + extra_dims = [] + orig = np.random.randn(4, 100, *extra_dims) + vars[k] = orig + return vars + + +@pytest.fixture +def idata(transformed_data, param_cfg): + vars = dict() + for k, orig in transformed_data.items(): + cfg = param_cfg[k] + if cfg["transform"] is not None: + var = cfg["transform"].backward(orig).eval() + else: + var = orig + assert not np.isnan(var).any() + vars[k] = var + return az.convert_to_inference_data(vars) + + +def test_idata_for_tests(idata, param_cfg): + assert set(idata.posterior.keys()) == set(param_cfg) + assert len(idata.posterior.coords["chain"]) == 4 + assert len(idata.posterior.coords["draw"]) == 100 + + +def test_args_compose(): + cfg = pmx.utils.prior._parse_args( + var_names=["a"], + b=("test",), + c=transforms.log, + d="e", + f=dict(dims="test"), + g=dict(name="h", dims="test", transform=transforms.log), + ) + assert cfg == dict( + a=dict(name="a", dims=None, transform=None), + b=dict(name="b", dims=("test",), transform=None), + c=dict(name="c", dims=None, transform=transforms.log), + d=dict(name="e", dims=None, transform=None), + f=dict(name="f", dims="test", transform=None), + g=dict(name="h", dims="test", transform=transforms.log), + ) + + +def test_transform_idata(transformed_data, idata, param_cfg): + flat_info = pmx.utils.prior._flatten(idata, **param_cfg) + expected_shape = 0 + for v in transformed_data.values(): + expected_shape += int(np.prod(v.shape[2:])) + assert flat_info["data"].shape[1] == expected_shape + assert len(flat_info["info"]) == len(param_cfg) + assert "sinfo" in flat_info["info"][0] + assert "vinfo" in flat_info["info"][0] + + +@pytest.fixture +def flat_info(idata, param_cfg): + return pmx.utils.prior._flatten(idata, **param_cfg) + + +def test_mean_chol(flat_info): + mean, chol = pmx.utils.prior._mean_chol(flat_info["data"]) + assert mean.shape == (flat_info["data"].shape[1],) + assert chol.shape == (flat_info["data"].shape[1],) * 2 + + +def test_mvn_prior_from_flat_info(flat_info, coords, param_cfg): + with pm.Model(coords=coords) as model: + priors = pmx.utils.prior._mvn_prior_from_flat_info("trace_prior_", flat_info) + test_prior = pm.sample_prior_predictive(1) + names = [p["name"] for p in param_cfg.values()] + assert set(model.named_vars) == {"trace_prior_", *names} + + +def test_prior_from_idata(idata, user_param_cfg, coords, param_cfg): + with pm.Model(coords=coords) as model: + priors = pmx.utils.prior.prior_from_idata( + idata, var_names=user_param_cfg[0], **user_param_cfg[1] + ) + test_prior = pm.sample_prior_predictive(1) + names = [p["name"] for p in param_cfg.values()] + assert set(model.named_vars) == {"trace_prior_", *names} diff --git a/pymc_experimental/utils/__init__.py b/pymc_experimental/utils/__init__.py index f645838eb..f7f933433 100644 --- a/pymc_experimental/utils/__init__.py +++ b/pymc_experimental/utils/__init__.py @@ -1 +1,2 @@ from pymc_experimental.utils import spline +from pymc_experimental.utils import prior diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py new file mode 100644 index 000000000..ffd243a40 --- /dev/null +++ b/pymc_experimental/utils/prior.py @@ -0,0 +1,181 @@ +from typing import TypedDict, Optional, Union, Tuple, Sequence, Dict, List +import aeppl.transforms +import arviz +import pymc as pm +import aesara.tensor as at +import numpy as np + + +class ParamCfg(TypedDict): + name: str + transform: Optional[aeppl.transforms.RVTransform] + dims: Optional[Union[str, Tuple[str]]] + + +class ShapeInfo(TypedDict): + # shape might not match slice due to a transform + shape: Tuple[int] # transformed shape + slice: slice + + +class VarInfo(TypedDict): + sinfo: ShapeInfo + vinfo: ParamCfg + + +class FlatInfo(TypedDict): + data: np.ndarray + info: List[VarInfo] + + +def _arg_to_param_cfg( + key, value: Optional[Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]] = None +): + if value is None: + cfg = ParamCfg(name=key, transform=None, dims=None) + elif isinstance(value, Tuple): + cfg = ParamCfg(name=key, transform=None, dims=value) + elif isinstance(value, str): + cfg = ParamCfg(name=value, transform=None, dims=None) + elif isinstance(value, aeppl.transforms.RVTransform): + cfg = ParamCfg(name=key, transform=value, dims=None) + else: + cfg = value.copy() + cfg.setdefault("name", key) + cfg.setdefault("transform", None) + cfg.setdefault("dims", None) + return cfg + + +def _parse_args( + var_names: Sequence[str], **kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple] +) -> Dict[str, ParamCfg]: + results = dict() + for var in var_names: + results[var] = _arg_to_param_cfg(var) + for key, val in kwargs.items(): + results[key] = _arg_to_param_cfg(key, val) + return results + + +def _flatten(idata: arviz.InferenceData, **kwargs: ParamCfg) -> FlatInfo: + posterior = idata.posterior + vars = list() + info = list() + begin = 0 + for key, cfg in kwargs.items(): + data = ( + posterior[key] + # combine all draws from all chains + .stack(__sample__=["chain", "draw"]) + # move sample dim to the first position + # no matter where it was before + .transpose("__sample__", ...) + # we need numpy data for all the rest functionality + .values + ) + # omitting __sample__ + # we need shape in the untransformed space + if cfg["transform"] is not None: + # some transforms need original shape + data = cfg["transform"].forward(data).eval() + shape = data.shape[1:] + # now we can get rid of shape + data = data.reshape(data.shape[0], -1) + end = begin + data.shape[1] + vars.append(data) + sinfo = dict(shape=shape, slice=slice(begin, end)) + info.append(dict(sinfo=sinfo, vinfo=cfg)) + begin = end + return dict(data=np.concatenate(vars, axis=-1), info=info) + + +def _mean_chol(flat_array: np.ndarray): + mean = flat_array.mean(0) + cov = np.cov(flat_array, rowvar=False) + chol = np.linalg.cholesky(cov) + return mean, chol + + +def _mvn_prior_from_flat_info(name, flat_info: FlatInfo): + mean, chol = _mean_chol(flat_info["data"]) + base_dist = pm.Normal(name, np.zeros_like(mean)) + interim = mean + chol @ base_dist + result = dict() + for var_info in flat_info["info"]: + sinfo = var_info["sinfo"] + vinfo = var_info["vinfo"] + var = interim[sinfo["slice"]].reshape(sinfo["shape"]) + if vinfo["transform"] is not None: + var = vinfo["transform"].backward(var) + var = pm.Deterministic(vinfo["name"], var, dims=vinfo["dims"]) + result[vinfo["name"]] = var + return result + + +def prior_from_idata( + idata: arviz.InferenceData, + name="trace_prior_", + *, + var_names: Sequence[str], + **kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple] +) -> Dict[str, at.TensorVariable]: + """ + Create a prior from posterior using MvNormal approximation. + + The approximation uses MvNormal distribution. + Keep in mind that this function will only work well for unimodal + posteriors and will fail when complicated interactions happen. + + Moreover, if a retrieved variable is constrained, you + should specify a transform for the variable, e.g. + ``pymc.distributions.transforms.log`` for standard + deviation posterior. + + Parameters + ---------- + idata: arviz.InferenceData + Inference data with posterior group + var_names: Sequence[str] + names of variables to take as is from the posterior + kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple] + names of variables with additional configuration, see more in Examples + + Examples + -------- + >>> import pymc as pm + >>> import pymc.distributions.transforms as transforms + >>> import numpy as np + >>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model1: + ... a = pm.Normal("a") + ... b = pm.Normal("b", dims="test") + ... c = pm.HalfNormal("c") + ... d = pm.Normal("d") + ... e = pm.Normal("e") + ... f = pm.Dirichlet("f", np.ones(3), dims="options") + ... trace = pm.sample(progressbar=False) + + You can reuse the posterior in the new model. + + >>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model2: + ... priors = prior_from_idata( + ... trace, # the old trace (posterior) + ... var_names=["a", "d"], # take variables as is + ... + ... e="new_e", # assign new name "new_e" for a variable + ... # similar to dict(name="new_e") + ... + ... b=("test", ), # set a dim to "test" + ... # similar to dict(dims=("test", )) + ... + ... c=transforms.log, # apply log transform to a positive variable + ... # similar to dict(transform=transforms.log) + ... + ... # set a name, assign a dim and apply simplex transform + ... f=dict(name="new_f", dims="options", transform=transforms.simplex) + ... ) + ... trace1 = pm.sample_prior_predictive(100) + """ + param_cfg = _parse_args(var_names=var_names, **kwargs) + flat_info = _flatten(idata, **param_cfg) + return _mvn_prior_from_flat_info(name, flat_info)