Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: Gumbel Mixture Models #1598

Open
bryorsnef opened this issue Aug 3, 2022 · 4 comments
Open

Feature Request: Gumbel Mixture Models #1598

bryorsnef opened this issue Aug 3, 2022 · 4 comments

Comments

@bryorsnef
Copy link

It is possible to construct reparameterizable mixture distributions by replacing the categorical distribution with a gumbel (relaxed categorical) distribution. The ability to use a relaxed one-hot categorical distribution in mixture or mixtureSameFamily would be potentially very useful.

Differentiable mixture distributions implemented in torch here:
https://github.com/nextBillyonair/DPM/blob/master/dpm/distributions/gumbel_mixture_model.py

@bryorsnef
Copy link
Author

I believe https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/RelaxedOneHotCategorical is what you are looking for (the distribution goes under the name Relaxed One Hot Categorical, Gumbel Softmax and Concrete in the literature).

Yeah, the distribution is implemented, I was saying it would be useful to also have a mixture and mixtureSameFamily meta distribution that can take a relaxed one hot categorical as the cat argument instead of the categorical.

@bryorsnef
Copy link
Author

From the looks of https://github.com/tensorflow/probability/blob/v0.17.0/tensorflow_probability/python/distributions/mixture_same_family.py#L266-L270 the changes needed to allow this don't look too complicated. In _sample_n, the mask can be replaced with samples from the relaxed_one_hot_categorical mixture selecting distribution in lines 266-270

mask = tf.one_hot( indices=mix_sample, # [n, B] depth=num_components, on_value=npdt(1), off_value=npdt(0)) # [n, B, k]

and remove this check for the mixture distribution dtypes.

if is_init and not dtype_util.is_integer(self.mixture_distribution.dtype): raise ValueError( 'mixture_distribution.dtype({}) is not over integers'.format( dtype_util.name(self.mixture_distribution.dtype)))

Other methods, like _log_prob, do not appear to need any changes.

@brianwa84
Copy link
Contributor

brianwa84 commented Oct 11, 2022 via email

@bryorsnef
Copy link
Author

bryorsnef commented Oct 11, 2022 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants