-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Distributions GenerativeFunctions #274
base: master
Are you sure you want to change the base?
Make Distributions GenerativeFunctions #274
Conversation
…eorgematheos-distributionsasgenfns
@georgematheos For these benchmark results, could you also run the experiment but at 10x the number of datapoints? That will help ensure that nothing is sneakily asymptotically slower (though I don't see why it would be). |
@alex-lew Here is some more benchmarking for asymptotics. I did 10x the datapoints for the static DSL, and 1/10 the points for the dynamic DSL (since this is slow, and has superlinear performance, so it takes a very long time at 10x the data points). I also did a few runs on each since I found the results varied a decent amount: This PR:
Master branch:
So it looks like there is no asymptotic difference between the old implementation and the new one. Again we see that the dynamic DSL is slowed down a bit by this PR (I wouldn't be surprised if it is slowed down more in this small example since there's less time for the compiler to optimize). It also looks like in these runs the speedup in the static DSL do not seem to appear (though there doesn't seem to be a significant slowdown either). |
@marcoct I agree that changing the choicemap interface may be reasonable. Soon, I will implement the In terms of implementing iteration behavior instead of |
@georgematheos I'm interested in seeing if we can merge this in for the next breaking-changes release. (This PR mostly doesn't break anything for users, but it does change the ChoiceMap interface a bit, e.g. the behavior of One question I have about the implementation is why the gradient-based GFI methods still special-case on |
@alex-lew not totally sure why I did this -- I can take a look in more detail next week. If I remember correctly, the reason may be that at the time when I wrote the pull request, I did not understand Gen's backprop code, so I tried to keep the back-prop code as similar to the past version as possible, to make sure I didn't break anything. |
src/distribution.jl
Outdated
args | ||
score::Float64 | ||
end | ||
@inline dist(::DistributionTrace{T, Dist}) where {T, Dist} = Dist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@georgematheos Unfortunately, this implementation does not support more complicated distributions that are not 'singleton' structs, like the Mixture
distributions. The problem is that they do not have zero-argument constructors. For example, HomogeneousMixture
's definition looks like this:
struct HomogeneousMixture{T} <: Distribution{T}
base_dist::Distribution{T}
dims::Vector{Int}
end
If someone creates such a distribution, and then simulates a trace from it, the trace does not remember the base_dist
and the dims
, and so it cannot implement Gen.get_gen_fn
.
One option is for the DistributionTrace
to store a reference to the distribution object, and to have dist(d::DistributionTrace) = d.dist
. That adds slight storage overhead, but maybe not enough to worry about -- it's one more pointer per choice in a trace.
Another option is to say that the Mixtures are not literally subtypes of Distribution -- they are just additional generative functions with ValueChoiceMaps.
I lean toward the first option -- thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One option is for the DistributionTrace to store a reference to the distribution object, and to have dist(d::DistributionTrace) = d.dist. That adds slight storage overhead, but maybe not enough to worry about -- it's one more pointer per choice in a trace.
I think this may actually add zero storage overhead if the DistributionTrace
is parametrically typed, and the type in question is a singleton (as it is for Normal
etc.)! My guess is that Julia optimizes that sort of thing away. But it's worth checking.
All right, I think I've removed most (all?) of the special-casing that the dynamic and static DSLs do on the One interesting benefit of this PR's changes is that it is now possible to, from the trace of a program, figure out what distribution each choice was drawn from. This could be used to implement, e.g., automatic Gibbs inference for discrete choices without the user needing to pass in a list of valid values (because the function could just deduce the support of the choice, based on However, merging this PR would break external packages that examine the IR of Gen static DSL functions and expect to find |
This builds on #263 and resolves #259
I make
Distribution <: GenerativeFunction
true, and remove a lot of code in the static and dynamic DSLs specialized to distributions. Likewise, I remove theChoiceAt
combinator.Note that I have not significantly modified the gradient code, since I have not yet taken the time to understand how it works. (So it still dispatches on whether something is a distribution or a generative function.) We may be able to further reduce the code footprint by removing specialization to distributions in the gradient calculations.
I have also added a
test/benchmark
folder with a couple initial MH benchmarks adapted from theexamples
; we can add and revise benchmarks as we create them. Currently the benchmarks are running MH on a static DSL and dynamic DSL model I took from the examples folder.My initial benchmark results were sort of volatile (varied a lot between runs), and more careful benchmarking should be done. It looks like the changes somewhat improve performance for static code, but slow down dynamic code somewhere between 1.1x and 1.6x. The dynamic performance slowdown appears to be caused by my previous PR (
ValueChoiceMap
), not the distributions-as-generative-functions PR. That said, I don't see how to do this PR without building on the other. I have not thoroughly investigated what causes the dynamic performance reduction; we may be able to improve this.Benchmarking results:
after this PR:
After the
ValueChoiceMap
PR but before this PR:Before either PR: