Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Theano graph optimization to remove normalizing constants from logp #4396

Closed
twiecki opened this issue Dec 31, 2020 · 26 comments
Closed

Comments

@twiecki
Copy link
Member

twiecki commented Dec 31, 2020

Our logp's include normalization terms which are unnecessary for sampling as that is fine with an unnormalized logp. As such, we can optimize the model logp evaluation if we take the Theano computation graph and apply a custom graph optimization that identifies and removes these terms. This should lead to quite a nice speed-up as it saves unnecessary computations.

I think the way to do this would be to traverse the logp graph, break the large sum into individual terms, and throw out any terms that do not depend on value nodes.

For more information on how to optimize the theano graph, see https://theano-pymc.readthedocs.io/en/latest/extending/optimization.html and https://theano-pymc.readthedocs.io/en/latest/optimizations.html.

@twiecki
Copy link
Member Author

twiecki commented Dec 31, 2020

I wrote a small NB to set up the environment to start experimenting with this: https://gist.github.com/twiecki/e758db2c3d2df5f3368fc49e6087e58f

@brandonwillard
Copy link
Contributor

brandonwillard commented Jan 1, 2021

Here's an implementation that's nearly 75% of the way there:

import numpy as np
import theano.tensor as tt
from theano import config
from theano.compile import optdb
from theano.gof.fg import FunctionGraph
from theano.gof.graph import inputs as tt_inputs
from theano.gof.opt import EquilibriumOptimizer, PatternSub
from theano.gof.optdb import Query
from theano.printing import debugprint as tt_dprint
from theano.tensor.opt import get_clients

# We don't need to waste time compiling graphs to C
config.cxx = ""


# a / b -> a * 1/b, for a != 1 and b != 1
div_to_mul_pattern = PatternSub(
    (tt.true_div, "a", "b"),
    (tt.mul, "a", (tt.inv, "b")),
    allow_multiple_clients=True,
    name="div_to_mul",
    tracks=[tt.true_div],
    get_nodes=get_clients,
)

# a - b -> a + (-b)
sub_to_add_pattern = PatternSub(
    (tt.sub, "a", "b"),
    (tt.add, "a", (tt.neg, "b")),
    allow_multiple_clients=True,
    name="sub_to_add",
    tracks=[tt.sub],
    get_nodes=get_clients,
)

# a * (x + y) -> a * x + a * y
distribute_mul_pattern = PatternSub(
    (tt.mul, "a", (tt.add, "x", "y")),
    (tt.add, (tt.mul, "a", "x"), (tt.mul, "a", "y")),
    allow_multiple_clients=True,
    name="distribute_mul",
    tracks=[tt.mul],
    get_nodes=get_clients,
)

expand_opt = EquilibriumOptimizer(
    [div_to_mul_pattern, distribute_mul_pattern, sub_to_add_pattern],
    ignore_newtrees=False,
    tracks_on_change_inputs=True,
    max_use_ratio=config.optdb__max_use_ratio,
)


def optimize_graph(fgraph, include=["canonicalize"], custom_opt=None, **kwargs):
    if not isinstance(fgraph, FunctionGraph):
        inputs = tt_inputs([fgraph])
        fgraph = FunctionGraph(inputs, [fgraph], clone=False)

    canonicalize_opt = optdb.query(Query(include=include, **kwargs))
    _ = canonicalize_opt.optimize(fgraph)

    if custom_opt:
        custom_opt.optimize(fgraph)

    return fgraph


tau = tt.dscalar("tau")
value = tt.dscalar("value")
mu = tt.dscalar("mu")

logp = (-tau * (value - mu) ** 2 + tt.log(tau / np.pi / 2.0)) / 2.0

logp_fg = optimize_graph(logp, custom_opt=expand_opt)

tt_dprint(logp_fg)

# TODO: Remove additive terms that do not contain the desired terms (e.g. `mu`
# and `tau` when the likelihood is only a function of `mu`, `tau`)

# This is what we want from the optimization
# logp_goal = -tau * (value - mu) ** 2

@twiecki
Copy link
Member Author

twiecki commented Jan 1, 2021

@chandan5362
Copy link
Contributor

Hey @twiecki, I am currently working on this issue. I went through the documentation and I could not find any inbuilt method like replace_validate to remove a node from the FunctionGraph. shall I have to write the usual graph node removal code or is there already an inbuilt method ? If there is, I might have missed that. So, do let me know.
Thanks in advance :)

@twiecki
Copy link
Member Author

twiecki commented Jan 11, 2021

That's great @chandan5362 - this is a super high-impact issue. Good questions, I'm not sure. Probably @brandonwillard knows.

