Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for dims in LKJCholeskyCov #6828

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jul 14, 2023

What is this PR about?

There are currently several pain points when using labeled dims withLKJCholeskyCov:

  1. The distribution samples 2d matrices, but passing a pair of dimensions results in an error, because internally the distribution is represented in packed lower-triangular form.
  2. Any dimensions given to LKJCholeskyCov are not propagated to internally generated deterministics, {name}_std and {name}_corr. If a user wants these to be labeled, he has to pass a long dictionary to idata_kwargs
  3. After sampling, the 1's on the diagonal of all samples drawn from {name}_corr causes an error in arviz when computing within-chain variance.

This PR tries to correct all three of these. Here is an example model under this PR:

from string import ascii_uppercase

n = 3
n_obs = 100
mean = np.zeros(n)
L = np.random.normal(size=(n, n))
cov = L @ L.T
data = np.random.multivariate_normal(mean=mean, cov=cov, size=(n_obs, ))

with pm.Model(coords={'dim':ascii_uppercase[:n], 
                      'dim_aux':ascii_uppercase[:n]},
              coords_mutable={'obs_idx':np.arange(n_obs, dtype='int')}) as mod:
    sd_dist = pm.Exponential.dist(1)
    chol, *_ = LKJCholeskyCov('chol', n=n, sd_dist=sd_dist, eta=1, dims=['dim', 'dim_aux'])
    obs = pm.MvNormal('obs', mu=0, chol=chol, observed=data, dims=['obs_idx', 'dim'])
    
    idata = pm.sample()

First, I pass two dimensions to LKJCholeskyCov -- one for the columns, and one for the rows. This corresponds to the expectation that I am drawing from a matrix-valued random variable.

Internally, I take the Cartesian product between these two dims, and use the lower triangle of the resulting matrix to make and register a new coordinate: packed_tril_{name}. This is then set as the dims on packed_chol.

Next, only the upper triangle (excluding the diagonal) of the correlation matrix is stored in a deterministic. Another new coordinate is registered: corr_{name}.

Finally, the first dim is used to add a labeled dimension to {name}_std.

This results in the following graph:
image

Here is the result plotted with az.plot_trace:
image

This PR still needs a bit of work, including:

  1. Unit tests for the new functionality
  2. Documentation
  3. The generated dimensions are added in a "hacky" way, I hope this can be improved
  4. The names on the "packed" dimensions of chol and cov are not great. It would be nice if a MultiIndex could be specified here, but I don't think it's currently possible without some changes (see here, but maybe this is out of date?)

It's also possible that the problem (3) is a problem on the arviz side of things, and should be fixed there instead of here. But in that case, it would still be nice to propagate the matrix dims to the full square correlation matrix.

Because of these points, I'm marking this as a draft PR. But I would still like feedback on the idea of automatically generating coords, or at least on how dim handling can be improved in LKJCholeskyCov

Checklist

Major / Breaking Changes

  • None

New features

  • Allow LKJCholeskyCov to generate and register new model coords corresponding to the distributions it internally registers

Bugfixes

  • Allows plotting of generated correlation matrix in az.plot_trace, but that might not be something that should be fixed on the PyMC side. See discussion above.

Documentation

-None

Maintenance

-None


馃摎 Documentation preview 馃摎: https://pymc--6828.org.readthedocs.build/en/6828/

@codecov
Copy link

codecov bot commented Jul 14, 2023

Codecov Report

Merging #6828 (7ecbfbd) into main (13e7c88) will decrease coverage by 25.20%.
The diff coverage is 23.68%.

Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main    #6828       +/-   ##
===========================================
- Coverage   92.02%   66.82%   -25.20%     
===========================================
  Files          95       95               
  Lines       16262    16293       +31     
===========================================
- Hits        14965    10888     -4077     
- Misses       1297     5405     +4108     
Impacted Files Coverage 螖
pymc/distributions/multivariate.py 35.64% <17.85%> (-56.57%) 猬囷笍
pymc/distributions/distribution.py 71.63% <33.33%> (-25.43%) 猬囷笍
pymc/model.py 89.87% <42.85%> (-1.09%) 猬囷笍

... and 51 files with indirect coverage changes

Comment on lines +1504 to +1505
mod.coords[name] = value
mod.dim_lengths[name] = pt.TensorConstant(pt.lscalar, np.array(len(value)))
Copy link
Member

@ricardoV94 ricardoV94 Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably use model.add_coord/s and check name doesn't clash with existing one

f_idx = np.triu_indices

flat_tri_idx = np.arange(n**2, dtype=int).reshape(n, n)[f_idx(n, k=k)]
coord_product = np.fromiter([f"{x}" for x in product(*chol_dims)], dtype="object")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit scared that this could break things when we serialize traces. Can we at least store those in netcdf and zarr?
I wouldn't mind if this just had integer coords either...

coord_product = np.fromiter([f"{x}" for x in product(*chol_dims)], dtype="object")
tri_coords = coord_product[flat_tri_idx].tolist()

packed_dim_name = f"{name_prefix}_{dims[0]}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would much prefer a postfix instead of a prefix, especially in this case, because we already use that for the deterministics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants