Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for dims in LKJCholeskyCov #6828

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 43 additions & 2 deletions pymc/distributions/multivariate.py
Expand Up @@ -18,6 +18,7 @@
import warnings

from functools import reduce
from itertools import product
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -1433,14 +1434,32 @@
"""

def __new__(cls, name, eta, n, sd_dist, *, compute_corr=True, store_in_trace=True, **kwargs):
dims = kwargs.pop("dims", None)

Check warning on line 1437 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1437

Added line #L1437 was not covered by tests

if dims is not None:

Check warning on line 1439 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1439

Added line #L1439 was not covered by tests
# TODO: Add check for 2d dims?
packed_dim_name, packed_dim_value = cls._make_packed_coord_from_dims(

Check warning on line 1441 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1441

Added line #L1441 was not covered by tests
n, dims, "packed_tril"
)
cls._register_new_coords_with_model(packed_dim_name, packed_dim_value)
kwargs["dims"] = [packed_dim_name]

Check warning on line 1445 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1444-L1445

Added lines #L1444 - L1445 were not covered by tests

packed_chol = _LKJCholeskyCov(name, eta=eta, n=n, sd_dist=sd_dist, **kwargs)

if not compute_corr:
return packed_chol
else:
chol, corr, stds = cls.helper_deterministics(n, packed_chol)
if store_in_trace:
corr = pm.Deterministic(f"{name}_corr", corr)
stds = pm.Deterministic(f"{name}_stds", stds)
corr_triu = corr[pt.triu_indices_from(corr, k=1)]
corr_triu_dim_name, corr_triu_dim_value = cls._make_packed_coord_from_dims(

Check warning on line 1455 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1454-L1455

Added lines #L1454 - L1455 were not covered by tests
n, dims, "corr", lower=False, k=1
)
cls._register_new_coords_with_model(corr_triu_dim_name, corr_triu_dim_value)

Check warning on line 1458 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1458

Added line #L1458 was not covered by tests

corr_tril = pm.Deterministic(f"{name}_corr", corr_triu, dims=corr_triu_dim_name)
stds = pm.Deterministic(f"{name}_stds", stds, dims=dims[0])

Check warning on line 1461 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1460-L1461

Added lines #L1460 - L1461 were not covered by tests

return chol, corr, stds

@classmethod
Expand All @@ -1463,6 +1482,28 @@
corr = inv_stds[None, :] * cov * inv_stds[:, None]
return chol, corr, stds

@classmethod
def _make_packed_coord_from_dims(cls, n, dims, name_prefix, lower=True, k=0):
mod = pm.modelcontext(None)
chol_dims = [mod.coords[dim] for dim in dims]
if lower:
f_idx = np.tril_indices

Check warning on line 1490 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1487-L1490

Added lines #L1487 - L1490 were not covered by tests
else:
f_idx = np.triu_indices

Check warning on line 1492 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1492

Added line #L1492 was not covered by tests

flat_tri_idx = np.arange(n**2, dtype=int).reshape(n, n)[f_idx(n, k=k)]
coord_product = np.fromiter([f"{x}" for x in product(*chol_dims)], dtype="object")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit scared that this could break things when we serialize traces. Can we at least store those in netcdf and zarr?
I wouldn't mind if this just had integer coords either...

tri_coords = coord_product[flat_tri_idx].tolist()

Check warning on line 1496 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1494-L1496

Added lines #L1494 - L1496 were not covered by tests

packed_dim_name = f"{name_prefix}_{dims[0]}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would much prefer a postfix instead of a prefix, especially in this case, because we already use that for the deterministics.

return packed_dim_name, tri_coords

Check warning on line 1499 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1498-L1499

Added lines #L1498 - L1499 were not covered by tests

@classmethod
def _register_new_coords_with_model(cls, name, value):
mod = pm.modelcontext(None)
mod.coords[name] = value
mod.dim_lengths[name] = pt.TensorConstant(pt.lscalar, np.array(len(value)))

Check warning on line 1505 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1503-L1505

Added lines #L1503 - L1505 were not covered by tests
Comment on lines +1504 to +1505
Copy link
Member

@ricardoV94 ricardoV94 Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably use model.add_coord/s and check name doesn't clash with existing one



class LKJCorrRV(RandomVariable):
name = "lkjcorr"
Expand Down