Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Length check fails for IterableDataset #1076

Closed
snie2012 opened this issue May 27, 2020 · 4 comments · Fixed by #1077
Closed

Length check fails for IterableDataset #1076

snie2012 opened this issue May 27, 2020 · 4 comments · Fixed by #1077
Assignees
Labels

Comments

@snie2012
Copy link

snie2012 commented May 27, 2020

I am using IterableDataset with unknown length. It fails inside the following if clause:

if hasattr(data, "__len__"):

Here is the logic why it fails:

  1. Data is PyTorch DataLoader, which has the __len__ attr, so it goes into the if clause
  2. It runs the code epoch_length = len(data), and leads to length = self._IterableDataset_len_called = len(self.dataset) in dataloader.py in pytorch
  3. Since dataset is an IterableDataset, it doesn't have length so it throws sth like TypeError: object of type 'IterableDataset' has no len()

My concern is that we shouldn't expect IterableDataset to have __len__ and the code should not fail because of it. Any thoughts on this?

@vfdev-5 vfdev-5 added the bug label May 28, 2020
@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 28, 2020

@snie2012 thanks for the report! Yes, you are right about that. We have a test with IterableDataset but with defined epoch_length:

def test_engine_with_iterable_dataloader():
Yes, definitely, we need to support IterableDataset with unknown length !

@vfdev-5 vfdev-5 self-assigned this May 28, 2020
vfdev-5 added a commit to vfdev-5/ignite that referenced this issue May 28, 2020
@vfdev-5 vfdev-5 mentioned this issue May 28, 2020
3 tasks
sdesrozis pushed a commit that referenced this issue May 28, 2020
@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 28, 2020

@snie2012 the bug should be fixed in the next nightly release

@SantoshGuptaML
Copy link

Should DistributedSampler also support IterableDataset ? I get this error message

    train_sampler = DistributedSampler(TrainDataset(config))
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/distributed.py", line 63, in __init__
    self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
TypeError: object of type 'TrainDataset' has no len()

Where TrainDataset is an extension of IterableDataset

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Apr 29, 2021

@SantoshGuptaML no, torch's DistributedSampler does not support iterable datasets, only map-style datasets.

There are various way to route samples depending on process (distributed or simply mp):

In case of distributed training, the simplest way to distribute it (if possible) is to set up your iterable dataset depending on the rank:

import ignite.distributed as idist
rank = idist.get_rank()
ws = idist.get_world_size()
dataset = get_iterable_dataset(..., rank=rank, world_size=ws, ...)

where get_iterable_dataset is your custom method that fetches unique data for a rank.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants