-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add KL Divergence helper #7062
base: main
Are you sure you want to change the base?
Add KL Divergence helper #7062
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7062 +/- ##
==========================================
+ Coverage 90.17% 90.18% +0.01%
==========================================
Files 101 103 +2
Lines 16932 16952 +20
==========================================
+ Hits 15269 15289 +20
Misses 1663 1663
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very good @ferrine! One thing that you definitely need to do before you can merge this is to add a page about this to the documentation. Maybe make a folder for logprob
and include a subfolder for the KL divergence, since it will start to grow as more divergences get added.
pymc/logprob/kl_divergence.py
Outdated
q_inputs: List[TensorVariable], | ||
p_inputs: List[TensorVariable], | ||
): | ||
_, _, _, q_mu, q_sigma = q_inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably consider size like moment does?
I found that our |
I am also confused about the design choice, since |
The base functionality should be in logprob, but the specific implementations should be in
You can but it won't work for things that look like Distributions but are just helpers to create distributions like |
Moved kl into a private pymc.distribution._stats because these are functions that will be never used by anyone |
I don't like the underscore, why not just |
@ricardoV94 can you please reiterate on the review? Did I miss something? |
Tests are failing with import issue |
How many pairs of distributions do we expect to actually be able to support? |
Many of them https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/kl_divergence |
_, _, _, q_mu, q_sigma = q_inputs | ||
_, _, _, p_mu, p_sigma = p_inputs | ||
diff_log_scale = pt.log(q_sigma) - pt.log(p_sigma) | ||
return ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May want to broadcast to size, like we do with moment
, if someone does kl_div(pm.Normal.dist(shape=5), pm.Normal.dist(mu=1))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still relevant, you're ignoring batch dimensions encoded in the size parameter
Just did the rebase, anything we can add or change on top of that? |
q_inputs: List[TensorVariable], | ||
p_inputs: List[TensorVariable], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is wrong, RVs have non tensor inputs as well
q_inputs: List[TensorVariable], | |
p_inputs: List[TensorVariable], | |
q_inputs: List[Variable], | |
p_inputs: List[Variable], |
kl = _kl_div( | ||
q_rv.owner.op, | ||
p_rv.owner.op, | ||
q_inputs=q_rv.owner.inputs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps pass the node instead of inputs. Allows stuff like op.dist_params(node) and op.size_param(node) inside the dispatch functions. Not sure though
Description
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7062.org.readthedocs.build/en/7062/