Did you make any progress on the algorithm itself?

@chandan5362
Copy link
Contributor

chandan5362 commented Jan 11, 2021

Did you make any progress on the algorithm itself?

As per my understanding, most of the optimizations have already been done. only things that are left is to remove the node form the FunctionGraph and addressing the same in optdb every time RemoveNormalizingConstants runs until no more node is left to optimize with respect to expand_opt. . I am just looking for a concrete method to delete the unnecessary node.
I might need a little help here @brandonwillard .

@twiecki
Copy link
Member Author

twiecki commented Jan 11, 2021

I think there is still some complexity to figure out. What's not implemented yet is identifying sub-terms that contain value nodes. And you can't just break everything according to additions or subtractions as for example (value - x)**2 still needs to stay together. Does that make sense?

@chandan5362
Copy link
Contributor

Now I guess, I will need a little more information about the nodes that can be removed or can't be. where do we actually need to remove the normalisation term ? Just to have a clear insight, could you please provide a reference to such sampling method ?

@brandonwillard
Copy link
Contributor

brandonwillard commented Jan 11, 2021

The outline in that Gist already covers most of it.

There's more than one way to accomplish this, but, given what's already in the Gist, one can use the MergeOptimizer after expanding all the terms—per the example—and guarantee that all the arguments to Add and Mul are de-duplicated. From there, terms that are present in the arguments of every Mul in the Add and do not contain the arguments of interest can be factored out. Determining whether or not a term contains an argument of interest is simple (albeit rather inefficient): arg in theano.graph.basic.ancestors([term]).

Here's a simple illustration: graph = (sqrt(b) * log(x) + 2**3 * sqrt(b) * log(x - 1)) / a and our argument of interest is x (i.e. graph is proportional to some f(x)).

1.) Apply the existing transforms: inv(a) * sqrt(b) * log(x) + inv(a) * 2**3 * sqrt(b) * log(x - 1). We now have an Add with the arguments a1 = inv(a) * sqrt(b) * log(x) and a2 = inv(a) * 2**3 * sqrt(b) * log(x - 1), i.e. Add(a1, a2). Each argument is a Mul with its own arguments, i.e. a1 = Mul(inv(a), sqrt(b), log(x)) and a2 = Mul(inv(a), 2**3, sqrt(b), log(x - 1)).

2.) Now, we know that we can "gather"/"collect" the inv(a) and sqrt(b) terms, because they appear in both Muls; however, Theano doesn't know that the inv(a) in a1 is the same inv(a) in a2—at least not with ==—and that's why we need to apply the MergeOptimizer. After applying it, a1.inputs[0] is a2.inputs[0] and so on, so we can programmatically determine which terms are present in all the Muls.

3.) Of the terms that are present in all Muls, we can only remove the ones that do not contain x, and that's when we can use an expression like x not in theano.graph.basic.ancestors([term]) for each shared term. Once we've obtained the set of terms that are shared amongst all the Muls and not x, we can remove them or create a new Add/Mul graph that simply omits them.

All of the graph manipulation should be done within the context of an Optimizer class, and any replacements to an existing graph are performed using the methods in FunctionGraph. There are many ways one could approach this using existing Optimizer classes, but creating a GlobalOptimizer from scratch is probably the most conceptually straightforward approach.

@chandan5362
Copy link
Contributor

Thanks for the help @brandonwillard .
So, according to you, what should be the final output for the above expression?
And also the output for the one in the Gist.
According to me, for the example in the Gist, the output should be (value - mu) ** 2 and tau should be omitted as well. What do you say?

@brandonwillard
Copy link
Contributor

Yes, tau doesn't need to be in the final output, just the term (value - mu)**2.

@twiecki
Copy link
Member Author

twiecki commented Jan 16, 2021

Huh, how can we infer the posterior of tau if we throw it out of the logp?

@brandonwillard
Copy link
Contributor

I was under the assumption that value was the "argument" in this example; however, yes, the actual arguments would be the parameters being esimated, e.g. in , since we're essentially trying to produce .

@chandan5362
Copy link
Contributor

chandan5362 commented Jan 18, 2021

If our arguments of interest are mu and tau, then the final output should be -tau*(value-mu)**2 + log(tau). shouldn't be? If it is not so, why did we remove the term tt.log(tau / np.pi / 2.0)) even though it contains tau?

@twiecki
Copy link
Member Author

twiecki commented Jan 22, 2021

@chandan5362 correct, we can only remove terms that are constants.

@twiecki
Copy link
Member Author

twiecki commented Feb 7, 2021

@chandan5362 are you still working on this? need any help?

Here is a JAX implementation doing the same thing: rlouf/mcx#71

@brandonwillard
Copy link
Contributor

Here is a JAX implementation doing the same thing: rlouf/mcx#71

I highly advise against using that as a reference. It looks like a very ground-up implementation that uses a considerable amount of functionality that overlaps with a lot of existing machinery in Theano/Aesara—both in terms of what it's trying to do and how it's doing it. Also, from a quick reading, it looks like it could be doing many of the exact same things discussed here, albeit in a much different context.

While I'm on the topic, I don't think I explicitly mentioned that the operations discussed above are primarily used to find sub-graphs, terms, and term/sub-graph relationships that can be used to remove terms and sub-graphs from the original graph. In other words, the manipulations detailed above do not necessarily produce graphs that we want or need to use as the final transformed graph.

This came up in conversations about concerns over applying a distributive law, which can lead to less numerically stable expressions. The main point is this: you can easily and cheaply manipulate a graph into an intermediate form for the purposes of obtaining information; however, those intermediate graphs can be discarded once the requisite information is obtained, or further transformations can be applied to produce a more computationally stable form (see the "stabilization" optimizations in Theano/Aesara).

@chandan5362
Copy link
Contributor

@chandan5362 are you still working on this? need any help?

Hey @twiecki ,
I got a little off the the track and thought of doing it later. Anyway, I will try to get in asap.
Meanwhile if someone else want to work on it, they can work.

@balancap
Copy link

balancap commented Feb 8, 2021

Working on the PR rlouf/mcx#71 for MCX, I can confirm it should absolutely not being considered as a reference :) The implementation is still fairly crude at the moment, and as said by @brandonwillard, Theano has much better graph machinery to implement this optimization in an elegant way.

I was more planning on studying your codebase to refine my implementation, more than the other way around!

@manthehunted
Copy link

manthehunted commented Mar 31, 2021

Hi, I took a stab at this.

You can find what i did

Here is a short description on what I did

  • I find that the hardest part was to remove a term (also find which term) from addition op. So that's what I implemented. Then a factorization on multiplication op can be performed on a graph that my function produces (function:normalized), which I did not implement.
  • theano/aesara has a lot of functionalities, but i find they are hard to manipulate. So all the functions i wrote are standalone.

Let me know what you think.

@brandonwillard
Copy link
Contributor

From a quick look, it seems like you're on the right track.

  • I find that the hardest part was to remove a term (also find which term) from addition op.

In this context, removing a term is really replacing a term. FunctionGraph.replace is the correct way to do that. It can also be done using aesara.graph.basic.clone_replace, but the two are not the same (e.g. multiple overlapping replacements will not work).

  • theano/aesara has a lot of functionalities, but i find they are hard to manipulate. So all the functions i wrote are standalone.

Take a look at aesara.graph.basic.walk; it does what you're doing with unpack_with and more, but in a much more scalable, non-recursion-based way.

What you're calling Unit is essentially the Apply node that it contains, so I don't quite understand the role of that type.

If you have questions or run into issues using these Aesara tools, you can post questions here or in the Aesara repo and we can try to clarify.

@twiecki
Copy link
Member Author

twiecki commented Apr 5, 2021

@manthehunted This looks really cool, great job! As @brandonwillard said there's some cleaning up that would really help make this more efficient. Let us know if you have problems using these more advanced Aesara tools.

@manthehunted
Copy link

manthehunted commented Apr 6, 2021

@twiecki thanx for the encouragement.

@brandonwillard
Is there any reason why you prefer replacing terms from an existing FunctionGraph to making a new FunctionGraph whose terms are subset of the existing graph? Or an actual implementation doesn't matter as long as it achieves a goal? Because the normalization operation trims terms (node in this case), and I feel it's more efficient to create a new graph rather than modifying the existing graph (ie. FunctionGraph.replace).

Well, now as I think more about this, what I am saying is

  • if number of terms kept < number of terms trimmed, then better to create a new graph
  • otherwise, better to modify an existing graph.

Now I feel like making the problem a bit more complicated. Maybe I will just pick one approach, and have you guys give me feedback.

@manthehunted
Copy link

manthehunted commented Apr 14, 2021

@brandonwillard
Ok, I picked one approach (reconstruct graph rather than remove nodes) without recursive operation (stack instead, ie. aesara.graph.basic.walk and its extention).

Here is the work, and please let me know what you think of it.

@brandonwillard
Copy link
Contributor

Sorry, I've been very busy with the v4 changes and Aesara, so it might take me awhile to review and respond.

@ricardoV94
Copy link
Member

Closing in favor of aesara-devs/aeppl#33

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants