-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Enable fused foreach Adam compilation #104121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+127
−14
Closed
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
c1f503f
Changes to fully fuse Adam
mlazos 96c205b
Add compiled optimizer test suite
mlazos e33ff2f
Fix adam change for pow
mlazos e069c26
Merge remote-tracking branch 'origin/main' into mlazos/adam-fused
mlazos 445ab88
Fix test case
mlazos 298a982
Implement internal compile flag + test
mlazos 97d0664
Merge remote-tracking branch 'origin/main' into mlazos/adam-fused
mlazos 1da06d8
Add compiled flag to serdes
mlazos 6f192d4
Overrdie capturable in dynamo
mlazos 3e026ee
Merge branch 'main' into mlazos/adam-fused
mlazos 948f6cc
Ignore asserts handled by the compiler
mlazos 827d5ea
Fix bug in load_state_dict
mlazos c6d3b40
Updated comments
mlazos d2260ea
Added additional commenting and test
mlazos 643e385
Merge branch 'mlazos/adam-fused' of github.com:pytorch/pytorch into m…
mlazos 84402fd
Update torch/optim/adamw.py
mlazos 8769945
Update torch/optim/optimizer.py
mlazos d8e6c45
Update test/inductor/test_compiled_optimizers.py
mlazos 1cb8777
Updated comments
mlazos File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Owner(s): ["module: inductor"] | ||
|
||
import sys | ||
import unittest | ||
|
||
from copy import deepcopy | ||
|
||
import torch | ||
|
||
import torch._inductor | ||
|
||
from torch.testing._internal.common_utils import TestCase | ||
|
||
|
||
aten = torch.ops.aten | ||
|
||
try: | ||
try: | ||
from .test_torchinductor import check_model, check_model_cuda, requires_cuda | ||
except ImportError: | ||
from test_torchinductor import check_model, check_model_cuda, requires_cuda | ||
except (unittest.SkipTest, ImportError) as e: | ||
sys.stderr.write(f"{type(e)}: {e}\n") | ||
if __name__ == "__main__": | ||
sys.exit(0) | ||
raise | ||
|
||
|
||
def make_test(optim_cls, closure=None, **kwargs): | ||
@requires_cuda() | ||
def test_fn(self): | ||
input = torch.ones([10, 10], device="cuda:0") | ||
model_eager = torch.nn.Sequential( | ||
*[torch.nn.Linear(10, 10, device="cuda:0") for _ in range(2)] | ||
mlazos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
model_eager(input).sum().backward() | ||
|
||
input = torch.ones([10, 10], device="cuda:0") | ||
model_compiled = deepcopy(model_eager) | ||
model_compiled(input).sum().backward() | ||
|
||
opt_eager = optim_cls(model_eager.parameters(), **kwargs) | ||
opt_compiled = optim_cls(model_compiled.parameters(), **kwargs) | ||
# run the patcher so that step has the expected structure | ||
torch._dynamo.eval_frame.TorchPatcher.patch() | ||
|
||
# unwrap step to avoid a deliberate graph break due to | ||
# a limitation of functionalization/no_grad detection | ||
# see the [Note on graph break] in optimizer.py | ||
# This ignores the outer _use_grad_if_differentiable wrapper | ||
# and instead manually disables grad before calling step, which is fine | ||
# for now as dynamo does not support differentiable optimizers anyway | ||
step_fn = opt_compiled.step.__wrapped__ | ||
if closure is not None: | ||
|
||
def fn(): | ||
step_fn(opt_compiled, closure) | ||
|
||
else: | ||
|
||
def fn(): | ||
step_fn(opt_compiled) | ||
|
||
with torch.set_grad_enabled(False): | ||
mlazos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
torch.compile(fn, backend="inductor", fullgraph=True)() | ||
opt_eager.step() | ||
|
||
self.assertEqual( | ||
list(model_eager.parameters()), list(model_compiled.parameters()) | ||
) | ||
if self.check_kernel_count: | ||
# currently, we compile the step and the rest of the computation | ||
# separately because the step is a single element tensor | ||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) | ||
|
||
return test_fn | ||
|
||
|
||
class CompiledOptimizerTests(TestCase): | ||
check_model_cuda = check_model_cuda | ||
check_model_cpu = check_model | ||
check_kernel_count = True | ||
|
||
def setUp(self): | ||
super().setUp() | ||
torch._inductor.metrics.reset() | ||
|
||
def tearDown(self): | ||
super().tearDown() | ||
torch._inductor.metrics.reset() | ||
|
||
test_adam = make_test(torch.optim.Adam, lr=0.01) | ||
test_adam_weight_decay = make_test(torch.optim.Adam, lr=0.01, weight_decay=0.01) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mlazos marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.