-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Theano graph optimization to remove normalizing constants from logp #4396
Comments
I wrote a small NB to set up the environment to start experimenting with this: https://gist.github.com/twiecki/e758db2c3d2df5f3368fc49e6087e58f |
Here's an implementation that's nearly 75% of the way there: import numpy as np
import theano.tensor as tt
from theano import config
from theano.compile import optdb
from theano.gof.fg import FunctionGraph
from theano.gof.graph import inputs as tt_inputs
from theano.gof.opt import EquilibriumOptimizer, PatternSub
from theano.gof.optdb import Query
from theano.printing import debugprint as tt_dprint
from theano.tensor.opt import get_clients
# We don't need to waste time compiling graphs to C
config.cxx = ""
# a / b -> a * 1/b, for a != 1 and b != 1
div_to_mul_pattern = PatternSub(
(tt.true_div, "a", "b"),
(tt.mul, "a", (tt.inv, "b")),
allow_multiple_clients=True,
name="div_to_mul",
tracks=[tt.true_div],
get_nodes=get_clients,
)
# a - b -> a + (-b)
sub_to_add_pattern = PatternSub(
(tt.sub, "a", "b"),
(tt.add, "a", (tt.neg, "b")),
allow_multiple_clients=True,
name="sub_to_add",
tracks=[tt.sub],
get_nodes=get_clients,
)
# a * (x + y) -> a * x + a * y
distribute_mul_pattern = PatternSub(
(tt.mul, "a", (tt.add, "x", "y")),
(tt.add, (tt.mul, "a", "x"), (tt.mul, "a", "y")),
allow_multiple_clients=True,
name="distribute_mul",
tracks=[tt.mul],
get_nodes=get_clients,
)
expand_opt = EquilibriumOptimizer(
[div_to_mul_pattern, distribute_mul_pattern, sub_to_add_pattern],
ignore_newtrees=False,
tracks_on_change_inputs=True,
max_use_ratio=config.optdb__max_use_ratio,
)
def optimize_graph(fgraph, include=["canonicalize"], custom_opt=None, **kwargs):
if not isinstance(fgraph, FunctionGraph):
inputs = tt_inputs([fgraph])
fgraph = FunctionGraph(inputs, [fgraph], clone=False)
canonicalize_opt = optdb.query(Query(include=include, **kwargs))
_ = canonicalize_opt.optimize(fgraph)
if custom_opt:
custom_opt.optimize(fgraph)
return fgraph
tau = tt.dscalar("tau")
value = tt.dscalar("value")
mu = tt.dscalar("mu")
logp = (-tau * (value - mu) ** 2 + tt.log(tau / np.pi / 2.0)) / 2.0
logp_fg = optimize_graph(logp, custom_opt=expand_opt)
tt_dprint(logp_fg)
# TODO: Remove additive terms that do not contain the desired terms (e.g. `mu`
# and `tau` when the likelihood is only a function of `mu`, `tau`)
# This is what we want from the optimization
# logp_goal = -tau * (value - mu) ** 2 |
Updated the gist: https://gist.github.com/e758db2c3d2df5f3368fc49e6087e58f |
Hey @twiecki, I am currently working on this issue. I went through the documentation and I could not find any inbuilt method like |
That's great @chandan5362 - this is a super high-impact issue. Good questions, I'm not sure. Probably @brandonwillard knows. Did you make any progress on the algorithm itself? |
As per my understanding, most of the optimizations have already been done. only things that are left is to remove the node form the |
I think there is still some complexity to figure out. What's not implemented yet is identifying sub-terms that contain |
Now I guess, I will need a little more information about the nodes that can be removed or can't be. where do we actually need to remove the normalisation term ? Just to have a clear insight, could you please provide a reference to such sampling method ? |
The outline in that Gist already covers most of it. There's more than one way to accomplish this, but, given what's already in the Gist, one can use the Here's a simple illustration: 1.) Apply the existing transforms: 2.) Now, we know that we can "gather"/"collect" the 3.) Of the terms that are present in all All of the graph manipulation should be done within the context of an |
Thanks for the help @brandonwillard . |
Yes, |
Huh, how can we infer the posterior of |
If our arguments of interest are |
@chandan5362 correct, we can only remove terms that are constants. |
@chandan5362 are you still working on this? need any help? Here is a JAX implementation doing the same thing: rlouf/mcx#71 |
I highly advise against using that as a reference. It looks like a very ground-up implementation that uses a considerable amount of functionality that overlaps with a lot of existing machinery in Theano/Aesara—both in terms of what it's trying to do and how it's doing it. Also, from a quick reading, it looks like it could be doing many of the exact same things discussed here, albeit in a much different context. While I'm on the topic, I don't think I explicitly mentioned that the operations discussed above are primarily used to find sub-graphs, terms, and term/sub-graph relationships that can be used to remove terms and sub-graphs from the original graph. In other words, the manipulations detailed above do not necessarily produce graphs that we want or need to use as the final transformed graph. This came up in conversations about concerns over applying a distributive law, which can lead to less numerically stable expressions. The main point is this: you can easily and cheaply manipulate a graph into an intermediate form for the purposes of obtaining information; however, those intermediate graphs can be discarded once the requisite information is obtained, or further transformations can be applied to produce a more computationally stable form (see the "stabilization" optimizations in Theano/Aesara). |
|
Working on the PR rlouf/mcx#71 for MCX, I can confirm it should absolutely not being considered as a reference :) The implementation is still fairly crude at the moment, and as said by @brandonwillard, Theano has much better graph machinery to implement this optimization in an elegant way. I was more planning on studying your codebase to refine my implementation, more than the other way around! |
Hi, I took a stab at this. You can find what i did Here is a short description on what I did
Let me know what you think. |
From a quick look, it seems like you're on the right track.
In this context, removing a term is really replacing a term.
Take a look at What you're calling If you have questions or run into issues using these Aesara tools, you can post questions here or in the Aesara repo and we can try to clarify. |
@manthehunted This looks really cool, great job! As @brandonwillard said there's some cleaning up that would really help make this more efficient. Let us know if you have problems using these more advanced Aesara tools. |
@twiecki thanx for the encouragement. @brandonwillard Well, now as I think more about this, what I am saying is
Now I feel like making the problem a bit more complicated. Maybe I will just pick one approach, and have you guys give me feedback. |
@brandonwillard Here is the work, and please let me know what you think of it. |
Sorry, I've been very busy with the |
Closing in favor of aesara-devs/aeppl#33 |
Our logp's include normalization terms which are unnecessary for sampling as that is fine with an unnormalized logp. As such, we can optimize the model logp evaluation if we take the Theano computation graph and apply a custom graph optimization that identifies and removes these terms. This should lead to quite a nice speed-up as it saves unnecessary computations.
I think the way to do this would be to traverse the logp graph, break the large sum into individual terms, and throw out any terms that do not depend on
value
nodes.For more information on how to optimize the theano graph, see https://theano-pymc.readthedocs.io/en/latest/extending/optimization.html and https://theano-pymc.readthedocs.io/en/latest/optimizations.html.
The text was updated successfully, but these errors were encountered: