Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] nn.Module int/float attributes caused re-compile #115711

Closed
yanboliang opened this issue Dec 12, 2023 · 10 comments
Closed

[torch.compile] nn.Module int/float attributes caused re-compile #115711

yanboliang opened this issue Dec 12, 2023 · 10 comments
Assignees
Labels
dynamo-must-fix These bugs affect TorchDynamo reliability. enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: dynamic shapes module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yanboliang
Copy link
Contributor

yanboliang commented Dec 12, 2023

馃殌 The feature, motivation and pitch

Repro:

import torch

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(4, 4)
        self.step = 0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + 1
        self.step += 1
        return self.layer(x)

m = MyModule()
opt_m = torch.compile(backend="eager")(m)

x = torch.randn(3, 4)
print(opt_m(x))

x = torch.randn(4, 4)
print(opt_m(x))

x = torch.randn(5, 4)
print(opt_m(x))

In this example, self.step is a constant attribute of nn.Module. Though it's not part of FX graph, Dynamo would produce guard for this and triggers re-compilation since this is used as a counter as model iteration goes.
Probably we can wrap the int/float attributes of nn.Module as torch.SymInt/Float to avoid guard failures and re-compilations, but only do this when they are not used in control flow.

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng @jansel

Alternatives

No response

Additional context

No response

@voznesenskym
Copy link
Collaborator

(regurgitating offline discussion)

Thanks for making the issue! The way we handle this today is via residuals. The TLDR of that is int->ConstantVariable->codegen

             22 LOAD_CONST               4 (k)
             24 LOAD_FAST                0 (self)
             26 STORE_ATTR               0 (step)

where k is a placeholder for some burned in static value - this is annoying, because we have essentially specialized the entire frame on that one value - and it doesn't actually matter what it is! It could ostensibly be rewritten as bytecode to just call self.step += 1

Now, this requires some analysis to get right, because if self.step is used anywhere as an input to an op, or control flow, it can influence how other things are executed:

Changing the return of the example program above to:

        if self.step % 2 == 0:
            return x.cos()
        return self.layer(x)

Move this guard from annoying to required for soundness!

My proposal here is that we look at replacing

int->ConstantVariable
with
int->SymInt(via shape_env)->SymNodeVariable AND teach dynamo residuals to generate code for dynamic shape updates

This will be a little tricky, and require the automatic_dynamic approach to these ints, but it should allow us to remove specializations on arbitrary integers floating through the program in favor of intelligent guards.

In the proposal, @yanboliang's example program would produce a guard on self.step that admits any value, and in my updated example, would produce a guard like self.step % 2 == 0 (and != for the other set of cases), both of which are a strict improvement over recompiling for every single fixed integer value we see in step.

@jansel
Copy link
Contributor

jansel commented Dec 13, 2023

I think we can treat integers stored on nn.Modules as dynamic. I wouldn't replace all constants though.

@yanboliang
Copy link
Contributor Author

I think we can treat integers stored on nn.Modules as dynamic. I wouldn't replace all constants though.

Agreed, I think it's hard to replace all, just like we are not easy to enable dynamic shape for all cases.

@yanboliang yanboliang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 13, 2023
@voznesenskym
Copy link
Collaborator

I think we can treat integers stored on nn.Modules as dynamic. I wouldn't replace all constants though.

Agreed, I think it's hard to replace all, just like we are not easy to enable dynamic shape for all cases.

If its automatic_dynamic, this should strictly be better. I don't understand why we would split it on nn.Modules vs others.

@bdhirsh
Copy link
Contributor

bdhirsh commented Dec 13, 2023

This will be a little tricky, and require the automatic_dynamic approach to these ints

Only partially related to this issue - but at some point do we want to consider allowing users to call torch.mark_dynamic(self.step) (aka support mark_dynamic for scalars, not just tensors?). That seem like a nice parity to have between mark_dynamic vs. automatic-dynamic-shapes.

FWIW: if a user has a scalar value that flows as an input to their model, for the most part, wrapping that scalar in a torch.Tensor() is enough for them (avoids specializing, prevents the compiler from guarding on that value anywhere). But that still requires a model change from the user, and using a SymInt seems like it's probably strictly better (what if you want to specialize on my_value > 10? you need a hint).

@yanboliang
Copy link
Contributor Author

If its automatic_dynamic, this should strictly be better. I don't understand why we would split it on nn.Modules vs others.

I think if we make arbitrary int/float as dynamic would be easily trigger long tail coverage issue, I remember that's the major reason we roll back from all dynamic shapes to automatic dynamic shapes.

@yanboliang
Copy link
Contributor Author

@bdhirsh Agreed on this direction, probably we can follow this order: 1/ Make integers on nn.Module as dynamic; 2/ Support calling torch.mark_dynamic for arbitrary scalars; 3/ Dynamic scalars automatically. I think the underlying implementation is the same thing.

@ezyang
Copy link
Contributor

ezyang commented Jan 1, 2024

The reason you can't easily mark_dynamic an int is because mark_dynamic operates by mutating an attribute on the Tensor it's given, but ints don't have attributes in the traditional sense and cannot be mutated this way. So you would need some sort of wrapper class that replaces the int, but then you have to make sure it looks enough like an int so it doesn't break user code. It probably involves less user code changes if the NN module manages which attributes are dynamic or not.

@voznesenskym
Copy link
Collaborator

@bdhirsh Agreed on this direction, probably we can follow this order: 1/ Make integers on nn.Module as dynamic; 2/ Support calling torch.mark_dynamic for arbitrary scalars; 3/ Dynamic scalars automatically. I think the underlying implementation is the same thing.

I don't see why we wouldn't just do 3 to start?

@voznesenskym
Copy link
Collaborator

If its automatic_dynamic, this should strictly be better. I don't understand why we would split it on nn.Modules vs others.

I think if we make arbitrary int/float as dynamic would be easily trigger long tail coverage issue, I remember that's the major reason we roll back from all dynamic shapes to automatic dynamic shapes.

No, the major reason we did this was that if we made everything automatic, it would have too many guards. Note that my proposal is not full dynamism for all constants, but rather, full automatic dynamism for all constants. See my original post in this thread, it describes the approach.

@anijain2305 anijain2305 added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix dynamo-must-fix These bugs affect TorchDynamo reliability. labels Feb 1, 2024
ZelboK pushed a commit to ZelboK/pytorch that referenced this issue May 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo-must-fix These bugs affect TorchDynamo reliability. enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: dynamic shapes module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants