Skip to content

Conversation

xinglong-li
Copy link
Contributor

@xinglong-li xinglong-li commented Jun 7, 2022

add gauss_inv_wishart_utils.py, for gaussian inverse Wishart distribution;
add multivariate_t_utils.py computing log_pdf and log_prob_of_pos_predict for multivariate T distribution;
add gibbs_finite_mix_gauss_utils.py implementing Gibbs sampling for finite Gaussian mixture model;
add dp_mixgauss_utils.py implementing the forward simulation of DP mixture model and clustering analysis using DP mixture model

@xinglong-li
Copy link
Contributor Author

xinglong-li commented Jun 7, 2022

resolves probml/pyprobml#863

"""
Evaluating the logarithm of probability of the posterior predictive multivariate T distribution.
The likelihood of the observation given the parameter is Gaussian distribution.
The prior distribution is Normal Inverse Wishart (NIW) with parameters given by hyper_params.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment with all the math details spelled out (in latex form).

key: jax.random.PRNGKey
Seed of initial random cluster
--------------------------------------------
* array(N):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify that these are return parameters, and give the variable names (eg Z: array(N): ...)

from multivariate_t_utils import log_predic_t


def dp_mixture_simu(N, alpha, H, key):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to dp_mixgauss_ancestral_sample

Number of samples to be generated from the mixture model
alpha: float
Concentration parameter of the Dirichlet process
H: object of NormalInverseWishart
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to avoid short. ambiguous variable names. Replace H with niw_prior.

Z = jnp.full(N, 0)
# Sample cluster assignment from the Chinese restaurant process prior
CR = []
for i in range(N):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fun (optional!) exercise would be to figure out how to vectorize this (eg with lax.scan). Might be tricky because the shapes need to be of fixed size. I think you could pre-allocate CR to a fixed sized vector and then use a binary mask to select the 'valid' prefix.

return Z, jnp.array(X), Mu, Sigma


def dp_cluster(T, X, alpha, hyper_params, key):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give this function a more descriptive name, eg dp_mixgauss_gibbs_sample

new_label = 1
for t in range(T):
# Update the cluster assignment for every observation
for i in range(n):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be vectorized?

from multivariate_t_utils import log_predic_t


def gibbs_gmm(T, X, alpha, K, hyper_params, key):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename this mixgauss_gibbs_sample.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for these detailed comments. I'll try to fix these issues and take care in my future code.
I did find it hard to vectorize the DP mixture model since the sizes of many arrays are not fixed. In fact, even for the finite mixture model, the size of arrays of cluster assignment is not fixed.
As you mentioned, one possible solution is to pre-allocate a large enough vector and then use binary musks, and we might gain time efficiency by sacrificing some space. I'm just a little bit concerned if this approach is still viable when the data is huge.

@coveralls
Copy link

Pull Request Test Coverage Report for Build 2452176098

  • 0 of 151 (0.0%) changed or added relevant lines in 4 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.4%) to 11.526%

Changes Missing Coverage Covered Lines Changed/Added Lines %
probml_utils/gibbs_finite_mixgauss_utils.py 0 25 0.0%
probml_utils/multivariate_t_utils.py 0 29 0.0%
probml_utils/gauss_inv_wishart_utils.py 0 46 0.0%
probml_utils/dp_mixgauss_utils.py 0 51 0.0%
Totals Coverage Status
Change from base Build 2419970093: -0.4%
Covered Lines: 466
Relevant Lines: 4043

💛 - Coveralls

@murphyk murphyk merged commit fde1ce4 into probml:main Jun 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants