Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:
- id: pyupgrade
args: [--py37-plus]
- repo: https://github.com/psf/black
rev: 21.11b1
rev: 22.3.0
hooks:
- id: black
- repo: https://github.com/PyCQA/pylint
Expand Down
205 changes: 166 additions & 39 deletions preclinpack/nested_hierarchy_rvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@

try:
import pymc as pm

from aesara import tensor as at

_v4 = True
except ImportError:
import pymc3 as pm

from .utils import CenteredNormal, dims_to_tuple
from theano import tensor as at

_v4 = False

from .utils import CenteredNormal, dims_to_tuple, read_parametrization, shape_from_dims


def make_hierarchical_weights(
Expand Down Expand Up @@ -98,7 +106,7 @@ def make_hierarchical_coefs(
)
coef_dims = dims_to_tuple(coef_dims)
leaf_rvs = {}
stack = [(0, None, None, hierarchy)]
stack = [(0, None, None, hierarchy, None)]
scales = make_hierarchical_weights(
hierarchy=hierarchy,
name_prefix=name_prefix,
Expand All @@ -107,59 +115,178 @@ def make_hierarchical_coefs(
)
unused_scales = set(scales.keys())
while stack:
level, dim, parent, node = stack.pop(0)
level, dim, parent, node, parent_pop_mean = stack.pop(0)
if isinstance(node, dict):
if parent is None:
rv = pm.Normal(f"{name_prefix}prior_mu", prior_mu, prior_mu_sigma, dims=coef_dims)
rv_name = f"{name_prefix}prior_mu"
rv = None
for i, (child_dim, child) in enumerate(node.items()):
stack.append((level + 1, child_dim, rv, child))
else:
# We can't use a CenteredNormal if dim size is 1
(shape,) = model.shape_from_dims((dim,))
safe_param = "non-centered" if shape > 1 else "centered"
if isinstance(parametrization, dict):
if parametrized_by_level:
param = parametrization.get(level, safe_param)
if isinstance(child, dict):
# If the next level follows a non-centered parametrization,
# we need to add the population mean RV to the current level
child_param = read_parametrization(
parametrization,
child_dim,
level + 1,
model,
parametrized_by_level=parametrized_by_level,
)
if rv is None:
if child_param == "non-centered":
rv = pm.Normal(
f"{rv_name}_sample_mean",
mu=prior_mu,
sigma=at.sqrt(
prior_mu_sigma**2 + 1 / shape_from_dims(model, child_dim)
),
dims=coef_dims,
)
rv_pop_mean = pm.Normal(
rv_name, mu=prior_mu, sigma=prior_mu_sigma, dims=coef_dims
)
print(
f"Non centered param, level={level}, sigma={rv.distribution.sigma.eval()}"
)
else:
rv = pm.Normal(
rv_name,
mu=prior_mu,
sigma=prior_mu_sigma,
dims=coef_dims,
)
print(
f"Centered param, level={level}, sigma={rv.distribution.sigma.eval()}"
)
rv_pop_mean = rv
else:
param = parametrization.get(dim, safe_param)
elif parametrization is None:
param = safe_param
else:
param = parametrization
if rv is None:
rv = pm.Normal(rv_name, prior_mu, prior_mu_sigma, dims=coef_dims)
rv_pop_mean = rv
stack.append((level + 1, child_dim, rv, child, rv_pop_mean))
else:
param = read_parametrization(
parametrization, dim, level, model, parametrized_by_level=parametrized_by_level
)

if param == "non-centered" and safe_param == "centered":
sd = scales[dim]
unused_scales.discard(dim)
if param not in ["centered", "non-centered"]:
raise ValueError(
f"non-centered parametrization cannot be used for dim {dim} "
f"because it's shape {shape} is <2"
f"parameterization must be one of [centered, non-centered], but got {param}"
)
# When the current node has child nodes that should create non-centered RV's, we need to modify
# the current node's sigma. This means that we need to
# 1. Look at the child nodes and their parametrizations to signal whether we need to
# modify the current node's RV sigma or not
# 2. Create the current node's RV
# 3. Trace the current node's children and add them to the stack.

sd = scales[dim]
unused_scales.discard(dim)
# We first go through the child nodes to peek whether they require an RV or whether they use a
# non-centered parametrization
child_requires_rv = False
child_modifies_sigma = []
child_size = []
for i, (child_dim, child) in enumerate(node.items()):
if isinstance(child, dict):
child_requires_rv = True
child_param = read_parametrization(
parametrization,
child_dim,
level + 1,
model,
parametrized_by_level=parametrized_by_level,
)
if child_param == "non-centered":
child_modifies_sigma.append(True)
else:
child_modifies_sigma.append(False)
child_size.append(shape_from_dims(model, child_dim))
else:
child_modifies_sigma.append(False)
child_size.append(0)
# Now we create the current node's RV
if param == "centered":
rv = pm.Normal(
f"{name_prefix}coefs_{dim}",
parent,
sd,
dims=(dim,) + coef_dims,
)
elif param == "non-centered":
raw = CenteredNormal(
f"{name_prefix}coefs_{dim}_raw",
1,
axis=0,
dims=(dim,) + coef_dims,
_sd = at.stack(
[
at.sqrt(sd**2 + 1 / _size) if modifies_sigma else sd
for modifies_sigma, _size in zip(child_modifies_sigma, child_size)
]
)
rv = pm.Deterministic(
print(f"Centered param, level={level}, sigma={_sd.eval()}")
rv = pm.Normal(
f"{name_prefix}coefs_{dim}",
parent + raw * sd,
mu=parent,
sigma=_sd,
dims=(dim,) + coef_dims,
)
rv_pop_mean = rv
else:
raise ValueError(
f"parameterization must be one of [centered, non-centered], but got {param}"
_sd = at.stack(
[
at.sqrt(1 + 1 / _size / sd**2) if modifies_sigma else 1.0
for modifies_sigma, _size in zip(child_modifies_sigma, child_size)
]
)
print(f"Non centered param, level={level}, sigma={_sd.eval()}")
if child_requires_rv:
raw = CenteredNormal(
f"{name_prefix}coefs_{dim}_sample_mean_raw",
_sd,
axis=0,
dims=(dim,) + coef_dims,
)
rv = pm.Deterministic(
f"{name_prefix}coefs_{dim}_sample_mean",
parent + raw * sd,
dims=(dim,) + coef_dims,
)
raw_pop_mean = pm.Normal(
f"{name_prefix}coefs_{dim}_raw",
mu=0,
sigma=1,
dims=(dim,) + coef_dims,
)
rv_pop_mean = pm.Deterministic(
f"{name_prefix}coefs_{dim}",
parent + raw_pop_mean * sd,
dims=(dim,) + coef_dims,
)
if _v4:
pm.Potential(
f"{name_prefix}coefs_{dim}_eps_logp",
pm.logp(
pm.Normal.dist(
mu=raw + (parent - parent_pop_mean) / sd,
sigma=1,
),
0,
),
)
else:
pm.Potential(
f"{name_prefix}coefs_{dim}_eps_logp",
pm.Normal.dist(
mu=raw + (parent - parent_pop_mean) / sd,
sigma=1,
).logp(0),
)
else:
raw = CenteredNormal(
f"{name_prefix}coefs_{dim}_raw",
1,
axis=0,
dims=(dim,) + coef_dims,
)
rv = pm.Deterministic(
f"{name_prefix}coefs_{dim}",
parent + raw * sd,
dims=(dim,) + coef_dims,
)
rv_pop_mean = rv

# Now we iterate through the children to add them to the stack
for i, (child_dim, child) in enumerate(node.items()):
stack.append((level + 1, child_dim, rv[i], child))
stack.append((level + 1, child_dim, rv[i], child, rv_pop_mean[i]))

else:
# TODO: We shouldn't be creating these scales as we are ignoring them
Expand Down
88 changes: 87 additions & 1 deletion preclinpack/tests/test_nested_hierarchy_rvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
from random import choice

import numpy as np
import pandas as pd
import pytest

from preclinpack.nested_hierarchy_rvs import (
make_hierarchical_coefs,
make_hierarchical_weights,
)
from preclinpack.nested_hierarchy_utils import slice_data_by_nested_hierarchies
from preclinpack.nested_hierarchy_utils import (
slice_data_by_nested_hierarchies,
sort_data,
)
from preclinpack.tests.utils import (
get_keys_nested_dict,
test_dataset,
Expand All @@ -17,6 +21,47 @@
from preclinpack.utils import pm


@pytest.fixture(scope="module", params=[1, 2])
def level_datasets(request):
if request.param == 1:
nested_hierarchies = ["level1"]
level1 = np.tile(np.array(["A", "B"]), (1000, 2))
means = np.tile(np.array([1.0, -1.0]), (1000, 2))
rng = np.random.default_rng()
obs = rng.normal(loc=means, scale=0.1)
dataset = pd.DataFrame(
data=np.stack(
[
level1.flatten(),
obs.flatten(),
],
axis=1,
),
columns=["level1", "feature"],
)
dataset = sort_data(dataset, ["level1"], [])
else:
nested_hierarchies = ["level1", "level2"]
level1 = np.tile(np.array([["A", "A", "A"], ["B", "B", "B"]]), (1000, 2, 3))
level2 = np.tile(np.array([["AA", "AB", "AC"], ["BA", "BB", "BC"]]), (1000, 2, 3))
means = np.tile(np.array([[1.5, 1.0, 0.5], [-0.5, -1.0, -1.5]]), (1000, 2, 3))
rng = np.random.default_rng()
obs = rng.normal(loc=means, scale=0.1)
dataset = pd.DataFrame(
data=np.stack(
[
level1.flatten(),
level2.flatten(),
obs.flatten(),
],
axis=1,
),
columns=["level1", "level2", "feature"],
)
dataset = sort_data(dataset, ["level1", "level2"], [])
return nested_hierarchies, dataset


class TestMakeHierarchicalWeights:
@classmethod
def setup_class(cls):
Expand Down Expand Up @@ -430,3 +475,44 @@ def test_unsafe_parametrization(self):
coef_dims="behavior",
parametrization="non-centered",
)

@pytest.mark.usefixtures("level_datasets")
def test_population_means(self, level_datasets):
nested_hierarchies, dataset = level_datasets
hierarchy_coords, hierarchy = slice_data_by_nested_hierarchies(
dataset,
nested_hierarchies,
return_slices=True,
)
with pm.Model(coords=hierarchy_coords) as model1:
make_hierarchical_coefs(
hierarchy=hierarchy,
name_prefix="test",
parametrization="non-centered",
level_scale_kwargs={
1: {"learnable": False, "value": 1.0},
2: {"learnable": False, "value": 1.0},
},
)
idata1 = pm.sample(chains=1, cores=1, tune=1000, draws=1000, return_inferencedata=True)
with pm.Model(coords=hierarchy_coords) as model2:
make_hierarchical_coefs(
hierarchy=hierarchy,
name_prefix="test",
parametrization="centered",
level_scale_kwargs={
1: {"learnable": False, "value": 1.0},
2: {"learnable": False, "value": 1.0},
},
)
idata2 = pm.sample(chains=1, cores=1, tune=1000, draws=1000, return_inferencedata=True)
import arviz as az

var_names = [
k
for k in idata1.posterior.data_vars
if not (k.endswith("_raw") or k.endswith("truncated_") or "sample_mean" in k)
]
print(var_names)
print(az.summary(idata1, var_names=var_names))
print(az.summary(idata2, var_names=var_names))
Loading