Skip to content

[FR] torch context with default device & dtype #27878

@ssnl

Description

@ssnl

🚀 Feature

There has been needs to set default device. This FR proposes an API that is compatible with 3rd party libs.

Motivation

There has been a lot of discussion around default device flag in pytorch (#7535). Yet, implementing such an API has been mostly blocked by the concern that 3rd party libraries may assume that tensors are created on CPU.

Similar dilemma has been also seen in the python multiprocessing library, where multiple start methods can be used (triggering bugs like librosa/librosa#747). They come up with this API multiprocessing.get_context(xxx), which returns a context object with the same set of functions as the multiprocessing module, but associated with a different start method, enabling patterns like mp = mp.get_context('spawn').

In addition to tensor creation, for many users (including me) it is extremely verbose to write .to(device) for every tensor yielded tensor from the data loader. It would be very handy if this handles the moving automatically as well.

Pitch

  1. Tensor creation
    import torch
    torch.empty(3)  # returns CPU float tensor
    torch = torch.get_context(device='cuda:2', dtype=torch.half)
    torch.empty(3)  # returns CUDA half tensor
  2. Module creation
    import torch
    torch = torch.get_context(device='cuda:2')
    MyModule()  # returns module on CUDA if MyModule uses torch tensor creation methods
  3. (needs discussion) affect torch_ctx.utils.data.DataLoader s.t. the yielded samples contain tensors moved to the device of torch_ctx.
  4. (needs discussion) affect torch_ctx.load(xxx) the same way as map_location.

Concerns

Such features always have the potential issue of making things too "frictionless", obscure, and harder to debug. But I think this one is not too bad.

cc @ezyang @gchanan @zou3519 @bdhirsh @heitorschueroff @ngimel @ejguan

Metadata

Metadata

Assignees

No one assigned

    Labels

    ezyang's listStuff ezyang doesn't want to losefeatureA request for a proper, new feature.module: cudaRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions