Skip to content

Commit

Permalink
Added shape inferring in LKJCholeskyCov
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 committed Jul 14, 2021
1 parent 6d82bab commit 88d34db
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
7 changes: 6 additions & 1 deletion pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,11 @@ def _lkj_normalizing_constant(eta, n):


class _LKJCholeskyCovRV(RandomVariable):
def _shape_from_params(self, dist_params, **kwargs):
n = dist_params[1]
dist_shape = ((n * (n + 1)) // 2,)
return dist_shape

def __init__(self, *args, sd_dist=None, **kwargs):
self.sd_dist = sd_dist
self._print_name = _print_name = ("_LKJCholeskyCov", "\\operatorname{_LKJCholeskyCov}")
Expand Down Expand Up @@ -1152,7 +1157,7 @@ def dist(cls, eta, n, sd_dist, *args, **kwargs):

cls.rv_op = _LKJCholeskyCovRV(
"_lkjcholeskycov",
0,
2,
(0, 0, 0),
"floatX",
sd_dist=sd_dist,
Expand Down
9 changes: 4 additions & 5 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3140,14 +3140,13 @@ def test_lkjcholeskycov():
dist_shape = ((D * (D + 1)) // 2,)

with pm.Model() as model:
sd_dist = pm.HalfCauchy.dist(beta=2.5, size=(D))
sd_dist = pm.HalfCauchy.dist(beta=2.5)
packedL = pm.LKJCholeskyCov("packedL", eta=2, n=D, sd_dist=sd_dist)

with model:
prior = pm.sample()

pt = np.random.random(120)
pt = {"packedL_cholesky-cov-packed__": pt}
logp = model.fastlogp(pt)

with model:
prior = pm.sample_prior_predictive(5)

assert prior["packedL"].shape == (samples,) + dist_shape

0 comments on commit 88d34db

Please sign in to comment.