Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature request] torch.to(obj, device) supporting recursive lists/dicts/tuples of tensors probably by uplifting/promoting torch.distributed.utils._recursive_to #69431

Open
vadimkantorov opened this issue Dec 5, 2021 · 27 comments
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Dec 5, 2021

Often it is needed to move model results to cpu (or inputs to gpu). Once the data structures get a bit complicated, dicts and lists appear often in model results. Often we have to roll a little utility method like below. If indeed other people had to write this sort of utilities, may make sense to include something like this to core.

def to(obj, device):
  if torch.is_tensor(obj):
    return obj.to(device)
  if isinstance(obj, dict):
    return {k : to(v, device) for k, v in obj.items()}
  if isinstance(obj, tuple):
    return tuple(to(v, device) for v in obj)
  if isinstance(obj, list):
    return [to(v, device) for v in obj]
  return obj

cc @albanD @mruberry @jbschlosser @walterddr

@vadimkantorov vadimkantorov changed the title [proposal] [util] torch.to(obj, device) supporting reecursive lists/dicts/tuples of tensors [proposal] [util] torch.to(obj, device) supporting recursive lists/dicts/tuples of tensors Dec 6, 2021
@anjali411 anjali411 added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module function request A request for a new function or the addition of new arguments/modes to an existing function. and removed module: nn Related to torch.nn labels Dec 6, 2021
@albanD albanD added the needs research We need to decide whether or not this merits inclusion, based on research world label Dec 7, 2021
@albanD
Copy link
Collaborator

albanD commented Dec 7, 2021

Wouldn't this be nicely solved by just using pytree from here
where you can do tree_map(torch.to, your_obj) ?

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Dec 7, 2021

maybe! but I think it's still useful to have this primitive in core as it appears a lot in user code and very few people know about pytrees (i certainly don't). it could also be to better have a simple manual version instead to be exactly sure what's happening. e.g. a simple version makes it clear that it would be slow-ish for giant lists of python numbers (as there would be a lot of checks before returning the value), not clear what's happening with pytree. it also seems to have some flatten/unflatten concept...

also, torch.to doesn't exist now or isn't documented

@albanD
Copy link
Collaborator

albanD commented Dec 7, 2021

also, torch.to doesn't exist now or isn't documented

Ho I used that because you mentioned it above :p But indeed that doesn't even exist.

@mruberry
Copy link
Collaborator

mruberry commented Dec 8, 2021

The lack of torch.to isn't actually an impediment because it's easy to make the method a function using a lambda.

We do write "crawlers" like this occasionally. assert_close is the most recent example that comes to mind and has one. I think there's also a slightly different version in common_methods_invocations.py. I'm inclined to suggest people write their own crawlers or use a Python utility package for functionality like this. There are a lot of vagaries (like what if the "elements" in the containers aren't other containers or tensors?) and the functionality is more about dealing with Python datastructures than it is dealing with tensors.

NumPy doesn't have any functionality like this, does it?

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Dec 8, 2021

Yeah, I just propose to use the opportunity and introduce torch.to supporting slightly more generic version, supporting basic Python structs (also supported by TorchScript), be it with pytree utils or not. I'm proposing this on improving UX grounds, so can't argue with that this isn't an "impediment". However, similar functions get re-rolled practically in every project.

I somewhat agree, but even in implementing custom device-movers for custom structures having an existing thing working for some known types will make the code simpler. My code above just passes through non PyTorch things. And I also agree/think that it would have been nicer to have more modular/reusable (but still simple enough) "collation" / "conversion" routines.

NumPy doesn't have a concept of device anyway, so it doesn't have this functionality obviously (the need for other automated conversions like dtype conversions is much less), but the reason isn't relevant IMO.

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Mar 29, 2022

I saw somewhere more examples of pytree, so I now understand better that pytree traversal indeed can be used for making this recursive device transfer :) But I would still suggest, that it's a good default behavior for torch.to (to be introduced) without forcing users to make their own pytree'd version, since this recursive mode is often needed.

There might be some optimizations like preserving data sharing so that if we do torch.to(torch.split(...), ...) the converted tensors are still views over parts of a single tensor. Bit it might be too special-case

@vadimkantorov
Copy link
Contributor Author

I also wonder if treemap is parallel (akin to _foreach methods). Probably for torch.to(..., non_blocking=True) it doesn't matter as much, but still. It also probably could benefit from a single cudaMalloc call (at least for all strided tensors) flattening all tensors in the input

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented May 10, 2022

Related PR that actually implements this recursive utility: #77187, but not in generic namespace :(

@rohan-varma
Copy link
Member

Found this issue after @vadimkantorov comment on a related PR. Agreed that such a utility would be quite useful and PT-D would then not need such custom logic to move inputs for DDP / FSDP.

@albanD @jbschlosser @mruberry Do we have any new thoughts on this feature and whether this is something that the core team might be able to address?

@vadimkantorov vadimkantorov changed the title [proposal] [util] torch.to(obj, device) supporting recursive lists/dicts/tuples of tensors [proposal] [util] torch.to(obj, device) supporting recursive lists/dicts/tuples of tensors probably by uplifting/promoting trorch.distributed.utils._recursive_to May 11, 2022
@vadimkantorov vadimkantorov changed the title [proposal] [util] torch.to(obj, device) supporting recursive lists/dicts/tuples of tensors probably by uplifting/promoting trorch.distributed.utils._recursive_to [proposal] [util] torch.to(obj, device) supporting recursive lists/dicts/tuples of tensors probably by uplifting/promoting torch.distributed.utils._recursive_to May 11, 2022
@albanD
Copy link
Collaborator

albanD commented May 11, 2022

Any reason for not using pytree as suggested above in your case?

@mruberry
Copy link
Collaborator

mruberry commented May 12, 2022

I understand this would be convenient, but we try not to add sugar to core PyTorch operations.

@jbschlosser
Copy link
Contributor

jbschlosser commented May 12, 2022

To be explicit: this would involve use of pytree.tree_map on the data structure with lambda t: t.to(...).

@albanD
Copy link
Collaborator

albanD commented May 12, 2022

Well, for nn.Module, that already works as you can call to on the top Module.

@vadimkantorov
Copy link
Contributor Author

there are some stuffs about streams, not sure if tree_map will be able to handle it?

may also be a good idea to support non_blocking and calling custom methods .to if they exist in addition to tensors

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Dec 23, 2022

So now in theory, torch.to could just be implemented in terms of torch.utils._pytree.tree_map.

Compared to vanilla tree_map, torch.to could theoretically do all allocations in one batch (asynchronously) / in a single allocation, but not sure what could be a good design for it. Although, not even sure if it's optimal of doing one large allocation or many smaller ones wrt reuse of already allocated segments that might not be contiguous

@vadimkantorov
Copy link
Contributor Author

although in another context (sparse tensors) this post describes an optimization option that torch.to could also do (minimize the number of host->device copies): https://pytorch.org/blog/optimizing-production-pytorch-performance-with-graph-transformations/#31-combining-input-sparse-features

@vadimkantorov vadimkantorov changed the title [proposal] [util] torch.to(obj, device) supporting recursive lists/dicts/tuples of tensors probably by uplifting/promoting torch.distributed.utils._recursive_to [feature request] torch.to(obj, device) supporting recursive lists/dicts/tuples of tensors probably by uplifting/promoting torch.distributed.utils._recursive_to Jan 2, 2024
@vadimkantorov vadimkantorov changed the title [feature request] torch.to(obj, device) supporting recursive lists/dicts/tuples of tensors probably by uplifting/promoting torch.distributed.utils._recursive_to [feature request] torch.to(obj, device) supporting recursive lists/dicts/tuples of tensors probably by uplifting/promoting torch.distributed.utils._recursive_to Jan 2, 2024
@vadimkantorov
Copy link
Contributor Author

@vadimkantorov vadimkantorov changed the title [feature request] torch.to(obj, device) supporting recursive lists/dicts/tuples of tensors probably by uplifting/promoting torch.distributed.utils._recursive_to [feature request] torch.to(obj, device) supporting recursive lists/dicts/tuples of tensors probably by uplifting/promoting torch.distributed.utils._recursive_to Jan 2, 2024
@albanD
Copy link
Collaborator

albanD commented Jan 2, 2024

I guess pytree would be the way to go for this in 2024 :D
new = tree_map_only(torch.Tensor, lambda t: t.to(device), old)

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jan 2, 2024

On the workability side, yes it would cut it, but having some builtin option for separate copy thread or efficiently moving a lot of tensors in one allocation might be interesting - especially if it's already implemented in distributed context. Otheriwse, for functional programming side, I think implementing torch.to as tree_map_only is a worthy shortcut for users to have in core and for consistency between functions/methods :)

@albanD
Copy link
Collaborator

albanD commented Jan 2, 2024

I do not think that changing torch.to() would be good for consistency actually. That would make the behavior significantly different between function and methods.

@vadimkantorov
Copy link
Contributor Author

For basic individual tensors, the behavior would be exactly the same anyway? So for all inputs acceptable for instance .to method, the function behavior torch.to would be the same, right? I'm proposing to keep them strictly identical to all individual tensors (so, both .to and torch.to would continue to exist), and potentially have more knobs/functionality for more complex inputs, like tensor lists or pytree object hierarchies

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Apr 16, 2024

Btw _recursive_to got copied over in https://github.com/pytorch/pytorch/blob/main/torch/distributed/optim/zero_redundancy_optimizer.py#L29

So maybe it is the right moment for it to be promoted at least to torch._to?

@albanD
Copy link
Collaborator

albanD commented Apr 16, 2024

We should remove this code and just use tree_map there instead. It's going to be a lot more reliable...

@vadimkantorov
Copy link
Contributor Author

just use tree_map there instead

I think it'd be also a great impl for a publicly advertised shortcut :) Otherwise, this duplicate code exists in distributed_utils and in this ZeroOptimizer.py

A more custom op could somehow parallelize copies (if a large tensor list is passed as input) using multiple CUDA threads or allocate a single large contig memory chunk on GPU (and maybe do all this in async fashion if this suits to be able to schedule the copies on user-provided background CUDA stream used for copies)

@albanD
Copy link
Collaborator

albanD commented Apr 17, 2024

I think that if we need to add one more API for our users to know about, I prefer for them to learn about pytree. It will allow them to solve many of their problems with one line. Compared to a specialized API that will only solve a single problem.

@vadimkantorov
Copy link
Contributor Author

I think, torch.to is a very natural thing to search for if you already used tensor.to (as most functions exist both as instance functions and static methods). But even if this pytree-to-implement-generic-to idiom is publicized via HF and PyTorch code examples - it would already be great!

Another advantage to introduce torch.to is that later on some optimizations can be made for processing large lists

@albanD
Copy link
Collaborator

albanD commented Apr 17, 2024

In theory, nothing prevents Dynamo from tracing pytree work and doing fancy optimization there :D

cc @zou3519 for pytree as a public feature

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants