- 
                Notifications
    You must be signed in to change notification settings 
- Fork 25.7k
[torchdynamo] Use ProcessPoolExecutor for triton compiles #87032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch significantly improves the parallel compilation performance for cThis patch significantly improves the parallel compilation performance for compiling triton kernels by using ProcessPoolExecutor to create persistent pool of compilation workers. Previously os.fork overhead and GIL contention limited the achieved parallelism. This patch replaces the worker threads with a pool of processes to do the raw compilation, and does serial work on the main thread for everything else. This other work couldn't be parallelized anyway since it is mostly in python. In cold start situations, the time to get the worker threads started can be significant portion of the time. This patch starts the workers earlier so they are ready to perform compilation (see code comments) when dynamo gets to that point. Just tested this on one example benchmark (tf_efficientnet_b0), but the results are significant, almost eliminating the difference between a warm and cold compilation. ``` 39.613s - warm 41.290s - cold, this patch 2m53.197s - cold, single threaded: 1m7.092s - cold, old setup n = 8 (its best config) ``` (cold compilation is done after running `rm -rf /tmp/torchinductor_$USER`).ompiling triton kernels by using ProcessPoolExecutor to create persistent pool of compilation workers. Previously os.fork overhead and GIL contention limited the achieved parallelism. This patch replaces the worker threads with a pool of processes to do the raw compilation, and does serial work on the main thread for everything else. This other work couldn't be parallelized anyway since it is mostly in python. In cold start situations, the time to get the worker threads started can be significant portion of the time. This patch starts the workers earlier so they are ready to perform compilation (see code comments) when dynamo gets to that point. Just tested this on one example benchmark (tf_efficientnet_b0), but the results are significant, almost eliminating the difference between a warm and cold compilation. ``` 39.613s - warm 41.290s - cold, this patch 2m53.197s - cold, single threaded: 1m7.092s - cold, old setup n = 8 (its best config) ``` (cold compilation is done after running `rm -rf /tmp/torchinductor_$USER`). [ghstack-poisoned]
| 🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87032
 Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit ba930e5: The following jobs have failed:
 This comment was automatically generated by Dr. CI and updates every 15 minutes. | 
This patch significantly improves the parallel compilation performance for cThis patch significantly improves the parallel compilation performance for compiling triton kernels by using ProcessPoolExecutor to create persistent pool of compilation workers. Previously os.fork overhead and GIL contention limited the achieved parallelism. This patch replaces the worker threads with a pool of processes to do the raw compilation, and does serial work on the main thread for everything else. This other work couldn't be parallelized anyway since it is mostly in python. In cold start situations, the time to get the worker threads started can be significant portion of the time. This patch starts the workers earlier so they are ready to perform compilation (see code comments) when dynamo gets to that point. Just tested this on one example benchmark (tf_efficientnet_b0), but the results are significant, almost eliminating the difference between a warm and cold compilation. ``` 39.613s - warm 41.290s - cold, this patch 2m53.197s - cold, single threaded: 1m7.092s - cold, old setup n = 8 (its best config) ``` (cold compilation is done after running `rm -rf /tmp/torchinductor_$USER`).ompiling triton kernels by using ProcessPoolExecutor to create persistent pool of compilation workers. Previously os.fork overhead and GIL contention limited the achieved parallelism. This patch replaces the worker threads with a pool of processes to do the raw compilation, and does serial work on the main thread for everything else. This other work couldn't be parallelized anyway since it is mostly in python. In cold start situations, the time to get the worker threads started can be significant portion of the time. This patch starts the workers earlier so they are ready to perform compilation (see code comments) when dynamo gets to that point. Just tested this on one example benchmark (tf_efficientnet_b0), but the results are significant, almost eliminating the difference between a warm and cold compilation. ``` 39.613s - warm 41.290s - cold, this patch 2m53.197s - cold, single threaded: 1m7.092s - cold, old setup n = 8 (its best config) ``` (cold compilation is done after running `rm -rf /tmp/torchinductor_$USER`). ghstack-source-id: 6e71a88 Pull Request resolved: #87032
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
had been previously approved at pytorch/torchdynamo#1666
| @pytorchbot merge -f "test_quantization in windows build seems unrelated" | 
| Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team | 
| Hey @zdevito. | 
| @zdevito with this PR I am getting:  | 
#87032 seems to have an issue that breaks our benchmark script, it might have to do with the benchmark script also using subprocess. Before this PR: ``` $ ./benchmarks/dynamo/torchbench.py --performance --inductor --raise --training --float16 ... Traceback (most recent call last): File "/home/jansel/conda/envs/pytorch/lib/python3.9/concurrent/futures/process.py", line 246, in _process_worker r = call_item.fn(*call_item.args, **call_item.kwargs) File "/home/jansel/pytorch/torch/_inductor/codecache.py", line 239, in _worker_compile kernel = TritonCodeCache.load(source_code) File "/home/jansel/pytorch/torch/_inductor/codecache.py", line 234, in load mod = PyCodeCache.load(source_code) File "/home/jansel/pytorch/torch/_inductor/codecache.py", line 212, in load exec(code, mod.__dict__, mod.__dict__) File "/tmp/torchinductor_jansel/ij/cij7smji4sw2a56i4yz45bjkrosd2sb2raqnxzsxxpg4kwzuo2ta.py", line 5, in <module> from torch._inductor.triton_ops.autotune import reduction File "/home/jansel/pytorch/torch/_inductor/triton_ops/__init__.py", line 3, in <module> if has_triton(): File "/home/jansel/pytorch/torch/_inductor/utils.py", line 38, in has_triton return triton is not None and torch.cuda.get_device_capability() >= (7, 0) File "/home/jansel/pytorch/torch/cuda/__init__.py", line 368, in get_device_capability prop = get_device_properties(device) File "/home/jansel/pytorch/torch/cuda/__init__.py", line 382, in get_device_properties _lazy_init() # will define _get_device_properties File "/home/jansel/pytorch/torch/cuda/__init__.py", line 228, in _lazy_init raise RuntimeError( RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method ``` cc @zdevito [ghstack-poisoned]
#87032 seems to have an issue that breaks our benchmark script, it might have to do with the benchmark script also using subprocess. Before this PR: ``` $ ./benchmarks/dynamo/torchbench.py --performance --inductor --raise --training --float16 ... Traceback (most recent call last): File "/home/jansel/conda/envs/pytorch/lib/python3.9/concurrent/futures/process.py", line 246, in _process_worker r = call_item.fn(*call_item.args, **call_item.kwargs) File "/home/jansel/pytorch/torch/_inductor/codecache.py", line 239, in _worker_compile kernel = TritonCodeCache.load(source_code) File "/home/jansel/pytorch/torch/_inductor/codecache.py", line 234, in load mod = PyCodeCache.load(source_code) File "/home/jansel/pytorch/torch/_inductor/codecache.py", line 212, in load exec(code, mod.__dict__, mod.__dict__) File "/tmp/torchinductor_jansel/ij/cij7smji4sw2a56i4yz45bjkrosd2sb2raqnxzsxxpg4kwzuo2ta.py", line 5, in <module> from torch._inductor.triton_ops.autotune import reduction File "/home/jansel/pytorch/torch/_inductor/triton_ops/__init__.py", line 3, in <module> if has_triton(): File "/home/jansel/pytorch/torch/_inductor/utils.py", line 38, in has_triton return triton is not None and torch.cuda.get_device_capability() >= (7, 0) File "/home/jansel/pytorch/torch/cuda/__init__.py", line 368, in get_device_capability prop = get_device_properties(device) File "/home/jansel/pytorch/torch/cuda/__init__.py", line 382, in get_device_properties _lazy_init() # will define _get_device_properties File "/home/jansel/pytorch/torch/cuda/__init__.py", line 228, in _lazy_init raise RuntimeError( RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method ``` cc zdevito ghstack-source-id: 92b303e Pull Request resolved: #87048
#87032 seems to have an issue that breaks our benchmark script, it might have to do with the benchmark script also using subprocess. Before this PR: ``` $ ./benchmarks/dynamo/torchbench.py --performance --inductor --raise --training --float16 ... Traceback (most recent call last): File "/home/jansel/conda/envs/pytorch/lib/python3.9/concurrent/futures/process.py", line 246, in _process_worker r = call_item.fn(*call_item.args, **call_item.kwargs) File "/home/jansel/pytorch/torch/_inductor/codecache.py", line 239, in _worker_compile kernel = TritonCodeCache.load(source_code) File "/home/jansel/pytorch/torch/_inductor/codecache.py", line 234, in load mod = PyCodeCache.load(source_code) File "/home/jansel/pytorch/torch/_inductor/codecache.py", line 212, in load exec(code, mod.__dict__, mod.__dict__) File "/tmp/torchinductor_jansel/ij/cij7smji4sw2a56i4yz45bjkrosd2sb2raqnxzsxxpg4kwzuo2ta.py", line 5, in <module> from torch._inductor.triton_ops.autotune import reduction File "/home/jansel/pytorch/torch/_inductor/triton_ops/__init__.py", line 3, in <module> if has_triton(): File "/home/jansel/pytorch/torch/_inductor/utils.py", line 38, in has_triton return triton is not None and torch.cuda.get_device_capability() >= (7, 0) File "/home/jansel/pytorch/torch/cuda/__init__.py", line 368, in get_device_capability prop = get_device_properties(device) File "/home/jansel/pytorch/torch/cuda/__init__.py", line 382, in get_device_properties _lazy_init() # will define _get_device_properties File "/home/jansel/pytorch/torch/cuda/__init__.py", line 228, in _lazy_init raise RuntimeError( RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method ``` cc @zdevito Pull Request resolved: #87048 Approved by: https://github.com/soumith
Stack from ghstack (oldest at bottom):
This patch significantly improves the parallel compilation performance for cThis patch significantly improves the parallel compilation performance for compiling triton kernels
by using ProcessPoolExecutor to create persistent pool of compilation
workers.
Previously os.fork overhead and GIL contention limited the achieved
parallelism. This patch replaces
the worker threads with a pool of processes to do the raw compilation,
and does serial work on the main thread
for everything else. This other work couldn't be parallelized anyway
since it is mostly in python.
In cold start situations, the time to get the worker threads started can
be significant portion of the time.
This patch starts the workers earlier so they are ready to perform
compilation (see code comments) when dynamo
gets to that point.
Just tested this on one example benchmark (tf_efficientnet_b0), but the
results are significant, almost eliminating the difference between a
warm and cold compilation.
(cold compilation is done after running
rm -rf /tmp/torchinductor_$USER).ompiling triton kernelsby using ProcessPoolExecutor to create persistent pool of compilation workers.
Previously os.fork overhead and GIL contention limited the achieved parallelism. This patch replaces
the worker threads with a pool of processes to do the raw compilation, and does serial work on the main thread
for everything else. This other work couldn't be parallelized anyway since it is mostly in python.
In cold start situations, the time to get the worker threads started can be significant portion of the time.
This patch starts the workers earlier so they are ready to perform compilation (see code comments) when dynamo
gets to that point.
Just tested this on one example benchmark (tf_efficientnet_b0), but the results are significant, almost eliminating the difference between a warm and cold compilation.
(cold compilation is done after running
rm -rf /tmp/torchinductor_$USER).