diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 83837a3..8985237 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: - id: pyupgrade args: [--py37-plus] - repo: https://github.com/psf/black - rev: 21.11b1 + rev: 22.3.0 hooks: - id: black - repo: https://github.com/PyCQA/pylint diff --git a/preclinpack/nested_hierarchy_rvs.py b/preclinpack/nested_hierarchy_rvs.py index b2161fe..e2b9a15 100644 --- a/preclinpack/nested_hierarchy_rvs.py +++ b/preclinpack/nested_hierarchy_rvs.py @@ -2,10 +2,18 @@ try: import pymc as pm + + from aesara import tensor as at + + _v4 = True except ImportError: import pymc3 as pm -from .utils import CenteredNormal, dims_to_tuple + from theano import tensor as at + + _v4 = False + +from .utils import CenteredNormal, dims_to_tuple, read_parametrization, shape_from_dims def make_hierarchical_weights( @@ -98,7 +106,7 @@ def make_hierarchical_coefs( ) coef_dims = dims_to_tuple(coef_dims) leaf_rvs = {} - stack = [(0, None, None, hierarchy)] + stack = [(0, None, None, hierarchy, None)] scales = make_hierarchical_weights( hierarchy=hierarchy, name_prefix=name_prefix, @@ -107,59 +115,178 @@ def make_hierarchical_coefs( ) unused_scales = set(scales.keys()) while stack: - level, dim, parent, node = stack.pop(0) + level, dim, parent, node, parent_pop_mean = stack.pop(0) if isinstance(node, dict): if parent is None: - rv = pm.Normal(f"{name_prefix}prior_mu", prior_mu, prior_mu_sigma, dims=coef_dims) + rv_name = f"{name_prefix}prior_mu" + rv = None for i, (child_dim, child) in enumerate(node.items()): - stack.append((level + 1, child_dim, rv, child)) - else: - # We can't use a CenteredNormal if dim size is 1 - (shape,) = model.shape_from_dims((dim,)) - safe_param = "non-centered" if shape > 1 else "centered" - if isinstance(parametrization, dict): - if parametrized_by_level: - param = parametrization.get(level, safe_param) + if isinstance(child, dict): + # If the next level follows a non-centered parametrization, + # we need to add the population mean RV to the current level + child_param = read_parametrization( + parametrization, + child_dim, + level + 1, + model, + parametrized_by_level=parametrized_by_level, + ) + if rv is None: + if child_param == "non-centered": + rv = pm.Normal( + f"{rv_name}_sample_mean", + mu=prior_mu, + sigma=at.sqrt( + prior_mu_sigma**2 + 1 / shape_from_dims(model, child_dim) + ), + dims=coef_dims, + ) + rv_pop_mean = pm.Normal( + rv_name, mu=prior_mu, sigma=prior_mu_sigma, dims=coef_dims + ) + print( + f"Non centered param, level={level}, sigma={rv.distribution.sigma.eval()}" + ) + else: + rv = pm.Normal( + rv_name, + mu=prior_mu, + sigma=prior_mu_sigma, + dims=coef_dims, + ) + print( + f"Centered param, level={level}, sigma={rv.distribution.sigma.eval()}" + ) + rv_pop_mean = rv else: - param = parametrization.get(dim, safe_param) - elif parametrization is None: - param = safe_param - else: - param = parametrization + if rv is None: + rv = pm.Normal(rv_name, prior_mu, prior_mu_sigma, dims=coef_dims) + rv_pop_mean = rv + stack.append((level + 1, child_dim, rv, child, rv_pop_mean)) + else: + param = read_parametrization( + parametrization, dim, level, model, parametrized_by_level=parametrized_by_level + ) - if param == "non-centered" and safe_param == "centered": + sd = scales[dim] + unused_scales.discard(dim) + if param not in ["centered", "non-centered"]: raise ValueError( - f"non-centered parametrization cannot be used for dim {dim} " - f"because it's shape {shape} is <2" + f"parameterization must be one of [centered, non-centered], but got {param}" ) + # When the current node has child nodes that should create non-centered RV's, we need to modify + # the current node's sigma. This means that we need to + # 1. Look at the child nodes and their parametrizations to signal whether we need to + # modify the current node's RV sigma or not + # 2. Create the current node's RV + # 3. Trace the current node's children and add them to the stack. - sd = scales[dim] - unused_scales.discard(dim) + # We first go through the child nodes to peek whether they require an RV or whether they use a + # non-centered parametrization + child_requires_rv = False + child_modifies_sigma = [] + child_size = [] + for i, (child_dim, child) in enumerate(node.items()): + if isinstance(child, dict): + child_requires_rv = True + child_param = read_parametrization( + parametrization, + child_dim, + level + 1, + model, + parametrized_by_level=parametrized_by_level, + ) + if child_param == "non-centered": + child_modifies_sigma.append(True) + else: + child_modifies_sigma.append(False) + child_size.append(shape_from_dims(model, child_dim)) + else: + child_modifies_sigma.append(False) + child_size.append(0) + # Now we create the current node's RV if param == "centered": - rv = pm.Normal( - f"{name_prefix}coefs_{dim}", - parent, - sd, - dims=(dim,) + coef_dims, - ) - elif param == "non-centered": - raw = CenteredNormal( - f"{name_prefix}coefs_{dim}_raw", - 1, - axis=0, - dims=(dim,) + coef_dims, + _sd = at.stack( + [ + at.sqrt(sd**2 + 1 / _size) if modifies_sigma else sd + for modifies_sigma, _size in zip(child_modifies_sigma, child_size) + ] ) - rv = pm.Deterministic( + print(f"Centered param, level={level}, sigma={_sd.eval()}") + rv = pm.Normal( f"{name_prefix}coefs_{dim}", - parent + raw * sd, + mu=parent, + sigma=_sd, dims=(dim,) + coef_dims, ) + rv_pop_mean = rv else: - raise ValueError( - f"parameterization must be one of [centered, non-centered], but got {param}" + _sd = at.stack( + [ + at.sqrt(1 + 1 / _size / sd**2) if modifies_sigma else 1.0 + for modifies_sigma, _size in zip(child_modifies_sigma, child_size) + ] ) + print(f"Non centered param, level={level}, sigma={_sd.eval()}") + if child_requires_rv: + raw = CenteredNormal( + f"{name_prefix}coefs_{dim}_sample_mean_raw", + _sd, + axis=0, + dims=(dim,) + coef_dims, + ) + rv = pm.Deterministic( + f"{name_prefix}coefs_{dim}_sample_mean", + parent + raw * sd, + dims=(dim,) + coef_dims, + ) + raw_pop_mean = pm.Normal( + f"{name_prefix}coefs_{dim}_raw", + mu=0, + sigma=1, + dims=(dim,) + coef_dims, + ) + rv_pop_mean = pm.Deterministic( + f"{name_prefix}coefs_{dim}", + parent + raw_pop_mean * sd, + dims=(dim,) + coef_dims, + ) + if _v4: + pm.Potential( + f"{name_prefix}coefs_{dim}_eps_logp", + pm.logp( + pm.Normal.dist( + mu=raw + (parent - parent_pop_mean) / sd, + sigma=1, + ), + 0, + ), + ) + else: + pm.Potential( + f"{name_prefix}coefs_{dim}_eps_logp", + pm.Normal.dist( + mu=raw + (parent - parent_pop_mean) / sd, + sigma=1, + ).logp(0), + ) + else: + raw = CenteredNormal( + f"{name_prefix}coefs_{dim}_raw", + 1, + axis=0, + dims=(dim,) + coef_dims, + ) + rv = pm.Deterministic( + f"{name_prefix}coefs_{dim}", + parent + raw * sd, + dims=(dim,) + coef_dims, + ) + rv_pop_mean = rv + + # Now we iterate through the children to add them to the stack for i, (child_dim, child) in enumerate(node.items()): - stack.append((level + 1, child_dim, rv[i], child)) + stack.append((level + 1, child_dim, rv[i], child, rv_pop_mean[i])) else: # TODO: We shouldn't be creating these scales as we are ignoring them diff --git a/preclinpack/tests/test_nested_hierarchy_rvs.py b/preclinpack/tests/test_nested_hierarchy_rvs.py index df43de8..a0ea5df 100644 --- a/preclinpack/tests/test_nested_hierarchy_rvs.py +++ b/preclinpack/tests/test_nested_hierarchy_rvs.py @@ -2,13 +2,17 @@ from random import choice import numpy as np +import pandas as pd import pytest from preclinpack.nested_hierarchy_rvs import ( make_hierarchical_coefs, make_hierarchical_weights, ) -from preclinpack.nested_hierarchy_utils import slice_data_by_nested_hierarchies +from preclinpack.nested_hierarchy_utils import ( + slice_data_by_nested_hierarchies, + sort_data, +) from preclinpack.tests.utils import ( get_keys_nested_dict, test_dataset, @@ -17,6 +21,47 @@ from preclinpack.utils import pm +@pytest.fixture(scope="module", params=[1, 2]) +def level_datasets(request): + if request.param == 1: + nested_hierarchies = ["level1"] + level1 = np.tile(np.array(["A", "B"]), (1000, 2)) + means = np.tile(np.array([1.0, -1.0]), (1000, 2)) + rng = np.random.default_rng() + obs = rng.normal(loc=means, scale=0.1) + dataset = pd.DataFrame( + data=np.stack( + [ + level1.flatten(), + obs.flatten(), + ], + axis=1, + ), + columns=["level1", "feature"], + ) + dataset = sort_data(dataset, ["level1"], []) + else: + nested_hierarchies = ["level1", "level2"] + level1 = np.tile(np.array([["A", "A", "A"], ["B", "B", "B"]]), (1000, 2, 3)) + level2 = np.tile(np.array([["AA", "AB", "AC"], ["BA", "BB", "BC"]]), (1000, 2, 3)) + means = np.tile(np.array([[1.5, 1.0, 0.5], [-0.5, -1.0, -1.5]]), (1000, 2, 3)) + rng = np.random.default_rng() + obs = rng.normal(loc=means, scale=0.1) + dataset = pd.DataFrame( + data=np.stack( + [ + level1.flatten(), + level2.flatten(), + obs.flatten(), + ], + axis=1, + ), + columns=["level1", "level2", "feature"], + ) + dataset = sort_data(dataset, ["level1", "level2"], []) + return nested_hierarchies, dataset + + class TestMakeHierarchicalWeights: @classmethod def setup_class(cls): @@ -430,3 +475,44 @@ def test_unsafe_parametrization(self): coef_dims="behavior", parametrization="non-centered", ) + + @pytest.mark.usefixtures("level_datasets") + def test_population_means(self, level_datasets): + nested_hierarchies, dataset = level_datasets + hierarchy_coords, hierarchy = slice_data_by_nested_hierarchies( + dataset, + nested_hierarchies, + return_slices=True, + ) + with pm.Model(coords=hierarchy_coords) as model1: + make_hierarchical_coefs( + hierarchy=hierarchy, + name_prefix="test", + parametrization="non-centered", + level_scale_kwargs={ + 1: {"learnable": False, "value": 1.0}, + 2: {"learnable": False, "value": 1.0}, + }, + ) + idata1 = pm.sample(chains=1, cores=1, tune=1000, draws=1000, return_inferencedata=True) + with pm.Model(coords=hierarchy_coords) as model2: + make_hierarchical_coefs( + hierarchy=hierarchy, + name_prefix="test", + parametrization="centered", + level_scale_kwargs={ + 1: {"learnable": False, "value": 1.0}, + 2: {"learnable": False, "value": 1.0}, + }, + ) + idata2 = pm.sample(chains=1, cores=1, tune=1000, draws=1000, return_inferencedata=True) + import arviz as az + + var_names = [ + k + for k in idata1.posterior.data_vars + if not (k.endswith("_raw") or k.endswith("truncated_") or "sample_mean" in k) + ] + print(var_names) + print(az.summary(idata1, var_names=var_names)) + print(az.summary(idata2, var_names=var_names)) diff --git a/preclinpack/tests/test_utils.py b/preclinpack/tests/test_utils.py index c3ef259..c4cad86 100644 --- a/preclinpack/tests/test_utils.py +++ b/preclinpack/tests/test_utils.py @@ -13,6 +13,8 @@ compute_scalar_log_likelihood, number_free_parameters, number_observed_datapoints, + read_parametrization, + shape_from_dims, ) @@ -175,3 +177,62 @@ def test_potential_warning(self): log_like = compute_scalar_log_likelihood(model, idata) assert np.all(np.isclose(log_like, st.norm(x_values, 1).logpdf(0))) + + +def test_shape_from_dims(): + N = 4 + with pm.Model(coords={"test": range(N)}) as model: + assert shape_from_dims(model, "test") == N + + +@pytest.mark.parametrize( + ["parametrization", "parametrized_by_level"], + [ + [ + {"A": "centered", "B": "non-centered", "C": "centered"}, + False, + ], + [ + {0: "centered", 1: "non-centered"}, + True, + ], + [ + "centered", + False, + ], + [ + "non-centered", + False, + ], + [ + None, + False, + ], + ], +) +@pytest.mark.parametrize(["level", "dim"], [[0, "A"], [1, "B"], [1, "C"]]) +def test_read_parametrization(parametrization, parametrized_by_level, level, dim): + with pm.Model(coords={"A": range(1), "B": range(3), "C": range(1)}) as model: + pass + if isinstance(parametrization, str): + expected_param = parametrization + elif isinstance(parametrization, dict): + expected_param = parametrization[level if parametrized_by_level else dim] + else: + expected_param = ( + "centered" + if (parametrized_by_level and level == 0) + or (not parametrized_by_level and dim in ["A", "C"]) + else "non-centered" + ) + if expected_param == "non-centered" and (level == 0 or dim == "C"): + match = f"non-centered parametrization cannot be used for {'level' if parametrized_by_level else 'dim'}" + with pytest.raises(ValueError, match=match): + read_parametrization( + parametrization, dim, level, model, parametrized_by_level=parametrized_by_level + ) + else: + param = read_parametrization( + parametrization, dim, level, model, parametrized_by_level=parametrized_by_level + ) + assert param == expected_param diff --git a/preclinpack/utils.py b/preclinpack/utils.py index f685ecc..64ca43d 100644 --- a/preclinpack/utils.py +++ b/preclinpack/utils.py @@ -145,3 +145,33 @@ def compute_scalar_log_likelihood(model, idata): ) return log_like + + +def shape_from_dims(model, dim): + if _v4: + (shape,) = model.shape_from_dims((dim,)) + else: + (shape,) = np.asarray(model.coords[dim]).shape + return shape + + +def read_parametrization(parametrization, dim, level, model, parametrized_by_level=False): + shape = shape_from_dims(model, dim) + safe_param = "non-centered" if shape > 1 else "centered" + if isinstance(parametrization, dict): + if parametrized_by_level: + param = parametrization.get(level, safe_param) + else: + param = parametrization.get(dim, safe_param) + elif parametrization is None: + param = safe_param + else: + param = parametrization + + if param == "non-centered" and safe_param == "centered": + raise ValueError( + f"non-centered parametrization cannot be used for " + f"{'level {}'.format(level) if parametrized_by_level else 'dim {}'.format(dim)} " + f"because it's shape {shape} is <2" + ) + return param