-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Description
🚀 Feature
There has been needs to set default device. This FR proposes an API that is compatible with 3rd party libs.
Motivation
There has been a lot of discussion around default device flag in pytorch (#7535). Yet, implementing such an API has been mostly blocked by the concern that 3rd party libraries may assume that tensors are created on CPU.
Similar dilemma has been also seen in the python multiprocessing
library, where multiple start methods can be used (triggering bugs like librosa/librosa#747). They come up with this API multiprocessing.get_context(xxx)
, which returns a context object with the same set of functions as the multiprocessing
module, but associated with a different start method, enabling patterns like mp = mp.get_context('spawn')
.
In addition to tensor creation, for many users (including me) it is extremely verbose to write .to(device)
for every tensor yielded tensor from the data loader. It would be very handy if this handles the moving automatically as well.
Pitch
- Tensor creation
import torch torch.empty(3) # returns CPU float tensor torch = torch.get_context(device='cuda:2', dtype=torch.half) torch.empty(3) # returns CUDA half tensor
- Module creation
import torch torch = torch.get_context(device='cuda:2') MyModule() # returns module on CUDA if MyModule uses torch tensor creation methods
- (needs discussion) affect
torch_ctx.utils.data.DataLoader
s.t. the yielded samples contain tensors moved to the device oftorch_ctx
. - (needs discussion) affect
torch_ctx.load(xxx)
the same way asmap_location
.
Concerns
Such features always have the potential issue of making things too "frictionless", obscure, and harder to debug. But I think this one is not too bad.
cc @ezyang @gchanan @zou3519 @bdhirsh @heitorschueroff @ngimel @ejguan