-
Couldn't load subscription status.
- Fork 2.1k
Refactor logpt calls to aeppl #5166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #5166 +/- ##
==========================================
+ Coverage 78.02% 78.08% +0.05%
==========================================
Files 88 88
Lines 14123 14161 +38
==========================================
+ Hits 11020 11057 +37
- Misses 3103 3104 +1
|
pymc/distributions/logprob.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't have time to properly check it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue with removing that is it'll give out unused input errors while constructing the graph down the line in aeppl.factorized_logprob. Which technically we could turn off but that' cause problems on user side (They wouldn't know of unused inputs if we do it implicitly)
It does work if we allow unused inputs, But the above stated is the only reason I decided to keep it in the end.
Also with current way of building the value variable dictionary (By iterating over the graph nodes using toposort). I don't think we can avoid adding that RV value in the first place unless we put in a similar Subtensor check for it's parent node.
Probably the best way to deal with this issue long term will be to avoid the use of tag.value_var and switching to value_var dictionary as the only way to provide value variables. (This approach will also remove a whole lot of extra logic from logpt)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kc611 Can you provide a MWE? It's hard to understand from the comment alone what is the purpose of that branch and whether it is still valid. For instance we no longer have the {var: None} thing going on in aeppl but the comment above seems to talk about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import aesara
import aesara.tensor as at
import numpy as np
import pymc as pm
size = 5
mu_base = np.arange(size)
mu = np.stack([mu_base, -mu_base])
sigma = 0.001
A_rv = pm.Normal.dist(mu, sigma)
A_rv.name = "A"
p = 0.5
I_rv = pm.Bernoulli.dist(p, size=size)
I_rv.name = "I"
A_idx = A_rv[I_rv]
A_idx_value_var = A_idx.type()
A_idx_value_var.name = "A_idx_value"
I_value_var = I_rv.type()
I_value_var.name = "I_value"
A_idx_logp = pm.logpt(A_idx, {A_idx: A_idx_value_var, I_rv: I_value_var}, sum=False)
# If you comment out the subtensor branch the following line will not run
logp_vals_fn = aesara.function([A_idx_value_var, I_value_var], A_idx_logp)
# But this will
logp_vals_fn = aesara.function([A_idx_value_var, I_value_var], A_idx_logp, on_unused_input='ignore')
aesara.dprint(logp_vals_fn)
# Graph without the subtensor branch in logpt
# Elemwise{Composite{Switch(Cast{int8}((GE(i0, i1) * LE(i0, i2))), i3, i4)}} [id A] 'I_value_logprob' 0
# |I_value [id B]
# |TensorConstant{(1,) of 0} [id C]
# |TensorConstant{(1,) of 1} [id D]
# |TensorConstant{(1,) of -0..1805599453} [id E]
# |TensorConstant{(1,) of -inf} [id F]
# Graph with the subtensor branch in logpt
# Elemwise{Add}[(0, 1)] [id A] '' 8
# |Elemwise{Composite{Switch(Cast{int8}((GE(i0, i1) * LE(i0, i2))), i3, i4)}} [id B] '' 3
# | |InplaceDimShuffle{x,0} [id C] '' 0
# | | |I_value [id D]
# | |TensorConstant{(1, 1) of 0} [id E]
# | |TensorConstant{(1, 1) of 1} [id F]
# | |TensorConstant{(1, 1) of ..1805599453} [id G]
# | |TensorConstant{(1, 1) of -inf} [id H]
# |Assert{msg='sigma > 0'} [id I] 'A_idx_value_logprob' 7
# |Elemwise{Composite{((i0 + (i1 * sqr(((i2 - i3) / i4)))) - log(i4))}}[(0, 4)] [id J] '' 5
# | |TensorConstant{(1, 1) of ..5332046727} [id K]
# | |TensorConstant{(1, 1) of -0.5} [id L]
# | |A_idx_value [id M]
# | |AdvancedSubtensor1 [id N] '' 2
# | | |TensorConstant{[[ 0 1 2..-2 -3 -4]]} [id O]
# | | |I_value [id D]
# | |AdvancedSubtensor1 [id P] '' 1
# | |TensorConstant{(2, 5) of 0.001} [id Q]
# | |I_value [id D]
# |All [id R] '' 6
# |Elemwise{gt,no_inplace} [id S] '' 4
# |AdvancedSubtensor1 [id P] '' 1
# |TensorConstant{(1, 1) of 0.0} [id T]There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This MWE is derived from the test_logprob.py file so there are more instances of subtensors over there which show similar behavior.
7de7f3d to
0be20b4
Compare
pymc/model.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kc611 Do you remember if there was a reason why we need two separate calls to logpt and why can't we just pass rv_values + obs_values once?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess in some cases there were overlapping keys in those value dictionaries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not an issue for dictionaries, right?
0be20b4 to
69cbff3
Compare
69cbff3 to
b00c8e9
Compare
This commit introduces a varlogp_nojact, varlogp_nojact and potentiallogpt properties
b00c8e9 to
d392957
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This snippet would raise an unnecessary warning before this PR:
CC @kc611