Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DDP] Add PackedSequence support when device_ids is specified #86614

Closed
wants to merge 4 commits into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Oct 10, 2022

Stack from ghstack:

Before this PR, if a user runs DDP with device_ids specified and with a PackedSequence input, then the execution will error with something like:

raise ValueError(
  ValueError: batch_sizes should always be on CPU. Instances of PackedSequence should never be created manually. They should be instantiated by
 functions like pack_sequence and pack_padded_sequences in nn.utils.rnn. https://pytorch.org/docs/stable/nn.html...

This is because the DDP forward calls _to_kwargs(), which calls _recursive_to(), which moves the inputs to GPU. However, _is_namedtuple(packed_sequence) returns True, leading to the branch return [type(obj)(*args) for args in zip(*map(to_map, obj))], which tries to construct a PackedSequence directly via type(obj)(*args), leading to the error.

Repro for _is_namedtuple(packed_sequence) returning True:

import random

import torch
import torch.nn.utils.rnn as rnn_utils
from torch.nn.parallel.scatter_gather import _is_namedtuple

def _ordered_sequence(tensor_type):
    seqs = [tensor_type(random.randint(1, 256))
            for _ in range(32)]
    seqs = [s.random_(-128, 128) for s in seqs]
    ordered = sorted(seqs, key=len, reverse=True)
    return ordered

def _padded_sequence(tensor_type):
    ordered = _ordered_sequence(tensor_type)
    lengths = [len(i) for i in ordered]
    padded_tensor = rnn_utils.pad_sequence(ordered)
    return padded_tensor, lengths

padded, lengths = _padded_sequence(torch.Tensor)
packed = rnn_utils.pack_padded_sequence(
    padded, lengths, enforce_sorted=False)
print(type(packed), packed.data.device)
print(_is_namedtuple(packed))

Test Plan:

python test/distributed/test_c10d_nccl.py -k test_ddp_packed_sequence

Without the fix, the added unit test fails with the expected error.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 10, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86614

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4d7c335:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (c10d) release notes category label Oct 10, 2022
awgu added a commit that referenced this pull request Oct 10, 2022
ghstack-source-id: 6cdbf552770ec74c5295bd96fe14a80ec4cf3c30
Pull Request resolved: #86614
@@ -1,10 +1,12 @@
from typing import Any, Dict, List, Tuple
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorted imports.

…ecified"


Before this PR, if a user runs DDP with `device_ids` specified and with a `PackedSequence` input, then the execution will error with something like:
```
raise ValueError(
  ValueError: batch_sizes should always be on CPU. Instances of PackedSequence should never be created manually. They should be instantiated by
 functions like pack_sequence and pack_padded_sequences in nn.utils.rnn. https://pytorch.org/docs/stable/nn.html...
```
This is because the DDP forward calls `_to_kwargs()`, which calls `_recursive_to()`, which moves the inputs to GPU. However, `_is_namedtuple(packed_sequence)` returns `True`, leading to the branch `return [type(obj)(*args) for args in zip(*map(to_map, obj))]`, which tries to construct a `PackedSequence` directly via `type(obj)(*args)`, leading to the error.

Repro for `_is_namedtuple(packed_sequence)` returning `True`:
```
import random

import torch
import torch.nn.utils.rnn as rnn_utils
from torch.nn.parallel.scatter_gather import _is_namedtuple

def _ordered_sequence(tensor_type):
    seqs = [tensor_type(random.randint(1, 256))
            for _ in range(32)]
    seqs = [s.random_(-128, 128) for s in seqs]
    ordered = sorted(seqs, key=len, reverse=True)
    return ordered

def _padded_sequence(tensor_type):
    ordered = _ordered_sequence(tensor_type)
    lengths = [len(i) for i in ordered]
    padded_tensor = rnn_utils.pad_sequence(ordered)
    return padded_tensor, lengths

padded, lengths = _padded_sequence(torch.Tensor)
packed = rnn_utils.pack_padded_sequence(
    padded, lengths, enforce_sorted=False)
print(type(packed), packed.data.device)
print(_is_namedtuple(packed))
```

Test Plan:
```
python test/distributed/test_c10d_nccl.py -k test_ddp_packed_sequence
```




[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 10, 2022
ghstack-source-id: 4edd04038037ceea107d4946253583c369f8dde1
Pull Request resolved: #86614
@awgu awgu added module: ddp Issues/PRs related distributed data parallel training release notes: distributed (ddp) release notes category and removed release notes: distributed (c10d) release notes category labels Oct 10, 2022
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

test/distributed/test_c10d_nccl.py Show resolved Hide resolved
…ecified"


Before this PR, if a user runs DDP with `device_ids` specified and with a `PackedSequence` input, then the execution will error with something like:
```
raise ValueError(
  ValueError: batch_sizes should always be on CPU. Instances of PackedSequence should never be created manually. They should be instantiated by
 functions like pack_sequence and pack_padded_sequences in nn.utils.rnn. https://pytorch.org/docs/stable/nn.html...
```
This is because the DDP forward calls `_to_kwargs()`, which calls `_recursive_to()`, which moves the inputs to GPU. However, `_is_namedtuple(packed_sequence)` returns `True`, leading to the branch `return [type(obj)(*args) for args in zip(*map(to_map, obj))]`, which tries to construct a `PackedSequence` directly via `type(obj)(*args)`, leading to the error.

Repro for `_is_namedtuple(packed_sequence)` returning `True`:
```
import random

import torch
import torch.nn.utils.rnn as rnn_utils
from torch.nn.parallel.scatter_gather import _is_namedtuple

def _ordered_sequence(tensor_type):
    seqs = [tensor_type(random.randint(1, 256))
            for _ in range(32)]
    seqs = [s.random_(-128, 128) for s in seqs]
    ordered = sorted(seqs, key=len, reverse=True)
    return ordered

def _padded_sequence(tensor_type):
    ordered = _ordered_sequence(tensor_type)
    lengths = [len(i) for i in ordered]
    padded_tensor = rnn_utils.pad_sequence(ordered)
    return padded_tensor, lengths

padded, lengths = _padded_sequence(torch.Tensor)
packed = rnn_utils.pack_padded_sequence(
    padded, lengths, enforce_sorted=False)
print(type(packed), packed.data.device)
print(_is_namedtuple(packed))
```

Test Plan:
```
python test/distributed/test_c10d_nccl.py -k test_ddp_packed_sequence
```
Without the fix, the added unit test fails with the expected error.


[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 10, 2022
ghstack-source-id: d7445464b164f702035912100a0caffe51c02df3
Pull Request resolved: #86614
…ecified"


Before this PR, if a user runs DDP with `device_ids` specified and with a `PackedSequence` input, then the execution will error with something like:
```
raise ValueError(
  ValueError: batch_sizes should always be on CPU. Instances of PackedSequence should never be created manually. They should be instantiated by
 functions like pack_sequence and pack_padded_sequences in nn.utils.rnn. https://pytorch.org/docs/stable/nn.html...
```
This is because the DDP forward calls `_to_kwargs()`, which calls `_recursive_to()`, which moves the inputs to GPU. However, `_is_namedtuple(packed_sequence)` returns `True`, leading to the branch `return [type(obj)(*args) for args in zip(*map(to_map, obj))]`, which tries to construct a `PackedSequence` directly via `type(obj)(*args)`, leading to the error.

Repro for `_is_namedtuple(packed_sequence)` returning `True`:
```
import random

import torch
import torch.nn.utils.rnn as rnn_utils
from torch.nn.parallel.scatter_gather import _is_namedtuple

def _ordered_sequence(tensor_type):
    seqs = [tensor_type(random.randint(1, 256))
            for _ in range(32)]
    seqs = [s.random_(-128, 128) for s in seqs]
    ordered = sorted(seqs, key=len, reverse=True)
    return ordered

def _padded_sequence(tensor_type):
    ordered = _ordered_sequence(tensor_type)
    lengths = [len(i) for i in ordered]
    padded_tensor = rnn_utils.pad_sequence(ordered)
    return padded_tensor, lengths

padded, lengths = _padded_sequence(torch.Tensor)
packed = rnn_utils.pack_padded_sequence(
    padded, lengths, enforce_sorted=False)
print(type(packed), packed.data.device)
print(_is_namedtuple(packed))
```

Test Plan:
```
python test/distributed/test_c10d_nccl.py -k test_ddp_packed_sequence
```
Without the fix, the added unit test fails with the expected error.


[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 10, 2022
ghstack-source-id: 78bab0549d3b65a8e2b4933ad07544d96f49f616
Pull Request resolved: #86614
@rohan-varma rohan-varma self-requested a review October 10, 2022 18:48
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 10, 2022
@awgu
Copy link
Contributor Author

awgu commented Oct 10, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link

Hey @awgu.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@facebook-github-bot facebook-github-bot deleted the gh/awgu/119/head branch June 8, 2023 15:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: ddp Issues/PRs related distributed data parallel training release notes: distributed (ddp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants