Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions pymc/tests/test_variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,63 @@ def simple_model(simple_model_data):
return model


@pytest.fixture
def hierarchical_model_data():
group_coords = {
"group_d1": np.arange(3),
"group_d2": np.arange(7),
}
group_shape = tuple(len(d) for d in group_coords.values())
data_coords = {"data_d": np.arange(11)} | group_coords

data_shape = tuple(len(d) for d in data_coords.values())

mu = -5.0

sigma_group_mu = 3
group_mu = sigma_group_mu * np.random.randn(*group_shape)

sigma = 3.0

data = sigma * np.random.randn(*data_shape) + group_mu + mu

return dict(
group_coords=group_coords,
group_shape=group_shape,
data_coords=data_coords,
data_shape=data_shape,
mu=mu,
sigma_group_mu=sigma_group_mu,
sigma=sigma,
group_mu=group_mu,
data=data,
)


@pytest.fixture
def hierarchical_model(hierarchical_model_data):
with pm.Model(coords=hierarchical_model_data["data_coords"]) as model:
mu = pm.Normal("mu", mu=0, sigma=10)
sigma_group_mu = pm.HalfNormal("sigma_group_mu", sigma=5)

group_mu = pm.Normal(
"group_mu",
mu=0,
sigma=sigma_group_mu,
dims=list(hierarchical_model_data["group_coords"].keys()),
)

sigma = pm.HalfNormal("sigma", sigma=3)

pm.Normal(
"data",
mu=(mu + group_mu),
sigma=sigma,
observed=hierarchical_model_data["data"],
)
return model


@pytest.fixture(
scope="module",
params=[
Expand Down Expand Up @@ -571,6 +628,27 @@ def test_fit_oo(inference, fit_kwargs, simple_model_data):
np.testing.assert_allclose(np.std(trace.posterior["mu"]), np.sqrt(1.0 / d), rtol=0.2)


def test_fit_data(inference, fit_kwargs, simple_model_data):
fitted = inference.fit(**fit_kwargs)
mu_post = simple_model_data["mu_post"]
d = simple_model_data["d"]
np.testing.assert_allclose(fitted.mean_data["mu"].values, mu_post, rtol=0.05)
np.testing.assert_allclose(fitted.std_data["mu"], np.sqrt(1.0 / d), rtol=0.2)


def test_fit_data_coords(hierarchical_model, hierarchical_model_data):
with hierarchical_model:
fitted = pm.fit(1)

for data in [fitted.mean_data, fitted.std_data]:
assert set(data.keys()) == {"sigma_group_mu_log__", "sigma_log__", "group_mu", "mu"}
assert data["group_mu"].shape == hierarchical_model_data["group_shape"]
assert list(data["group_mu"].coords.keys()) == list(
hierarchical_model_data["group_coords"].keys()
)
assert data["mu"].shape == tuple()


def test_profile(inference):
inference.run_profiling(n=100).summary()

Expand Down
32 changes: 32 additions & 0 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import aesara
import aesara.tensor as at
import numpy as np
import xarray

from aesara.graph.basic import Variable

Expand Down Expand Up @@ -1109,16 +1110,47 @@ def __str__(self):

@node_property
def std(self):
"""Standard deviation of the latent variables as an unstructured 1-dimensional Aesara variable"""
raise NotImplementedError

@node_property
def cov(self):
"""Covariance between the latent variables as an unstructured 2-dimensional Aesara variable"""
raise NotImplementedError

@node_property
def mean(self):
"""Mean of the latent variables as an unstructured 1-dimensional Aesara variable"""
raise NotImplementedError

def var_to_data(self, shared):
"""Takes a flat 1-dimensional Aesara variable and maps it to an xarray data set based on the information in
`self.ordering`.
"""
# This is somewhat similar to `DictToArrayBijection.rmap`, which doesn't work here since we don't have
# `RaveledVars` and need to take the information from `self.ordering` instead
shared = shared.eval()
result = dict()
for name, s, shape, dtype in self.ordering.values():
dims = self.model.RV_dims.get(name, None)
if dims is not None:
coords = {d: np.array(self.model.coords[d]) for d in dims}
else:
coords = None
values = np.array(shared[s]).reshape(shape).astype(dtype)
result[name] = xarray.DataArray(values, coords=coords, dims=dims, name=name)
return xarray.Dataset(result)

@property
def mean_data(self):
"""Mean of the latent variables as an xarray Dataset"""
return self.var_to_data(self.mean)

@property
def std_data(self):
"""Standard deviation of the latent variables as an xarray Dataset"""
return self.var_to_data(self.std)


group_for_params = Group.group_for_params
group_for_short_name = Group.group_for_short_name
Expand Down