-
Notifications
You must be signed in to change notification settings - Fork 25.6k
init tls grad_mode/local_dispatch_key set while fork new thread in #113246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113246
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f7f34af with merge base 3ff4572 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
throughputbenchmark [ghstack-poisoned]
…hread in" TorchDynamo will guard grad_mode and the local dispatch key set. https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/dynamo/guards.cpp#L13-L16 While using ThroughputBenchmark, those tls state will not be init as same as the main thread status. https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/utils/throughput_benchmark-inl.h#L64-L94 Run following scripts ``` import torch linear = torch.nn.Linear(128, 128) compiled = torch.compile(linear) x = torch.rand(10, 128) with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): compiled(x) compiled(x) from torch._dynamo import config config.error_on_recompile = True from torch.utils import ThroughputBenchmark with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): bench = ThroughputBenchmark(compiled) bench.add_input(x) stats = bench.benchmark( num_calling_threads=10, num_warmup_iters=100, num_iters=100, ) print(stats) ``` will lead to 2 re-compile reasons: ``` triggered by the following guard failure(s): ___check_global_state() triggered by the following guard failure(s): tensor 'x' dispatch key set mismatch. ``` This will trigger a re-compile in torchdynamo. But since `ThroughputBenchmark` is used for sharing weight within threads, the model should not be changed anymore while running the benchmark. So this PR is to init the tls state as same as main thread. Then we can use ` ThroughputBenchmark` to run torchdynamo optimized models. throughputbenchmark [ghstack-poisoned]
@pytorchbot merge |
Merge failedReason: Approval needed from one of the following: |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
…hread in" TorchDynamo will guard grad_mode and the local dispatch key set. https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/dynamo/guards.cpp#L13-L16 While using ThroughputBenchmark, those tls state will not be init as same as the main thread status. https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/utils/throughput_benchmark-inl.h#L64-L94 Run following scripts ``` import torch linear = torch.nn.Linear(128, 128) compiled = torch.compile(linear) x = torch.rand(10, 128) with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): compiled(x) compiled(x) from torch._dynamo import config config.error_on_recompile = True from torch.utils import ThroughputBenchmark with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): bench = ThroughputBenchmark(compiled) bench.add_input(x) stats = bench.benchmark( num_calling_threads=10, num_warmup_iters=100, num_iters=100, ) print(stats) ``` will lead to 2 re-compile reasons: ``` triggered by the following guard failure(s): ___check_global_state() triggered by the following guard failure(s): tensor 'x' dispatch key set mismatch. ``` This will trigger a re-compile in torchdynamo. But since `ThroughputBenchmark` is used for sharing weight within threads, the model should not be changed anymore while running the benchmark. So this PR is to init the tls state as same as main thread. Then we can use ` ThroughputBenchmark` to run torchdynamo optimized models. throughputbenchmark [ghstack-poisoned]
Successfully rebased |
Hi, @iseeyuan, @chenyang78 , @atalman May you help to review this PR, the mergebot shows you are the owner. |
…hread in" TorchDynamo will guard grad_mode and the local dispatch key set. https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/dynamo/guards.cpp#L13-L16 While using ThroughputBenchmark, those tls state will not be init as same as the main thread status. https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/utils/throughput_benchmark-inl.h#L64-L94 Run following scripts ``` import torch linear = torch.nn.Linear(128, 128) compiled = torch.compile(linear) x = torch.rand(10, 128) with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): compiled(x) compiled(x) from torch._dynamo import config config.error_on_recompile = True from torch.utils import ThroughputBenchmark with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): bench = ThroughputBenchmark(compiled) bench.add_input(x) stats = bench.benchmark( num_calling_threads=10, num_warmup_iters=100, num_iters=100, ) print(stats) ``` will lead to 2 re-compile reasons: ``` triggered by the following guard failure(s): ___check_global_state() triggered by the following guard failure(s): tensor 'x' dispatch key set mismatch. ``` This will trigger a re-compile in torchdynamo. But since `ThroughputBenchmark` is used for sharing weight within threads, the model should not be changed anymore while running the benchmark. So this PR is to init the tls state as same as main thread. Then we can use ` ThroughputBenchmark` to run torchdynamo optimized models. throughputbenchmark [ghstack-poisoned]
Hi, @desertfire May you help to review this PR? |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
…hread in" TorchDynamo will guard grad_mode and the local dispatch key set. https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/dynamo/guards.cpp#L13-L16 While using ThroughputBenchmark, those tls state will not be init as same as the main thread status. https://github.com/pytorch/pytorch/blob/3a429423fcf72430e7a36c79e263c877d7a4ef72/torch/csrc/utils/throughput_benchmark-inl.h#L64-L94 Run following scripts ``` import torch linear = torch.nn.Linear(128, 128) compiled = torch.compile(linear) x = torch.rand(10, 128) with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): compiled(x) compiled(x) from torch._dynamo import config config.error_on_recompile = True from torch.utils import ThroughputBenchmark with torch.no_grad(), torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): bench = ThroughputBenchmark(compiled) bench.add_input(x) stats = bench.benchmark( num_calling_threads=10, num_warmup_iters=100, num_iters=100, ) print(stats) ``` will lead to 2 re-compile reasons: ``` triggered by the following guard failure(s): ___check_global_state() triggered by the following guard failure(s): tensor 'x' dispatch key set mismatch. ``` This will trigger a re-compile in torchdynamo. But since `ThroughputBenchmark` is used for sharing weight within threads, the model should not be changed anymore while running the benchmark. So this PR is to init the tls state as same as main thread. Then we can use ` ThroughputBenchmark` to run torchdynamo optimized models. throughputbenchmark [ghstack-poisoned]
Successfully rebased |
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
TorchDynamo will guard grad_mode and the local dispatch key set.
pytorch/torch/csrc/dynamo/guards.cpp
Lines 13 to 16 in 3a42942
While using ThroughputBenchmark, those tls state will not be init as same as the main thread status.
pytorch/torch/csrc/utils/throughput_benchmark-inl.h
Lines 64 to 94 in 3a42942
Run following scripts
will lead to 2 re-compile reasons:
This will trigger a re-compile in torchdynamo. But since
ThroughputBenchmark
is used for sharing weight within threads, the model should not be changed anymore while running the benchmark. So this PR is to init the tls state as same as main thread. Then we can useThroughputBenchmark
to run torchdynamo optimized models.Stack from ghstack (oldest at bottom):
throughputbenchmark