-
Notifications
You must be signed in to change notification settings - Fork 21
add dp_mixgauss related functions #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
resolves probml/pyprobml#863 |
""" | ||
Evaluating the logarithm of probability of the posterior predictive multivariate T distribution. | ||
The likelihood of the observation given the parameter is Gaussian distribution. | ||
The prior distribution is Normal Inverse Wishart (NIW) with parameters given by hyper_params. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment with all the math details spelled out (in latex form).
key: jax.random.PRNGKey | ||
Seed of initial random cluster | ||
-------------------------------------------- | ||
* array(N): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specify that these are return parameters, and give the variable names (eg Z: array(N): ...
)
from multivariate_t_utils import log_predic_t | ||
|
||
|
||
def dp_mixture_simu(N, alpha, H, key): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename to dp_mixgauss_ancestral_sample
Number of samples to be generated from the mixture model | ||
alpha: float | ||
Concentration parameter of the Dirichlet process | ||
H: object of NormalInverseWishart |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to avoid short. ambiguous variable names. Replace H
with niw_prior
.
Z = jnp.full(N, 0) | ||
# Sample cluster assignment from the Chinese restaurant process prior | ||
CR = [] | ||
for i in range(N): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A fun (optional!) exercise would be to figure out how to vectorize this (eg with lax.scan
). Might be tricky because the shapes need to be of fixed size. I think you could pre-allocate CR to a fixed sized vector and then use a binary mask to select the 'valid' prefix.
return Z, jnp.array(X), Mu, Sigma | ||
|
||
|
||
def dp_cluster(T, X, alpha, hyper_params, key): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Give this function a more descriptive name, eg dp_mixgauss_gibbs_sample
new_label = 1 | ||
for t in range(T): | ||
# Update the cluster assignment for every observation | ||
for i in range(n): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be vectorized?
from multivariate_t_utils import log_predic_t | ||
|
||
|
||
def gibbs_gmm(T, X, alpha, K, hyper_params, key): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename this mixgauss_gibbs_sample
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for these detailed comments. I'll try to fix these issues and take care in my future code.
I did find it hard to vectorize the DP mixture model since the sizes of many arrays are not fixed. In fact, even for the finite mixture model, the size of arrays of cluster assignment is not fixed.
As you mentioned, one possible solution is to pre-allocate a large enough vector and then use binary musks, and we might gain time efficiency by sacrificing some space. I'm just a little bit concerned if this approach is still viable when the data is huge.
Pull Request Test Coverage Report for Build 2452176098
💛 - Coveralls |
add gauss_inv_wishart_utils.py, for gaussian inverse Wishart distribution;
add multivariate_t_utils.py computing log_pdf and log_prob_of_pos_predict for multivariate T distribution;
add gibbs_finite_mix_gauss_utils.py implementing Gibbs sampling for finite Gaussian mixture model;
add dp_mixgauss_utils.py implementing the forward simulation of DP mixture model and clustering analysis using DP mixture model