Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dynamo] nn parameterization causing increased compile time #125314

Closed
williamwen42 opened this issue May 1, 2024 · 5 comments
Closed

[dynamo] nn parameterization causing increased compile time #125314

williamwen42 opened this issue May 1, 2024 · 5 comments
Assignees
Labels
module: dynamo module: startup-tracing-compile Compilation mechanism or time spent in (re)compilation, tracing, startup oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@williamwen42
Copy link
Member

williamwen42 commented May 1, 2024

Internal link: https://fb.workplace.com/groups/1075192433118967/permalink/1421189078519299/

Minimal repro:

"""
torch==2.3.0.dev20240226+cu121
parametrize=True: 4005ms
parametrize=False: 1285ms
parametrize=True: 1848ms
parametrize=False: 1417ms

torch==2.4.0.dev20240415+cu118
parametrize=True: 18487ms
parametrize=False: 1064ms
parametrize=True: 17213ms
parametrize=False: 1064ms
"""

import torch
from torch import nn
import time

class Module(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.seq = nn.Sequential(*[
            nn.Linear(128, 128, bias=True)
            for _ in range(32)
        ])

    def forward(self, x):
        return self.seq(x)

class Parametrization(torch.nn.Module):
    def forward(self, x):
        return x.half()

    def right_inverse(self, x):
        return x.float()

def parametrize(model: nn.Module):
    mods = list(model.modules())
    for mod in mods:
        params = list(mod._parameters.items())
        for name, p in params:
            if p is not None:
                torch.nn.utils.parametrize.register_parametrization(mod, name, Parametrization(), unsafe=True)

x = torch.randn([1, 128], device="cuda", dtype=torch.half)

# 2 options:
print(f"torch=={torch.__version__}")
for USE_PARAMETRIZATION in [True, False, True, False]:
    torch._dynamo.reset()
    m = Module().cuda()
    if USE_PARAMETRIZATION:
        parametrize(m)
    else:
        m.half()
    m(x)

    m.compile()

    t = time.time()
    m(x)
    dt = time.time() - t
    print(f"parametrize={USE_PARAMETRIZATION}: {int(1000 * dt)}ms")

Bisected to find the following commits responsible: #121041 (~200ms -> ~4500ms locally) and #123804 (~4500ms -> ~20000ms locally).

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

@williamwen42 williamwen42 added oncall: pt2 module: dynamo module: startup-tracing-compile Compilation mechanism or time spent in (re)compilation, tracing, startup labels May 1, 2024
@williamwen42 williamwen42 self-assigned this May 1, 2024
@yanboliang yanboliang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 2, 2024
@williamwen42 williamwen42 changed the title get_instructions_bytes appears to be causing increased compile time nn parameterization causing increased compile time May 3, 2024
@williamwen42
Copy link
Member Author

cc @jbschlosser @jansel for initial ideas on getting compile times back down.

@jansel
Copy link
Contributor

jansel commented May 4, 2024

#123804 was a correctness fix -- though maybe we could allowlist away whatever property is triggering the extra compile time.

@jbschlosser
Copy link
Contributor

Well one quick fix would be to not restart analysis for parametrizations and instead skip tracing into them entirely, as I did in the original impl of #121041. But it'd be better to figure out how to fully inline them.

though maybe we could allowlist away whatever property is triggering the extra compile time.

The parametrization mechanism injects a @property in place of a parametrized module parameter so those would be ones to allowlist.

@williamwen42 williamwen42 changed the title nn parameterization causing increased compile time [dynamo] nn parameterization causing increased compile time May 7, 2024
@ezyang
Copy link
Contributor

ezyang commented May 8, 2024

As we are recommending parametrization as the way to solve serialization for tensor subclasses, it is much, much better to support them.

pytorchmergebot pushed a commit that referenced this issue May 8, 2024
Workaround for #125314 and #125478.

We no longer make parametrized nn.Modules unspecialized. Instead, when we are about to call a function from the `torch.nn.utils.parametrize` module, we skip the frame.

The script from #125314 now outputs
```
parametrize=True: 6587ms
parametrize=False: 1729ms
parametrize=True: 4497ms
parametrize=False: 1539ms
```

Pull Request resolved: #125710
Approved by: https://github.com/jansel, https://github.com/jbschlosser
@williamwen42
Copy link
Member Author

We now trace through nn module parametrization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo module: startup-tracing-compile Compilation mechanism or time spent in (re)compilation, tracing, startup oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants