Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unnecessary compilation fails to optimize simple code #125652

Open
youkaichao opened this issue May 7, 2024 · 8 comments
Open

Unnecessary compilation fails to optimize simple code #125652

youkaichao opened this issue May 7, 2024 · 8 comments
Assignees
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@youkaichao
Copy link
Collaborator

youkaichao commented May 7, 2024

馃悰 Describe the bug

A minimal reproducible example:

import torch

class Layer(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.weight = torch.randn((16,))
        self.variance_epsilon = 1e-5

    @torch.compile
    def forward(self, hidden_states, residuals=None):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        mean = hidden_states.mean(-1, keepdim=True)
        variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
        hidden_states = (hidden_states -
                            mean) * torch.rsqrt(variance + self.variance_epsilon)
        hidden_states = self.weight.to(torch.float32) * hidden_states
        return hidden_states.to(input_dtype), residuals

layers = [Layer() for i in range(100)]
hidden_states = torch.randn((32, 16, 16))

for iteration in range(2):
    # simulate a model forward call
    for layer in layers:
        hidden_states, _ = layer(hidden_states)

For these 100 layers, torch.compile will compile the first 64 layers (which is the dynamo size limit for a code object), and the rest layers are not optimized.

However, ideally, we should only have one cache entry, that can be shared for all layers. We don't need to create different cache entry depending on id(self).

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @jansel

Error logs

No response

Minified repro

No response

Versions

pytorch 2.3.0+cu121

@bdhirsh
Copy link
Contributor

bdhirsh commented May 7, 2024

The problem is that dynamo burns the parameters/buffers into the graph that we compile for each layer, forcing them to get specialized.

@anijain2305 recently added an (experimental?) config to avoid that burning in, that you can try by running with TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1.

I actually tried it on your repro, and I (successfully) don't see any recompiles when I turn it on.

@youkaichao
Copy link
Collaborator Author

@bdhirsh Can you explain the rationale of the specialization? I'm quite puzzled here.

@ezyang
Copy link
Contributor

ezyang commented May 8, 2024

This is Animesh's thing, he's working on fixing it

@youkaichao
Copy link
Collaborator Author

@bdhirsh thanks for the information. I suppose that requires several months to be public, right?

@ezyang
Copy link
Contributor

ezyang commented May 9, 2024

@anijain2305 seemed pretty close when we talked about it a week ago

@ezyang
Copy link
Contributor

ezyang commented May 9, 2024

in particular, the flag is already available, you can opt into it and see if it works

@youkaichao
Copy link
Collaborator Author

Thanks for the answer. Setting export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 indeed solves this particular problem.

May I ask why this is not the default? Why do we need to set it manually?

@ezyang
Copy link
Contributor

ezyang commented May 13, 2024

@anijain2305 is working on setting it default on. It currently uncovers a pile of latent bugs that are showing on test suite.

@bdhirsh bdhirsh added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: dynamo labels May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants