Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FR] torch context with default device & dtype #27878

Closed
ssnl opened this issue Oct 14, 2019 · 19 comments
Closed

[FR] torch context with default device & dtype #27878

ssnl opened this issue Oct 14, 2019 · 19 comments
Labels
ezyang's list Stuff ezyang doesn't want to lose feature A request for a proper, new feature. module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ssnl
Copy link
Collaborator

ssnl commented Oct 14, 2019

🚀 Feature

There has been needs to set default device. This FR proposes an API that is compatible with 3rd party libs.

Motivation

There has been a lot of discussion around default device flag in pytorch (#7535). Yet, implementing such an API has been mostly blocked by the concern that 3rd party libraries may assume that tensors are created on CPU.

Similar dilemma has been also seen in the python multiprocessing library, where multiple start methods can be used (triggering bugs like librosa/librosa#747). They come up with this API multiprocessing.get_context(xxx), which returns a context object with the same set of functions as the multiprocessing module, but associated with a different start method, enabling patterns like mp = mp.get_context('spawn').

In addition to tensor creation, for many users (including me) it is extremely verbose to write .to(device) for every tensor yielded tensor from the data loader. It would be very handy if this handles the moving automatically as well.

Pitch

  1. Tensor creation
    import torch
    torch.empty(3)  # returns CPU float tensor
    torch = torch.get_context(device='cuda:2', dtype=torch.half)
    torch.empty(3)  # returns CUDA half tensor
  2. Module creation
    import torch
    torch = torch.get_context(device='cuda:2')
    MyModule()  # returns module on CUDA if MyModule uses torch tensor creation methods
  3. (needs discussion) affect torch_ctx.utils.data.DataLoader s.t. the yielded samples contain tensors moved to the device of torch_ctx.
  4. (needs discussion) affect torch_ctx.load(xxx) the same way as map_location.

Concerns

Such features always have the potential issue of making things too "frictionless", obscure, and harder to debug. But I think this one is not too bad.

cc @ezyang @gchanan @zou3519 @bdhirsh @heitorschueroff @ngimel @ejguan

@soumith
Copy link
Member

soumith commented Oct 14, 2019

after one framework finally got rid of a global context after 3 years, are you saying the competing framework should add one? :P

@ssnl
Copy link
Collaborator Author

ssnl commented Oct 14, 2019

:D a different kind of context

@zou3519 zou3519 added module: cuda Related to torch.cuda, and CUDA support in general needs research We need to decide whether or not this merits inclusion, based on research world triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Oct 15, 2019
@gchanan
Copy link
Contributor

gchanan commented Feb 6, 2020

this sounds like a promising API -- there's probably some implementation complexity behind it, but we'll remove "needs research" as it looks acceptable.

@gchanan gchanan removed the needs research We need to decide whether or not this merits inclusion, based on research world label Feb 6, 2020
@elistevens
Copy link
Contributor

Unclear if this issue is still actually active, but it was pointed at recently as if it were, and I'd like to point out that #2 above would require some magic to make work properly (assuming that the torch = part is required, vs. just setting a bunch of global state in torch.get_context(device='cuda:2'), and in that case, the function should be renamed to something more accurate), and would leave a bunch of polluted global state behind, possibly just from importing a 3rd party package that used it.

Please use a context manager like with torch.set_default_device('cuda:0'):.

@ssnl
Copy link
Collaborator Author

ssnl commented Sep 24, 2020

@elistevens Yeah I think you are right. #2 doesn't work normally, unfortunately.

@ezyang
Copy link
Contributor

ezyang commented Oct 12, 2020

Bumping priority due to activity on this issue, the parent issue, and also other issues.

@malfet
Copy link
Contributor

malfet commented Oct 19, 2020

Wouldn't context manager be better suited that a global context?
I.e.

with torch.context(device='cuda:2', dtype=torch.double):
   x = torch.ones((5, 5))

@ssnl
Copy link
Collaborator Author

ssnl commented Oct 19, 2020

@malfet Why not having the object do both? It is not always desirable to wrap the entire program in a context manager.

@ailzhang ailzhang added feature A request for a proper, new feature. and removed high priority triage review labels Oct 19, 2020
@ailzhang
Copy link
Contributor

Removing "high priority" label since this feature request still need some work in defining the user facing API and how it should work with existing code.
Probably also related: we also need define what to do with the current torch.set_default_dtype API.

@ezyang
Copy link
Contributor

ezyang commented Oct 20, 2020

Context manager doesn't give you desirable behavior, as it will also affect library code that allocates tensors. A lot of the resistance against a global "set default device" API is because it will make it difficult for library authors to write code in a way that will work no matter what the default device is. A module-like torch context object bypasses this problem as the device defaulting is lexical.

This reminds me, though, that in the proposal above, Module creation isn't done using the torch context explicitly; some amount of dynamic scoping seems necessary there. So this proposal, unfortunately, isn't complete.

@ezyang
Copy link
Contributor

ezyang commented Nov 17, 2020

Oh, actually, module creation is done doing torch context, because you say torch.nn.Module, so if you have a torch context, that would also solve your NN problem to. (But you'll have to shave the yak of how to desugar these calls into Python module calls first.)

@ezyang ezyang added the ezyang's list Stuff ezyang doesn't want to lose label Nov 17, 2020
@ssnl
Copy link
Collaborator Author

ssnl commented Nov 29, 2020

@ezyang Yeah... that could work. But in the usual case where one has a separate module containing the module definition, it seems hard to make the context device/dtype configurable (because that separate module just uses the regular torch.nn), right?

@ezyang
Copy link
Contributor

ezyang commented Nov 30, 2020

You have to solve the problem which is that we don't actually support direct on device Module creation. Supposing you add some explicit API for doing this, then you can just write some Python magic to sniff for nn modules and then partially apply them with the device arguments. You still avoid dynamic scoping in this case.

@ssnl
Copy link
Collaborator Author

ssnl commented Nov 30, 2020

Hmm I don't think I understand. What I am thinking is a scenario like the following:

# model.py

import torch

class Model(torch.nn.Module):
  ...
# main.py

torch_ctx = torch.get_context(sys.argv[1])

from model import Model

m = Model()  # <- not using `torch_ctx`

Oh maybe your this comment

You have to solve the problem which is that we don't actually support direct on device Module creation

is referring to something like Model(..., torch_ctx=torch_ctx) or torch_ctx.create(Model, ...), which would then make sense, although I can't think of an API that is nice for this...

@rgommers
Copy link
Collaborator

Python modules are singletons, so it's a little unclear what model = torch_ctx.wrap(model) should be doing. The wrap() can't just tack some context onto the existing model (that would leak to other users of model), therefore model needs to be a new namespace-like object with Model also being a new generated version of the original Model.

I guess that's the yak that needed shaving, and the same applies to torch = torch.get_context(...). Having an idea of how it could be implemented would be nice - it looks to me like you need a factory function that produces new namespaces that look exactly like an existing namespace, with all new objects that are just like the old ones but with some partial applied. That might still produce some unexpected behaviour, like isinstance or issubclass failing. I don't think there's a great way to get "partial classes" (see this SO question).

@carlosgmartin
Copy link

In my opinion, something like this would be helpful. My code uses a lot of tensor constructors like tensor, arange, zeros, full, eye, empty, as_tensor, etc. and is therefore littered with device=self.device flags. Accidentally omitting any such flag will cause an error. A context manager like @elistevens suggests sounds promising.

@ezyang Could those libraries not use their own nested context?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Sep 7, 2022

Bumping up this thread. We(PyTorch/XLA) are recently testing large models on TPU clusters. One issue we run into is by default model initialization happens on CPU, with large models it is really easy to OOM the host memory. If we can specify the default device with a context(so we can init model weights on TPU directly) that would be really helpful.

@ezyang
Copy link
Contributor

ezyang commented Sep 8, 2022

check out #82296 (comment)

ezyang added a commit that referenced this issue Dec 30, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Dec 30, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 3f678907f143125bf0931d2c86eae7c9eb8ee156
Pull Request resolved: #91525
ezyang added a commit that referenced this issue Dec 30, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Dec 30, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: a3644fb379a42e65052be03d86bca7a56baec198
Pull Request resolved: #91525
ezyang added a commit that referenced this issue Dec 30, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Dec 31, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Dec 31, 2022
Fixes #82296
Fixes #27878
Fixes #260

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Dec 31, 2022
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 539474a5becc7713dc61aa42edd3802640caf319
Pull Request resolved: #91525
ezyang added a commit that referenced this issue Jan 2, 2023
Fixes #82296
Fixes #27878
Fixes #260

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Jan 2, 2023
Fixes #82296
Fixes #27878
Fixes #260

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Jan 2, 2023
Fixes #82296
Fixes #27878
Fixes #260

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Jan 2, 2023
Fixes #82296
Fixes #27878
Fixes #260

Signed-off-by: Edward Z. Yang <ezyangfb.com>

[ghstack-poisoned]
ezyang added a commit that referenced this issue Jan 2, 2023
Fixes #82296
Fixes #27878
Fixes #260

Open to bikeshedding the module name.  Open to bikeshedding if any of
the identifiers should get reexporting from torch top level.  Open to
bikeshedding the global setter API name / location.

Tests coming later.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

ghstack-source-id: 79b87ed4a96e3facce461f54d84286832f496d0b
Pull Request resolved: #91525
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ezyang's list Stuff ezyang doesn't want to lose feature A request for a proper, new feature. module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

12 participants