-
Notifications
You must be signed in to change notification settings - Fork 26.1k
Description
🚀 Feature
When using CUDA, by default, synchronization of the host and GPU works through a tight loop that polls a particular memory location, running a CPU core at 100%. CUDA provides two alternative modes: (1) a tight loop that yields to other threads in between, and (2) synchronization via an interrupt (translated into an event by the driver). PyTorch should provide some way to select either of the three modes.
Motivation
Burning a CPU core at 100% is not very green. It doesn't save a lot of time either, especially when the GPU has to process large workloads in between synchronizations. And it makes it harder to see how much the CPU is actually utilized.
Pitch
On the backend, the way to select the synchronization mode is to run either of:
cudaSetDeviceFlags(cudaDeviceScheduleAuto)cudaSetDeviceFlags(cudaDeviceScheduleSpin)cudaSetDeviceFlags(cudaDeviceScheduleYield)cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync)
This can be done right after the first cudaSetDevice() call, but before the device is used (otherwise it's a cudaErrorSetOnActiveProcess error) -- this may be hard to guarantee. However, it can also be done before any cudaSetDevice() call, in that case it will set the default for any future activated devices.
So I'd suggest to add a function such as torch.cuda.set_sync_mode() that takes a string "auto", "spin", "yield", or "block" and directly issues the corresponding cudaSetDeviceFlags() call. It's up to the user to run this early enough then.
Disclaimer: I haven't looked into what would happen with multiprocessing.
Alternatives
The main alternative for implementing this would be figuring out where to safely place a cudaSetDeviceFlags() call in the device managing code, and have it read some global variables at that time. We could have a torch.cuda.sync_mode variable then, or still guard it by a setter function.
It's also possible to solve this in user code doing something like:
import ctypes
ctypes.CDLL('libcudart.so').cudaSetDeviceFlags(4) # = cudaDeviceScheduleBlockingSynccc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @ngimel @VitalyFedyunin @mruberry