From f41ffcd490fb554bc225719e97f07f35c20c034a Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 29 Jun 2022 18:08:24 +0000 Subject: [PATCH 01/23] add argument parser --- .../tests/test_prior_from_trace.py | 20 ++++++++++++++++++ pymc_experimental/utils/__init__.py | 1 + pymc_experimental/utils/prior.py | 21 +++++++++++++++++++ 3 files changed, 42 insertions(+) create mode 100644 pymc_experimental/tests/test_prior_from_trace.py create mode 100644 pymc_experimental/utils/prior.py diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py new file mode 100644 index 000000000..3b3a419b8 --- /dev/null +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -0,0 +1,20 @@ +import pymc_experimental as pmx +import pymc as pm +from pymc.distributions import transforms +import pytest + + +@pytest.mark.parametrize( + "case", + [ + (("a", dict(name="b")), dict(name="b", transform=None)), + (("a", None), dict(name="a", transform=None)), + (("a", transforms.log), dict(name="a", transform=transforms.log)), + (("a", dict(transform=transforms.log)), dict(name="a", transform=transforms.log)), + (("a", dict(name="b")), dict(name="b", transform=None)), + ], +) +def test_parsing_arguments(case): + inp, out = case + test = pmx.utils.prior.arg_to_param_cfg(*inp) + assert test == out diff --git a/pymc_experimental/utils/__init__.py b/pymc_experimental/utils/__init__.py index f645838eb..f7f933433 100644 --- a/pymc_experimental/utils/__init__.py +++ b/pymc_experimental/utils/__init__.py @@ -1 +1,2 @@ from pymc_experimental.utils import spline +from pymc_experimental.utils import prior diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py new file mode 100644 index 000000000..5e0e289e8 --- /dev/null +++ b/pymc_experimental/utils/prior.py @@ -0,0 +1,21 @@ +from typing import TypedDict, Optional, Union +import aeppl.transforms + + +class ParamCfg(TypedDict): + name: str + transform: Optional[aeppl.transforms.RVTransform] + + +def arg_to_param_cfg(key, value: Optional[Union[ParamCfg, aeppl.transforms.RVTransform, str]]): + if value is None: + cfg = ParamCfg(name=key, transform=None) + elif isinstance(value, str): + cfg = ParamCfg(name=value, transform=None) + elif isinstance(value, aeppl.transforms.RVTransform): + cfg = ParamCfg(name=key, transform=value) + else: + cfg = value.copy() + cfg.setdefault("name", key) + cfg.setdefault("transform", None) + return cfg From 70d1b1013365357bd7a88874c2c04131d39a1f5a Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 29 Jun 2022 18:26:35 +0000 Subject: [PATCH 02/23] extend argument parser --- pymc_experimental/tests/test_prior_from_trace.py | 15 ++++++++++----- pymc_experimental/utils/prior.py | 16 +++++++++++----- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index 3b3a419b8..78b803e1c 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -7,11 +7,16 @@ @pytest.mark.parametrize( "case", [ - (("a", dict(name="b")), dict(name="b", transform=None)), - (("a", None), dict(name="a", transform=None)), - (("a", transforms.log), dict(name="a", transform=transforms.log)), - (("a", dict(transform=transforms.log)), dict(name="a", transform=transforms.log)), - (("a", dict(name="b")), dict(name="b", transform=None)), + (("a", dict(name="b")), dict(name="b", transform=None, dims=None)), + (("a", None), dict(name="a", transform=None, dims=None)), + (("a", transforms.log), dict(name="a", transform=transforms.log, dims=None)), + ( + ("a", dict(transform=transforms.log)), + dict(name="a", transform=transforms.log, dims=None), + ), + (("a", dict(name="b")), dict(name="b", transform=None, dims=None)), + (("a", dict(name="b", dims="test")), dict(name="b", transform=None, dims="test")), + (("a", ("test",)), dict(name="a", transform=None, dims=("test",))), ], ) def test_parsing_arguments(case): diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index 5e0e289e8..f1ad61659 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -1,21 +1,27 @@ -from typing import TypedDict, Optional, Union +from typing import TypedDict, Optional, Union, Tuple import aeppl.transforms class ParamCfg(TypedDict): name: str transform: Optional[aeppl.transforms.RVTransform] + dims: Optional[Union[str, Tuple[str]]] -def arg_to_param_cfg(key, value: Optional[Union[ParamCfg, aeppl.transforms.RVTransform, str]]): +def arg_to_param_cfg( + key, value: Optional[Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]] +): if value is None: - cfg = ParamCfg(name=key, transform=None) + cfg = ParamCfg(name=key, transform=None, dims=None) + elif isinstance(value, Tuple): + cfg = ParamCfg(name=key, transform=None, dims=value) elif isinstance(value, str): - cfg = ParamCfg(name=value, transform=None) + cfg = ParamCfg(name=value, transform=None, dims=None) elif isinstance(value, aeppl.transforms.RVTransform): - cfg = ParamCfg(name=key, transform=value) + cfg = ParamCfg(name=key, transform=value, dims=None) else: cfg = value.copy() cfg.setdefault("name", key) cfg.setdefault("transform", None) + cfg.setdefault("dims", None) return cfg From 5fdb819e2090d76fe62b016a556dc2a408764be2 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 29 Jun 2022 18:50:46 +0000 Subject: [PATCH 03/23] prepare a valid fixture --- pymc_experimental/tests/test_prior_from_trace.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index 78b803e1c..8a50d9983 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -1,7 +1,8 @@ import pymc_experimental as pmx -import pymc as pm from pymc.distributions import transforms import pytest +import arviz as az +import numpy as np @pytest.mark.parametrize( @@ -23,3 +24,16 @@ def test_parsing_arguments(case): inp, out = case test = pmx.utils.prior.arg_to_param_cfg(*inp) assert test == out + + +@pytest.fixture +def idata(): + a = np.random.randn(4, 1000, 3) + b = np.exp(np.random.randn(4, 1000, 5)) + return az.convert_to_inference_data(dict(a=a, b=b)) + + +def test_idata_for_tests(idata): + assert set(idata.posterior.keys()) == {"a", "b"} + assert len(idata.posterior.coords["chain"]) == 4 + assert len(idata.posterior.coords["draw"]) == 1000 From 45582f586fc0a0ff58b229aa2790ac8fcc16a9a3 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 29 Jun 2022 19:26:10 +0000 Subject: [PATCH 04/23] improve fixture --- .../tests/test_prior_from_trace.py | 37 +++++++++++++++---- pymc_experimental/utils/prior.py | 2 +- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index 8a50d9983..9672892f6 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -27,13 +27,36 @@ def test_parsing_arguments(case): @pytest.fixture -def idata(): - a = np.random.randn(4, 1000, 3) - b = np.exp(np.random.randn(4, 1000, 5)) - return az.convert_to_inference_data(dict(a=a, b=b)) +def coords(): + return dict(test=range(3)) -def test_idata_for_tests(idata): - assert set(idata.posterior.keys()) == {"a", "b"} +@pytest.fixture +def param_cfg(): + return dict( + a=pmx.utils.prior.arg_to_param_cfg("a"), + b=pmx.utils.prior.arg_to_param_cfg("b", dict(transform=transforms.log, dims=("test",))), + ) + + +@pytest.fixture +def idata(param_cfg, coords): + vars = dict() + for k, cfg in param_cfg.items(): + if cfg["dims"] is not None: + extra_dims = [len(coords[d]) for d in cfg["dims"]] + else: + extra_dims = [] + orig = np.random.randn(4, 100, *extra_dims) + if cfg["transform"] is not None: + var = cfg["transform"].backward(orig).eval() + else: + var = orig + vars[k] = var + return az.convert_to_inference_data(vars, coords=coords) + + +def test_idata_for_tests(idata, param_cfg): + assert set(idata.posterior.keys()) == set(param_cfg) assert len(idata.posterior.coords["chain"]) == 4 - assert len(idata.posterior.coords["draw"]) == 1000 + assert len(idata.posterior.coords["draw"]) == 100 diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index f1ad61659..a1adf3bd9 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -9,7 +9,7 @@ class ParamCfg(TypedDict): def arg_to_param_cfg( - key, value: Optional[Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]] + key, value: Optional[Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]] = None ): if value is None: cfg = ParamCfg(name=key, transform=None, dims=None) From 8ea35dbd2a61bb9c36f8c106a6f44ab28b607c1c Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 29 Jun 2022 19:34:56 +0000 Subject: [PATCH 05/23] improve fixture --- pymc_experimental/tests/test_prior_from_trace.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index 9672892f6..be5370920 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -35,7 +35,10 @@ def coords(): def param_cfg(): return dict( a=pmx.utils.prior.arg_to_param_cfg("a"), - b=pmx.utils.prior.arg_to_param_cfg("b", dict(transform=transforms.log, dims=("test",))), + b=pmx.utils.prior.arg_to_param_cfg( + "b", dict(transform=transforms.sum_to_1, dims=("test",)) + ), + c=pmx.utils.prior.arg_to_param_cfg("c", dict(transform=transforms.log, dims=("test",))), ) From 8d047fc93edc4a032d4659477c253a88841b7ffb Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Wed, 29 Jun 2022 19:42:29 +0000 Subject: [PATCH 06/23] use simplex transform for the test case --- pymc_experimental/tests/test_prior_from_trace.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index be5370920..5f95b65fc 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -28,17 +28,17 @@ def test_parsing_arguments(case): @pytest.fixture def coords(): - return dict(test=range(3)) + return dict(test=range(3), simplex=range(4)) @pytest.fixture def param_cfg(): return dict( a=pmx.utils.prior.arg_to_param_cfg("a"), - b=pmx.utils.prior.arg_to_param_cfg( - "b", dict(transform=transforms.sum_to_1, dims=("test",)) + b=pmx.utils.prior.arg_to_param_cfg("b", dict(transform=transforms.log, dims=("test",))), + c=pmx.utils.prior.arg_to_param_cfg( + "c", dict(transform=transforms.simplex, dims=("simplex",)) ), - c=pmx.utils.prior.arg_to_param_cfg("c", dict(transform=transforms.log, dims=("test",))), ) @@ -55,6 +55,7 @@ def idata(param_cfg, coords): var = cfg["transform"].backward(orig).eval() else: var = orig + assert not np.isnan(var).any() vars[k] = var return az.convert_to_inference_data(vars, coords=coords) From 20c76bfd006ce7e75164ca4613569b3973e49b2a Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 30 Jun 2022 12:26:29 +0000 Subject: [PATCH 07/23] add parse args --- .../tests/test_prior_from_trace.py | 27 ++++++++++++++++--- pymc_experimental/utils/prior.py | 15 +++++++++-- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index 5f95b65fc..9a5565bc9 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -22,7 +22,7 @@ ) def test_parsing_arguments(case): inp, out = case - test = pmx.utils.prior.arg_to_param_cfg(*inp) + test = pmx.utils.prior._arg_to_param_cfg(*inp) assert test == out @@ -34,9 +34,9 @@ def coords(): @pytest.fixture def param_cfg(): return dict( - a=pmx.utils.prior.arg_to_param_cfg("a"), - b=pmx.utils.prior.arg_to_param_cfg("b", dict(transform=transforms.log, dims=("test",))), - c=pmx.utils.prior.arg_to_param_cfg( + a=pmx.utils.prior._arg_to_param_cfg("a"), + b=pmx.utils.prior._arg_to_param_cfg("b", dict(transform=transforms.log, dims=("test",))), + c=pmx.utils.prior._arg_to_param_cfg( "c", dict(transform=transforms.simplex, dims=("simplex",)) ), ) @@ -64,3 +64,22 @@ def test_idata_for_tests(idata, param_cfg): assert set(idata.posterior.keys()) == set(param_cfg) assert len(idata.posterior.coords["chain"]) == 4 assert len(idata.posterior.coords["draw"]) == 100 + + +def test_args_compose(): + cfg = pmx.utils.prior._parse_args( + var_names=["a"], + b=("test",), + c=transforms.log, + d="e", + f=dict(dims="test"), + g=dict(name="h", dims="test", transform=transforms.log), + ) + assert cfg == dict( + a=dict(name="a", dims=None, transform=None), + b=dict(name="b", dims=("test",), transform=None), + c=dict(name="c", dims=None, transform=transforms.log), + d=dict(name="e", dims=None, transform=None), + f=dict(name="f", dims="test", transform=None), + g=dict(name="h", dims="test", transform=transforms.log), + ) diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index a1adf3bd9..16c184693 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -1,4 +1,4 @@ -from typing import TypedDict, Optional, Union, Tuple +from typing import TypedDict, Optional, Union, Tuple, Sequence, Dict import aeppl.transforms @@ -8,7 +8,7 @@ class ParamCfg(TypedDict): dims: Optional[Union[str, Tuple[str]]] -def arg_to_param_cfg( +def _arg_to_param_cfg( key, value: Optional[Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]] = None ): if value is None: @@ -25,3 +25,14 @@ def arg_to_param_cfg( cfg.setdefault("transform", None) cfg.setdefault("dims", None) return cfg + + +def _parse_args( + var_names: Sequence[str], **kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple] +) -> Dict[str, ParamCfg]: + results = dict() + for var in var_names: + results[var] = _arg_to_param_cfg(var) + for key, val in kwargs.items(): + results[key] = _arg_to_param_cfg(key, val) + return results From 661eacb56ce2294a0624c560cb230e6cee3472c9 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 30 Jun 2022 13:31:44 +0000 Subject: [PATCH 08/23] add flatten util --- .../tests/test_prior_from_trace.py | 23 +++++++++-- pymc_experimental/utils/prior.py | 40 ++++++++++++++++++- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index 9a5565bc9..0f6c11519 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -34,7 +34,7 @@ def coords(): @pytest.fixture def param_cfg(): return dict( - a=pmx.utils.prior._arg_to_param_cfg("a"), + a=pmx.utils.prior._arg_to_param_cfg("d"), b=pmx.utils.prior._arg_to_param_cfg("b", dict(transform=transforms.log, dims=("test",))), c=pmx.utils.prior._arg_to_param_cfg( "c", dict(transform=transforms.simplex, dims=("simplex",)) @@ -43,7 +43,7 @@ def param_cfg(): @pytest.fixture -def idata(param_cfg, coords): +def transformed_data(param_cfg, coords): vars = dict() for k, cfg in param_cfg.items(): if cfg["dims"] is not None: @@ -51,13 +51,22 @@ def idata(param_cfg, coords): else: extra_dims = [] orig = np.random.randn(4, 100, *extra_dims) + vars[k] = orig + return vars + + +@pytest.fixture +def idata(transformed_data, param_cfg): + vars = dict() + for k, orig in transformed_data.items(): + cfg = param_cfg[k] if cfg["transform"] is not None: var = cfg["transform"].backward(orig).eval() else: var = orig assert not np.isnan(var).any() vars[k] = var - return az.convert_to_inference_data(vars, coords=coords) + return az.convert_to_inference_data(vars) def test_idata_for_tests(idata, param_cfg): @@ -83,3 +92,11 @@ def test_args_compose(): f=dict(name="f", dims="test", transform=None), g=dict(name="h", dims="test", transform=transforms.log), ) + + +def test_transform_idata(transformed_data, idata, param_cfg): + flat_info = pmx.utils.prior._flatten(idata, **param_cfg) + expected_shape = 0 + for v in transformed_data.values(): + expected_shape += int(np.prod(v.shape[2:])) + assert flat_info["data"].shape[1] == expected_shape diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index 16c184693..687bb618b 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -1,5 +1,7 @@ -from typing import TypedDict, Optional, Union, Tuple, Sequence, Dict +from typing import TypedDict, Optional, Union, Tuple, Sequence, Dict, List import aeppl.transforms +import arviz +import numpy as np class ParamCfg(TypedDict): @@ -8,6 +10,22 @@ class ParamCfg(TypedDict): dims: Optional[Union[str, Tuple[str]]] +class ShapeInfo(TypedDict): + # shape might not match slice due to a transform + shape: Tuple[int] + slice: slice + + +class VarInfo(TypedDict): + sinfo: ShapeInfo + vinfo: ParamCfg + + +class FlatInfo(TypedDict): + data: np.ndarray + info: List[VarInfo] + + def _arg_to_param_cfg( key, value: Optional[Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]] = None ): @@ -36,3 +54,23 @@ def _parse_args( for key, val in kwargs.items(): results[key] = _arg_to_param_cfg(key, val) return results + + +def _flatten(idata: arviz.InferenceData, **kwargs: ParamCfg) -> FlatInfo: + posterior = idata.posterior + vars = list() + info = list() + begin = 0 + for key, cfg in kwargs.items(): + data = posterior[key].values + # omitting chain, draw + shape = data.shape[2:] + if cfg["transform"] is not None: + data = cfg["transform"].forward(data).eval() + data = data.reshape(*data.shape[:2], -1) + data = data.reshape(-1, data.shape[2]) + end = begin + data.shape[1] + vars.append(data) + info.append(dict(shape=shape, slice=slice(begin, end))) + begin = end + return dict(data=np.concatenate(vars, axis=-1), infp=info) From 58ff72ad0db8f31b0ecc91a3f744efbbb5a6db3d Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 1 Jul 2022 13:24:50 +0000 Subject: [PATCH 09/23] fix typo --- pymc_experimental/utils/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index 687bb618b..1a57d2dee 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -73,4 +73,4 @@ def _flatten(idata: arviz.InferenceData, **kwargs: ParamCfg) -> FlatInfo: vars.append(data) info.append(dict(shape=shape, slice=slice(begin, end))) begin = end - return dict(data=np.concatenate(vars, axis=-1), infp=info) + return dict(data=np.concatenate(vars, axis=-1), info=info) From 2ba7a633d27eb72aa0b382b54091518233e20267 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 1 Jul 2022 14:59:22 +0000 Subject: [PATCH 10/23] refactor flattening --- .../tests/test_prior_from_trace.py | 1 + pymc_experimental/utils/prior.py | 21 ++++++++++++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index 0f6c11519..20af29dd7 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -100,3 +100,4 @@ def test_transform_idata(transformed_data, idata, param_cfg): for v in transformed_data.values(): expected_shape += int(np.prod(v.shape[2:])) assert flat_info["data"].shape[1] == expected_shape + assert len(flat_info["info"]) == len(param_cfg) diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index 1a57d2dee..91cbbb22c 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -62,13 +62,24 @@ def _flatten(idata: arviz.InferenceData, **kwargs: ParamCfg) -> FlatInfo: info = list() begin = 0 for key, cfg in kwargs.items(): - data = posterior[key].values - # omitting chain, draw - shape = data.shape[2:] + data = ( + posterior[key] + # combine all draws from all chains + .stack(__sample__=["chain", "draw"]) + # move sample dim to the first position + # no matter where it was before + .transpose("__sample__", ...) + # we need numpy data for all the rest functionality + .values + ) + # omitting __sample__ + # we need shape in the untransformed space + shape = data.shape[1:] if cfg["transform"] is not None: + # some transforms need original shape data = cfg["transform"].forward(data).eval() - data = data.reshape(*data.shape[:2], -1) - data = data.reshape(-1, data.shape[2]) + # now we can get rid of shape + data = data.reshape(data.shape[0], -1) end = begin + data.shape[1] vars.append(data) info.append(dict(shape=shape, slice=slice(begin, end))) From 7b09ff62d46098db78834012017240af6a612c4f Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 1 Jul 2022 15:04:22 +0000 Subject: [PATCH 11/23] add mean chol --- pymc_experimental/tests/test_prior_from_trace.py | 11 +++++++++++ pymc_experimental/utils/prior.py | 7 +++++++ 2 files changed, 18 insertions(+) diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index 20af29dd7..eabda2370 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -101,3 +101,14 @@ def test_transform_idata(transformed_data, idata, param_cfg): expected_shape += int(np.prod(v.shape[2:])) assert flat_info["data"].shape[1] == expected_shape assert len(flat_info["info"]) == len(param_cfg) + + +@pytest.fixture +def flat_info(idata, param_cfg): + return pmx.utils.prior._flatten(idata, **param_cfg) + + +def test_mean_chol(flat_info): + mean, chol = pmx.utils.prior._mean_chol(flat_info["data"]) + assert mean.shape == (flat_info["data"].shape[1],) + assert chol.shape == (flat_info["data"].shape[1],) * 2 diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index 91cbbb22c..f742a3d40 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -85,3 +85,10 @@ def _flatten(idata: arviz.InferenceData, **kwargs: ParamCfg) -> FlatInfo: info.append(dict(shape=shape, slice=slice(begin, end))) begin = end return dict(data=np.concatenate(vars, axis=-1), info=info) + + +def _mean_chol(flat_array: np.ndarray): + mean = flat_array.mean(0) + cov = np.cov(flat_array, rowvar=False) + chol = np.linalg.cholesky(cov) + return mean, chol From bb5cfd1d7fd207c80a0d4a917f2b6baaad181f75 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 1 Jul 2022 15:27:04 +0000 Subject: [PATCH 12/23] add test for mvn_prior --- .../tests/test_prior_from_trace.py | 10 +++++++ pymc_experimental/utils/prior.py | 29 +++++++++++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index eabda2370..d3dbd8ccd 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -3,6 +3,7 @@ import pytest import arviz as az import numpy as np +import pymc as pm @pytest.mark.parametrize( @@ -101,6 +102,8 @@ def test_transform_idata(transformed_data, idata, param_cfg): expected_shape += int(np.prod(v.shape[2:])) assert flat_info["data"].shape[1] == expected_shape assert len(flat_info["info"]) == len(param_cfg) + assert "sinfo" in param_cfg["info"][0] + assert "vinfo" in param_cfg["info"][0] @pytest.fixture @@ -112,3 +115,10 @@ def test_mean_chol(flat_info): mean, chol = pmx.utils.prior._mean_chol(flat_info["data"]) assert mean.shape == (flat_info["data"].shape[1],) assert chol.shape == (flat_info["data"].shape[1],) * 2 + + +def test_mvn_prior_from_flat_info(flat_info, coords, param_cfg): + with pm.Model(coords=coords) as model: + priors = pmx.utils.prior._mvn_prior_from_flat_info("trace_prior_", flat_info) + names = [p["name"] for p in param_cfg.values()] + assert set(model.named_vars) == {"trace_prior_", *names} diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index f742a3d40..2acf9c5ce 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -1,6 +1,7 @@ from typing import TypedDict, Optional, Union, Tuple, Sequence, Dict, List import aeppl.transforms import arviz +import pymc as pm import numpy as np @@ -12,7 +13,8 @@ class ParamCfg(TypedDict): class ShapeInfo(TypedDict): # shape might not match slice due to a transform - shape: Tuple[int] + shape_u: Tuple[int] # untransformed shape + shape_t: Tuple[int] # transformed shape slice: slice @@ -74,15 +76,19 @@ def _flatten(idata: arviz.InferenceData, **kwargs: ParamCfg) -> FlatInfo: ) # omitting __sample__ # we need shape in the untransformed space - shape = data.shape[1:] + shape_u = data.shape[1:] if cfg["transform"] is not None: # some transforms need original shape data = cfg["transform"].forward(data).eval() + shape_t = data.shape[1:] + else: + shape_t = shape_u # now we can get rid of shape data = data.reshape(data.shape[0], -1) end = begin + data.shape[1] vars.append(data) - info.append(dict(shape=shape, slice=slice(begin, end))) + sinfo = dict(shape_t=shape_t, shape_u=shape_u, slice=slice(begin, end)) + info.append(dict(sinfo=sinfo, vinfo=cfg)) begin = end return dict(data=np.concatenate(vars, axis=-1), info=info) @@ -92,3 +98,20 @@ def _mean_chol(flat_array: np.ndarray): cov = np.cov(flat_array, rowvar=False) chol = np.linalg.cholesky(cov) return mean, chol + + +def _mvn_prior_from_flat_info(name, flat_info: FlatInfo): + mean, chol = _mean_chol(flat_info["data"]) + base_dist = pm.Normal(name, np.zeros_like(mean)) + interim = mean + chol @ base_dist + result = dict() + for var_info in flat_info["info"]: + sinfo = var_info["sinfo"] + vinfo = var_info["vinfo"] + var = interim[sinfo["slice"]].reshape(sinfo["shape_t"]) + if vinfo["transform"] is not None: + var = vinfo["transform"].backward(var) + var = var.reshape(sinfo["shape_u"]) + var = pm.Deterministic(vinfo["name"], var, dims=vinfo["dims"]) + result[vinfo["name"]] = var + return result From 1eeb82d408d4dc0831198659dc00d3e9443903b4 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 1 Jul 2022 15:45:48 +0000 Subject: [PATCH 13/23] test final api --- .../tests/test_prior_from_trace.py | 31 ++++++++++++++----- pymc_experimental/utils/prior.py | 6 ++++ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index d3dbd8ccd..7d05fbc8d 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -33,22 +33,28 @@ def coords(): @pytest.fixture -def param_cfg(): - return dict( - a=pmx.utils.prior._arg_to_param_cfg("d"), - b=pmx.utils.prior._arg_to_param_cfg("b", dict(transform=transforms.log, dims=("test",))), - c=pmx.utils.prior._arg_to_param_cfg( - "c", dict(transform=transforms.simplex, dims=("simplex",)) - ), +def user_param_cfg(): + return (), dict( + a="d", + b=dict(transform=transforms.log, dims=("test",)), + c=dict(transform=transforms.simplex, dims=("simplex",)), ) +@pytest.fixture +def param_cfg(user_param_cfg): + return pmx.utils.prior._parse_args(user_param_cfg[0], **user_param_cfg[1]) + + @pytest.fixture def transformed_data(param_cfg, coords): vars = dict() for k, cfg in param_cfg.items(): if cfg["dims"] is not None: extra_dims = [len(coords[d]) for d in cfg["dims"]] + if cfg["transform"] is not None: + t = np.random.randn(*extra_dims) + extra_dims = tuple(cfg["transform"].forward(t).shape.eval()) else: extra_dims = [] orig = np.random.randn(4, 100, *extra_dims) @@ -120,5 +126,16 @@ def test_mean_chol(flat_info): def test_mvn_prior_from_flat_info(flat_info, coords, param_cfg): with pm.Model(coords=coords) as model: priors = pmx.utils.prior._mvn_prior_from_flat_info("trace_prior_", flat_info) + test_prior = pm.sample_prior_predictive(1) + names = [p["name"] for p in param_cfg.values()] + assert set(model.named_vars) == {"trace_prior_", *names} + + +def test_prior_from_idata(idata, user_param_cfg, coords): + with pm.Model(coords=coords) as model: + priors = pmx.utils.prior.prior_from_idata( + idata, var_names=user_param_cfg[0], **user_param_cfg[1] + ) + test_prior = pm.sample_prior_predictive(1) names = [p["name"] for p in param_cfg.values()] assert set(model.named_vars) == {"trace_prior_", *names} diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index 2acf9c5ce..dfd80e90d 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -115,3 +115,9 @@ def _mvn_prior_from_flat_info(name, flat_info: FlatInfo): var = pm.Deterministic(vinfo["name"], var, dims=vinfo["dims"]) result[vinfo["name"]] = var return result + + +def prior_from_idata(idata, name="trace_prior_", *, var_names: Sequence[str] = (), **kwargs): + param_cfg = _parse_args(var_names=var_names, **kwargs) + flat_info = _flatten(idata, **param_cfg) + return _mvn_prior_from_flat_info(name, flat_info) From 83beb1d20e592a251f078ee57c31baeb8f6f8405 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 1 Jul 2022 15:46:39 +0000 Subject: [PATCH 14/23] add additional argument --- pymc_experimental/tests/test_prior_from_trace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index 7d05fbc8d..78fc0a4d2 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -34,7 +34,7 @@ def coords(): @pytest.fixture def user_param_cfg(): - return (), dict( + return ("t", ), dict( a="d", b=dict(transform=transforms.log, dims=("test",)), c=dict(transform=transforms.simplex, dims=("simplex",)), From 0c2a3a75fa123359b4e196b1bd46f5df9bf8c455 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 1 Jul 2022 15:50:10 +0000 Subject: [PATCH 15/23] add type hints --- pymc_experimental/tests/test_prior_from_trace.py | 2 +- pymc_experimental/utils/prior.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index 78fc0a4d2..32e328854 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -34,7 +34,7 @@ def coords(): @pytest.fixture def user_param_cfg(): - return ("t", ), dict( + return ("t",), dict( a="d", b=dict(transform=transforms.log, dims=("test",)), c=dict(transform=transforms.simplex, dims=("simplex",)), diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index dfd80e90d..a0a163783 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -2,6 +2,7 @@ import aeppl.transforms import arviz import pymc as pm +import aesara.tensor as at import numpy as np @@ -117,7 +118,13 @@ def _mvn_prior_from_flat_info(name, flat_info: FlatInfo): return result -def prior_from_idata(idata, name="trace_prior_", *, var_names: Sequence[str] = (), **kwargs): +def prior_from_idata( + idata, + name="trace_prior_", + *, + var_names: Sequence[str], + **kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple] +) -> Dict[str, at.TensorVariable]: param_cfg = _parse_args(var_names=var_names, **kwargs) flat_info = _flatten(idata, **param_cfg) return _mvn_prior_from_flat_info(name, flat_info) From a2e0db21ab6f94770d118acbdf182da3e1e7466e Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 1 Jul 2022 18:11:44 +0000 Subject: [PATCH 16/23] fix tests --- pymc_experimental/tests/test_prior_from_trace.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc_experimental/tests/test_prior_from_trace.py b/pymc_experimental/tests/test_prior_from_trace.py index 32e328854..34f907152 100644 --- a/pymc_experimental/tests/test_prior_from_trace.py +++ b/pymc_experimental/tests/test_prior_from_trace.py @@ -108,8 +108,8 @@ def test_transform_idata(transformed_data, idata, param_cfg): expected_shape += int(np.prod(v.shape[2:])) assert flat_info["data"].shape[1] == expected_shape assert len(flat_info["info"]) == len(param_cfg) - assert "sinfo" in param_cfg["info"][0] - assert "vinfo" in param_cfg["info"][0] + assert "sinfo" in flat_info["info"][0] + assert "vinfo" in flat_info["info"][0] @pytest.fixture @@ -131,7 +131,7 @@ def test_mvn_prior_from_flat_info(flat_info, coords, param_cfg): assert set(model.named_vars) == {"trace_prior_", *names} -def test_prior_from_idata(idata, user_param_cfg, coords): +def test_prior_from_idata(idata, user_param_cfg, coords, param_cfg): with pm.Model(coords=coords) as model: priors = pmx.utils.prior.prior_from_idata( idata, var_names=user_param_cfg[0], **user_param_cfg[1] From c10e88075eabff40d5eb03fcb97ef63ce87ea203 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 1 Jul 2022 18:35:08 +0000 Subject: [PATCH 17/23] add a docstring --- pymc_experimental/utils/prior.py | 49 +++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index a0a163783..2c85359d8 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -119,12 +119,59 @@ def _mvn_prior_from_flat_info(name, flat_info: FlatInfo): def prior_from_idata( - idata, + idata: arviz.InferenceData, name="trace_prior_", *, var_names: Sequence[str], **kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple] ) -> Dict[str, at.TensorVariable]: + """ + Create a prior from posterior. + + Parameters + ---------- + idata: arviz.InferenceData + Inference data with posterior group + var_names: Sequence[str] + names of variables to take as is from the posterior + kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple] + names of variables with additional configuration, see more in Examples + + Examples + -------- + >>> import pymc as pm + >>> import pymc.distributions.transforms as transforms + >>> import numpy as np + >>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model1: + ... a = pm.Normal("a") + ... b = pm.Normal("b", dims="test") + ... c = pm.HalfNormal("c") + ... d = pm.Normal("d") + ... e = pm.Normal("e") + ... f = pm.Dirichlet("f", np.ones(3), dims="options") + ... trace = pm.sample(progressbar=False) + + You can reuse the posterior in the new model. + + >>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model2: + ... priors = prior_from_idata( + ... trace, # the old trace (posterior) + ... var_names=["a", "d"], # take variables as is + ... + ... e="new_e", # assign new name "new_e" for a variable + ... # similar to dict(name="new_e") + ... + ... b=("test", ), # set a coord to "test" + ... # similar to dict(dims=("test", )) + ... + ... c=transforms.log, # apply log transform to a positive variable + ... # similar to dict(transform=transforms.log) + ... + ... # set a name, assign a coord and apply simplex transform + ... f=dict(name="new_f", dims="options", transform=transforms.simplex) + ... ) + ... trace1 = pm.sample_prior_predictive(100) + """ param_cfg = _parse_args(var_names=var_names, **kwargs) flat_info = _flatten(idata, **param_cfg) return _mvn_prior_from_flat_info(name, flat_info) From 71eb6b7d9f436af24b9acc4b91f0d856dc2f93a9 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Fri, 1 Jul 2022 18:36:14 +0000 Subject: [PATCH 18/23] add to docs --- docs/api_reference.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 3020e3d01..9763d525f 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -29,3 +29,6 @@ methods in the current release of PyMC experimental. .. automodule:: pymc_experimental.utils.spline :members: bspline_interpolation +.. automodule:: pymc_experimental.utils.prior + :members: prior_from_idata + From f9b563280feb726c64372ccc9e0db5d85ef075df Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sat, 2 Jul 2022 19:08:03 +0000 Subject: [PATCH 19/23] simplify implementation --- pymc_experimental/utils/prior.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index 2c85359d8..a879c08ea 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -14,8 +14,7 @@ class ParamCfg(TypedDict): class ShapeInfo(TypedDict): # shape might not match slice due to a transform - shape_u: Tuple[int] # untransformed shape - shape_t: Tuple[int] # transformed shape + shape: Tuple[int] # transformed shape slice: slice @@ -77,18 +76,15 @@ def _flatten(idata: arviz.InferenceData, **kwargs: ParamCfg) -> FlatInfo: ) # omitting __sample__ # we need shape in the untransformed space - shape_u = data.shape[1:] if cfg["transform"] is not None: # some transforms need original shape data = cfg["transform"].forward(data).eval() - shape_t = data.shape[1:] - else: - shape_t = shape_u + shape = data.shape[1:] # now we can get rid of shape data = data.reshape(data.shape[0], -1) end = begin + data.shape[1] vars.append(data) - sinfo = dict(shape_t=shape_t, shape_u=shape_u, slice=slice(begin, end)) + sinfo = dict(shape=shape, slice=slice(begin, end)) info.append(dict(sinfo=sinfo, vinfo=cfg)) begin = end return dict(data=np.concatenate(vars, axis=-1), info=info) @@ -109,10 +105,9 @@ def _mvn_prior_from_flat_info(name, flat_info: FlatInfo): for var_info in flat_info["info"]: sinfo = var_info["sinfo"] vinfo = var_info["vinfo"] - var = interim[sinfo["slice"]].reshape(sinfo["shape_t"]) + var = interim[sinfo["slice"]].reshape(sinfo["shape"]) if vinfo["transform"] is not None: var = vinfo["transform"].backward(var) - var = var.reshape(sinfo["shape_u"]) var = pm.Deterministic(vinfo["name"], var, dims=vinfo["dims"]) result[vinfo["name"]] = var return result From 88ee0a9db7996328be0f93ba269a92aadfc85f13 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Tue, 5 Jul 2022 20:22:44 +0300 Subject: [PATCH 20/23] Update pymc_experimental/utils/prior.py Co-authored-by: Oriol Abril-Pla --- pymc_experimental/utils/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index a879c08ea..ebea97849 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -156,7 +156,7 @@ def prior_from_idata( ... e="new_e", # assign new name "new_e" for a variable ... # similar to dict(name="new_e") ... - ... b=("test", ), # set a coord to "test" + ... b=("test", ), # set a dim to "test" ... # similar to dict(dims=("test", )) ... ... c=transforms.log, # apply log transform to a positive variable From 678c84904d93522ee834de0a9f2a716814b82b36 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Tue, 5 Jul 2022 20:22:57 +0300 Subject: [PATCH 21/23] Update pymc_experimental/utils/prior.py Co-authored-by: Oriol Abril-Pla --- pymc_experimental/utils/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index ebea97849..a91ada940 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -162,7 +162,7 @@ def prior_from_idata( ... c=transforms.log, # apply log transform to a positive variable ... # similar to dict(transform=transforms.log) ... - ... # set a name, assign a coord and apply simplex transform + ... # set a name, assign a dim and apply simplex transform ... f=dict(name="new_f", dims="options", transform=transforms.simplex) ... ) ... trace1 = pm.sample_prior_predictive(100) From c2baf4c2e47b247e35e5ff53b4facf62783520b6 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Tue, 5 Jul 2022 17:30:57 +0000 Subject: [PATCH 22/23] update the docstring --- pymc_experimental/utils/prior.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index a91ada940..4ab26601d 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -121,7 +121,16 @@ def prior_from_idata( **kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple] ) -> Dict[str, at.TensorVariable]: """ - Create a prior from posterior. + Create a prior from posterior using MvNormal approximation. + + The approximation uses MvNormal distribution. + Keep in mind that this function will only work well for unimodal + posteriors and will fail when complicated interactions happen. + + Moreover, if a retrieved variable is constrained, you + should specify the transform for the variable, e.g. + ``pymc.distributions.transforms.log`` for standard + deviation posterior. Parameters ---------- From 27ec67b4fa28856a0c8ec940c59004238db400d1 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Tue, 5 Jul 2022 17:31:36 +0000 Subject: [PATCH 23/23] update the docstring --- pymc_experimental/utils/prior.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index 4ab26601d..ffd243a40 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -128,7 +128,7 @@ def prior_from_idata( posteriors and will fail when complicated interactions happen. Moreover, if a retrieved variable is constrained, you - should specify the transform for the variable, e.g. + should specify a transform for the variable, e.g. ``pymc.distributions.transforms.log`` for standard deviation posterior.