Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][DataLoader] Prototype of SamplerIterableDataset #49363

Closed
wants to merge 7 commits into from
23 changes: 21 additions & 2 deletions test/test_dataset.py
Expand Up @@ -3,10 +3,10 @@

import torch
from torch.testing._internal.common_utils import (TestCase, run_tests)
from torch.utils.data import IterableDataset
from torch.utils.data import IterableDataset, RandomSampler
from torch.utils.data.datasets import \
(CollateIterableDataset, BatchIterableDataset, ListDirFilesIterableDataset,
LoadFilesFromDiskIterableDataset)
LoadFilesFromDiskIterableDataset, SamplerIterableDataset)


def create_temp_dir_and_files():
Expand Down Expand Up @@ -140,6 +140,25 @@ def test_batch_dataset(self):
with self.assertRaises(NotImplementedError):
len(batch_ds_nolen)

def test_sampler_dataset(self):
arrs = range(10)
ds = IterDatasetWithLen(arrs)
# Default SequentialSampler
sampled_ds = SamplerIterableDataset(ds)
self.assertEqual(len(sampled_ds), 10)
i = 0
for x in sampled_ds:
self.assertEqual(x, i)
i += 1

# RandomSampler
random_sampled_ds = SamplerIterableDataset(ds, sampler=RandomSampler, replacement=True)

# Requires `__len__` to build SamplerDataset
ds_nolen = IterDatasetWithoutLen(arrs)
with self.assertRaises(AssertionError):
sampled_ds = SamplerIterableDataset(ds_nolen)


if __name__ == '__main__':
run_tests()
4 changes: 2 additions & 2 deletions torch/utils/data/__init__.py
Expand Up @@ -3,11 +3,11 @@
Subset, random_split)
from .distributed import DistributedSampler
from .dataloader import DataLoader, _DatasetKind, get_worker_info
from .datasets import BatchIterableDataset, CollateIterableDataset
from .datasets import (BatchIterableDataset, CollateIterableDataset, SamplerIterableDataset)

__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler',
'SubsetRandomSampler', 'WeightedRandomSampler', 'BatchSampler',
'DistributedSampler', 'Dataset', 'IterableDataset', 'TensorDataset',
'ConcatDataset', 'ChainDataset', 'BufferedShuffleDataset', 'Subset',
'random_split', 'DataLoader', '_DatasetKind', 'get_worker_info',
'BatchIterableDataset', 'CollateIterableDataset']
'BatchIterableDataset', 'CollateIterableDataset', 'SamplerIterableDataset']
3 changes: 2 additions & 1 deletion torch/utils/data/datasets/__init__.py
@@ -1,7 +1,8 @@
from .batchdataset import BatchIterableDataset
from .collatedataset import CollateIterableDataset
from .samplerdataset import SamplerIterableDataset
from .listdirfilesdataset import ListDirFilesIterableDataset
from .loadfilesfromdiskdataset import LoadFilesFromDiskIterableDataset

__all__ = ['BatchIterableDataset', 'CollateIterableDataset', 'ListDirFilesIterableDataset',
'LoadFilesFromDiskIterableDataset']
'LoadFilesFromDiskIterableDataset', 'SamplerIterableDataset']
38 changes: 38 additions & 0 deletions torch/utils/data/datasets/samplerdataset.py
@@ -0,0 +1,38 @@
from torch.utils.data import IterableDataset, Sampler, SequentialSampler
from typing import TypeVar, Type, Iterator, Sized

T_co = TypeVar('T_co', covariant=True)


class SamplerIterableDataset(IterableDataset[T_co]):
r""" :class:`SamplerIterableDataset`.

IterableDataset to generate sample elements.
args:
dataset: IterableDataset sampled from
sampler: Sampler class to genereate sample elements from input dataset.
Default is :class:`SequentialSampler` for IterableDataset
"""
dataset: IterableDataset
sampler: Sampler

def __init__(self,
dataset: IterableDataset,
*,
sampler: Type[Sampler] = SequentialSampler,
**kwargs
) -> None:
assert isinstance(dataset, Sized), \
"Sampler class requires input dataset implemented `__len__`"
self.dataset = dataset
# https://github.com/python/mypy/pull/9629 will solve
self.sampler = sampler(data_source=self.dataset, **kwargs) # type: ignore

def __iter__(self) -> Iterator[T_co]:
return iter(self.sampler)

def __len__(self) -> int:
# Dataset has been tested as `Sized`
if isinstance(self.sampler, Sized) and len(self.sampler) >= 0:
return len(self.sampler)
raise NotImplementedError