Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix logpdf_grad errors in @dist DSL. #497

Merged
merged 4 commits into from
Jan 25, 2023

Conversation

ztangent
Copy link
Member

This PR addresses #496, including some of the indexing errors identified in that issue, but also by generalizing the backpropagation of gradients from the base distribution to the user-facing / custom-defined @dist distribution. Generalizing the backprop code was necessary to ensure that vector-valued arguments have appropriate gradients returned (e.g. the probability vector for a relabeled categorical), but also provides support for multivariate distributions (e.g. broadcasted_normal and mvnormal). Test cases have been added to check for this functionality.

The one additional test case I can think of is to make sure using a @dist-defined distribution in a generative function leads to the correct outputs (gradients, score, etc.). I can do this later today!

@alex-lew
Copy link
Contributor

This looks great, @ztangent! Thanks so much.

@alex-lew alex-lew merged commit cec0486 into master Jan 25, 2023
@ztangent ztangent deleted the 20230109-ztangent-fix_dist_dsl_grads branch March 18, 2024 20:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants