Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling reasoning about rationals in symbolic shapes #128053

Open
ezyang opened this issue Jun 5, 2024 · 24 comments
Open

Handling reasoning about rationals in symbolic shapes #128053

ezyang opened this issue Jun 5, 2024 · 24 comments
Labels
module: dynamic shapes triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented Jun 5, 2024

馃悰 Describe the bug

In an early version of #126905 I attempted to add asserts to our sympy functions. Specifically, for sympy functions that expected to receive integer inputs (like IntTrueDiv), I asserted that the inputs is_integer. Because our fragment of Sympy reasoning follows Python int/float promotion rules (and in particular, there is no data dependence about it), in principle it should be possible to always accurately tell if any given expression is integral or not.

However, there are some cases where we end up generating rational expressions in reasoning. The easiest way to end up with rationals is when you use builtin Sympy division, which aggressively distributes division across addition:

>>> import sympy
>>> from sympy.abc import x, y
>>> (x + y) / 2
x/2 + y/2

When simplifications involving rationals happen, it is easy to lose track that a larger quantity is_integer (and sympy's assumption system is rather weak, so it only does superficial reasoning). In the example above, if x + y is divisible by 2, then x/2 + y/2 is integral, but Sympy would never be able to tell you this.

I always thought rationals were an unexpected and unwelcome style of simplification that I was happy to obliterate from Sympy. And I have mostly eliminated it in #126905 However, the offline solver makes use of sympy's inequality solver, which does not support custom functions (sympy/sympy#26632), and so the offline solver rewrites FloorDiv into sympy division so that it can make use of it. This means rationals can show up, and in particular, means that I can't really do any is_integer asserts.

This issue is to track what we should do about this. We appear to be internally split about it. @lezcano strongly favors stringent invariants, but @avikchaudhuri says:

Hmm, seems hard to do reasoning on ints without allowing rationals. Feels like rationals are to ints what complex numbers are to reals...

Is there a flag we could turn off to temporarily avoid these assertions, understanding that we might allow intermediate rationals to show up, and later check that the solutions would be integers?

One possibility is certainly to have some sort of context manager that turns off the assertions in some context, and enable this context manager in the offline solver. I'm not sure if anyone has better ideas.

Versions

main

@ezyang
Copy link
Contributor Author

ezyang commented Jun 5, 2024

cc @shazqadeer

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 6, 2024
@bhack
Copy link
Contributor

bhack commented Jun 8, 2024

Is this related also to these recompilations or not?

V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles] Recompiling function forward in /workspace/networks/encoders/swin/swin_transformer.py:425
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles]     triggered by the following guard failure(s):
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles]     - Eq(IntTrueDiv(L['H'], 7), 19.4285714285714)                   # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles]     - L['H'] == 272                                               
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles]     - Eq(IntTrueDiv(L['H'], 7), 9.71428571428571)                   # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles]     - Eq(IntTrueDiv(L['H'], 7), 19.4285714285714)                   # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles]     - Eq(IntTrueDiv(L['H'], 7), 9.71428571428571)                   # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.560000 138203450480448 torch/_dynamo/guards.py:2610] [6/6] [__recompiles]     - Eq(IntTrueDiv(L['H'], 7), 19.4285714285714)                   # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles] Recompiling function torch_dynamo_resume_in_forward_at_433 in /workspace/networks/encoders/swin/swin_transformer.py:433
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles]     triggered by the following guard failure(s):
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles]     - Eq(IntTrueDiv(L['W'], 7), 19.4285714285714)                   # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles]     - L['W'] == 272                                               
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles]     - Eq(IntTrueDiv(L['W'], 7), 9.71428571428571)                   # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles]     - Eq(IntTrueDiv(L['W'], 7), 19.4285714285714)                   # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles]     - Eq(IntTrueDiv(L['W'], 7), 9.71428571428571)                   # _dynamo/output_graph.py:451 in init_ambient_guards
V0608 00:37:52.636000 138203450480448 torch/_dynamo/guards.py:2610] [7/6] [__recompiles]     - Eq(IntTrueDiv(L['W'], 7), 19.4285714285714)                   # _dynamo/output_graph.py:451 in init_ambient_guards

@lezcano
Copy link
Collaborator

lezcano commented Jun 8, 2024

to see why it's recompiling, run it with TORCH_LOGS=explain

@bhack
Copy link
Contributor

bhack commented Jun 8, 2024

Do we have explain?

Valid settings:
all, dynamo, aot, autograd, inductor, dynamic, torch, distributed, c10d, ddp, pp, fsdp, onnx, export, compiled_autograd_verbose, trace_call, bytecode, aot_joint_graph, fusion, custom_format_test_artifact, output_code, trace_source, aot_graphs, recompiles, onnx_diagnostics, not_implemented, verbose_guards, graph, graph_sizes, kernel_code, sym_node, graph_code, ddp_graphs, graph_breaks, post_grad_graphs, cudagraphs, perf_hints, trace_bytecode, guards, compiled_autograd, schedule, recompiles_verbose, overlap

@lezcano
Copy link
Collaborator

lezcano commented Jun 8, 2024

sorry , I meant TORCH_LOGS=recompiles, but it's the one you have used.
In your logs, it's clear that H / 7 != 9.71428571428571 and H / 7 != 19.4285714285714 for H = 272. Same for W, so it sounds reasonable that it's recompiling (and probably tracking it with a symint to not recompile later), no?

@bhack
Copy link
Contributor

bhack commented Jun 8, 2024

Probably, but I am trying to investigate why the compilation never end cause I have only this message and this other one at #127677 (comment)

@ezyang
Copy link
Contributor Author

ezyang commented Jun 8, 2024

OK, here's another interesting application of reasoning about rationals.

Suppose we have s0 in [4, 10] and s0 = s1 * 2 (but s1 is otherwise unbounded). We would like to replace s0 with s1 * 2, but if we do so naively, we will lose knowledge that s0 in [4, 10]. We would like to update the bounds on s1. One way to do this is solve for s1 in terms of s0. However, the solution for this equation is s1 = s0 / 2: rationals! In #126905 I made these solutions illegal, and as a result also made it impossible to do replacements when the thing being replaced had min/max bounds.

The fix is to temporarily allow for rational compute, and then when you're all done requantize back to integers. So maybe this is a good reason to allow for rationals.

@shazqadeer
Copy link
Contributor

@ezyang : Your example above is interesting. If you did allow for reasoning with rationals, would the upper and lower bounds of value ranges for ints become rational numbers?

@shazqadeer
Copy link
Contributor

@ezyang : I believe I have your permission to ask random questions on your PRs and issues to fill background gaps. If not, ignore this message.

The PR summary refers to an "offline solver". Can you provide some context on the need for this offline solver and how it works?

@ezyang
Copy link
Contributor Author

ezyang commented Jun 10, 2024

If you did allow for reasoning with rationals, would the upper and lower bounds of value ranges for ints become rational numbers?

You can end up computing a rational upper/lower bound, but it's always OK to round them, since if x in [1/2, 3/2], and x is integer, then actually it could only ever be 1.

@ezyang
Copy link
Contributor Author

ezyang commented Jun 10, 2024

The PR summary refers to an "offline solver". Can you provide some context on the need for this offline solver and how it works?

The offline solver is all the code in DimConstraints. It's offline because we assume we've collected all of the guards and then can solve them all together. The online solver in ShapeEnv has to be able to answer queries as we go, since symbolic evaluation cannot proceed without answers.

@lezcano
Copy link
Collaborator

lezcano commented Jun 10, 2024

Suppose we have s0 in [4, 10] and s0 = s1 * 2 (but s1 is otherwise unbounded). [...]

I mean, s1 = s0 / 2 is a rationa, but you can represent it with your SymPy functions, which is completely fine, right?

@ezyang
Copy link
Contributor Author

ezyang commented Jun 10, 2024

I was going to say "but the division is generated by sympy solver" but we already reimplemented the solver so I can fix it.

But I'm not too sure what the rounding behavior should be if I turn this into an integer division. Let's suppose I have s0 in [3, 11] and s0 = s1 * 2. The desired refined range for s1 is [2, 5]. If I solve s1 = s0 // 2, I will get a less accurate s1 in [1, 5]. I need distinct rounding behavior for the upper and lower bound.

@lezcano
Copy link
Collaborator

lezcano commented Jun 10, 2024

To get tight bounds you want to map [lower, upper] into [CeilDiv(lower, div), FloorDiv(upper, div)].

@lezcano
Copy link
Collaborator

lezcano commented Jun 10, 2024

actually that's not true, you first need to potentially normalise multiplying the range (i.e. potentially flip it) by the sign of div to make sure div is positive, but yeah.

If div is a range, you first make sure that 0 is not in the div range, and then compute as above.

@ezyang
Copy link
Contributor Author

ezyang commented Jun 10, 2024

My point is that, the way the reverse propagation is currently written, there's no opportunity to directly remap the ValueRange. We have equation s0 = s1 * 2, we solve s1 = s0 / 2, and then we run value ranges prop on the new expression. Concretely, to do what you suggest, I need to introduce some new weird div operator s1 = MyDiv(s0, 2) that has the correct value ranges formula you described and... I'm not even sure what its runtime semantics are, maybe this is CleanDiv (where we guarantee that the output is integral?) Or are you suggesting that we stop using try_solve for reverse propagation and do it some other way?

@lezcano
Copy link
Collaborator

lezcano commented Jun 10, 2024

So, there are a few possibilities here. The first one, which should be quite uncontroversial, is to check that lower % div == 0 and upper % div == 0. In this case, every division agrees, and you can say s1 = FloorDiv(s0, div) and compute rg_s1 = cls.floordiv(rg_s0, div) (or using TruncDiv if you know that you are going to generate this in a kernel with C semantics).

I would assume that most of the uses would fall in this case.

Then, for the general case, we could have an op handler that coerces a float range into an int range as described above. We could describe this as Coerce(FloatTrueDiv(s0, 2)). Then, after the solve, we'd need to propagate the ranges from s1 to s0 using the initial equation s0 = s1 * 2, and finally turn Coerce o FloatTrueDiv into FloorDiv (or TruncDiv), now that we have simplified the problem to the previous one.

@ezyang
Copy link
Contributor Author

ezyang commented Jun 10, 2024

While I agree your first proposal is sound, it feels very edge casey to me. It also changes some of our layering; right now our try_solve is value range oblivious, but for your proposal, we would need to consult the value ranges to determine if we could do a replacement.

For the general solution, there are problems. First, I don't want to use a FloatTrueDiv, because s0 and 2 are not floats, they are integers, and to satisfy typing I'd have to coerce them to floats too. And now this looks very suspicious: why am I doing floating point division for the reasoning here? In particular, I can now torture you with something like my integer is out of the exactly representable floating range? Like, I know this is not going to happen in practice, but if we just want something that works, doing rationals works too!

@lezcano
Copy link
Collaborator

lezcano commented Jun 10, 2024

I mean, I'm team rationals, so you don't have to convince me on the "rationals work" front :P In particular, what about implementing the general solution but using sympy.Div instead of FloatTrueDiv? In particular, we'd temporarily have a Div, but we'd turn it straight away into one of our divs, so life is good?

@ezyang
Copy link
Contributor Author

ezyang commented Jun 10, 2024

But you're also on team asserts, and I'm giving this up here. The rational logic with sympy.Div solution is what has landed to main. This works great... I just can't assert that stuff is_integer now, because it might not. In particular, the asserts cannot live in eval, because even if the div is temporary, eval would get run immediately when I create the expression.

@lezcano
Copy link
Collaborator

lezcano commented Jun 10, 2024

Yeah, eval running eagerly and not having any sort of global information means that it may not be the best place to put the asserts. We might want to have a pass that simplifies the expressions using global information (VRs and otherwise) and then, after that, assert that the expressions have certain postconditions.

I'm not sure whether we have a place like that at the moment in the code, but would certainly be desirable.

@ezyang
Copy link
Contributor Author

ezyang commented Jun 10, 2024

A simple place we can do asserts is in guard accumulation, since this is where expressions persist. Otherwise, we can do it on SymNode creation. But I am worried about the latter as the invariant test is recursive so you will keep iterating into the same structs over and over again.

@shazqadeer
Copy link
Contributor

The PR summary refers to an "offline solver". Can you provide some context on the need for this offline solver and how it works?

The offline solver is all the code in DimConstraints. It's offline because we assume we've collected all of the guards and then can solve them all together. The online solver in ShapeEnv has to be able to answer queries as we go, since symbolic evaluation cannot proceed without answers.

Thanks. Why is an offline solver needed?

@ezyang
Copy link
Contributor Author

ezyang commented Jun 11, 2024

Export has a feature which is that if you specify dimensions as dynamic, but they are not fully dynamic due to guards, it will try to compute the simplest set of constraints you could specify to precisely specify what extra constraints your guards have imposed. Also cc @avikchaudhuri

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants