Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

DistributedSampler internal asserts if len(dataset) * 2 < number of GPUs #45324

Closed
agemor opened this issue Sep 25, 2020 · 4 comments
Closed
Labels
high priority oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@agemor
Copy link
Contributor

agemor commented Sep 25, 2020

馃悰 Bug

DistributedSampler crash has been reported in some threads (links below), while its cause has not been identified.

The crash is due to the implementation error in DistributedSampler, where it crashes when the len(dataset) * 2 < num_gpus. L101 of the following code snippet is the problematic part, as self.total_size - len(indices) overflows indices in such case.

if not self.drop_last:
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size

To Reproduce

import os
import torch.utils.data
import torch.multiprocessing as mp
import torch.distributed

# Error occurs when dataset_size * 2 < num_gpus
num_gpus = 9
dataset_size = 4

def main(gpu):
    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=num_gpus, rank=gpu)

    dataset = DummyDataset()
    sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_gpus, rank=gpu)
    loader = torch.utils.data.DataLoader(dataset, 1, sampler=sampler)

    for batch in loader:
        print(batch)

class DummyDataset(torch.utils.data.Dataset):
    def __getitem__(self, item):
        return 0
    def __len__(self):
        return dataset_size

if __name__ == '__main__':
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '10000'
    mp.spawn(main, nprocs=num_gpus)

If dataset_size * 2 < num_gpus, (e.g. dataset_size = 4 and num_gpus =9), following error occurs.

-- Process 3 terminated with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
    fn(i, *args)
  File "/home/user/pytorch_bug.py", line 25, in main
    for batch in loader:
  File "/usr/local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 384, in _next_data
    index = self._next_index()  # may raise StopIteration
  File "/usr/local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 339, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
  File "/usr/local/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 200, in __iter__
    for idx in self.sampler:
  File "/usr/local/lib/python3.6/site-packages/torch/utils/data/distributed.py", line 68, in __iter__
    assert len(indices) == self.total_size
AssertionError

cc @ezyang @gchanan @zou3519 @ssnl @VitalyFedyunin @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar @jiayisuse @agolynski

@agemor
Copy link
Contributor Author

agemor commented Sep 25, 2020

A simple workaround is to set drop_last=True for such case. Yet we still need more preventive measures.

@zou3519 zou3519 added module: dataloader Related to torch.utils.data.DataLoader and Sampler high priority oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Sep 25, 2020
@zou3519
Copy link
Contributor

zou3519 commented Sep 25, 2020

Hi-pri for crash

@gchanan
Copy link
Contributor

gchanan commented Sep 28, 2020

not a crash.

@zou3519 zou3519 removed the module: dataloader Related to torch.utils.data.DataLoader and Sampler label Sep 28, 2020
@zou3519
Copy link
Contributor

zou3519 commented Sep 28, 2020

As @gchanan mentioned, this isn't a crash. It's an internal assertion error that's being shown to users. We should generally not assert to the user and instead raise nice error messages if possible, so I'm leaving this at hi pri

@zou3519 zou3519 changed the title DistributedSampler crashes if len(dataset) * 2 < number of GPUs DistributedSampler internal asserts if len(dataset) * 2 < number of GPUs Sep 28, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants