-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.compile] nn.Module int/float attributes caused re-compile #115711
Comments
(regurgitating offline discussion) Thanks for making the issue! The way we handle this today is via residuals. The TLDR of that is
where Now, this requires some analysis to get right, because if self.step is used anywhere as an input to an op, or control flow, it can influence how other things are executed: Changing the return of the example program above to:
Move this guard from annoying to required for soundness! My proposal here is that we look at replacing
This will be a little tricky, and require the automatic_dynamic approach to these ints, but it should allow us to remove specializations on arbitrary integers floating through the program in favor of intelligent guards. In the proposal, @yanboliang's example program would produce a guard on self.step that admits any value, and in my updated example, would produce a guard like |
I think we can treat integers stored on |
Agreed, I think it's hard to replace all, just like we are not easy to enable dynamic shape for all cases. |
If its automatic_dynamic, this should strictly be better. I don't understand why we would split it on nn.Modules vs others. |
Only partially related to this issue - but at some point do we want to consider allowing users to call FWIW: if a user has a scalar value that flows as an input to their model, for the most part, wrapping that scalar in a |
I think if we make arbitrary int/float as dynamic would be easily trigger long tail coverage issue, I remember that's the major reason we roll back from all dynamic shapes to automatic dynamic shapes. |
@bdhirsh Agreed on this direction, probably we can follow this order: 1/ Make integers on |
The reason you can't easily |
I don't see why we wouldn't just do 3 to start? |
No, the major reason we did this was that if we made everything automatic, it would have too many guards. Note that my proposal is not full dynamism for all constants, but rather, full automatic dynamism for all constants. See my original post in this thread, it describes the approach. |
Fixes pytorch#115711 Pull Request resolved: pytorch#126466 Approved by: https://github.com/jansel
Fixes #115711 Pull Request resolved: #126466 Approved by: https://github.com/jansel
馃殌 The feature, motivation and pitch
Repro:
In this example,
self.step
is a constant attribute ofnn.Module
. Though it's not part of FX graph, Dynamo would produce guard for this and triggers re-compilation since this is used as a counter as model iteration goes.Probably we can wrap the int/float attributes of
nn.Module
astorch.SymInt/Float
to avoid guard failures and re-compilations, but only do this when they are not used in control flow.cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng @jansel
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: