Skip to content

Commit

Permalink
[WIP][DataLoader] Prototype of SamplerIterableDataset
Browse files Browse the repository at this point in the history
ghstack-source-id: e2099d04d07f6da90f1a7a5da039a24f12f4d56a
Pull Request resolved: #49363
  • Loading branch information
ejguan committed Dec 14, 2020
1 parent be5f2f0 commit 62df863
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 4 deletions.
18 changes: 17 additions & 1 deletion test/test_dataset.py
@@ -1,6 +1,7 @@
import torch
from torch.utils.data import IterableDataset
from torch.utils.data.datasets import (CollateIterableDataset, BatchIterableDataset)
from torch.utils.data.datasets import (
CollateIterableDataset, BatchIterableDataset, SamplerIterableDataset)
from torch.testing._internal.common_utils import (TestCase, run_tests)


Expand Down Expand Up @@ -94,6 +95,21 @@ def test_batch_dataset(self):
with self.assertRaises(NotImplementedError):
len(batch_ds_nolen)

def test_sampler_dataset(self):
arrs = range(10)
ds = IterDatasetWithLen(arrs)
sampled_ds = SamplerIterableDataset(ds)
self.assertEqual(len(sampled_ds), 10)
i = 0
for x in sampled_ds:
self.assertEqual(x, i)
i += 1

ds_nolen = IterDatasetWithoutLen(arrs)
sampled_ds = SamplerIterableDataset(ds_nolen)
with self.assertRaises(NotImplementedError):
len(sampled_ds)


if __name__ == '__main__':
run_tests()
4 changes: 2 additions & 2 deletions torch/utils/data/__init__.py
Expand Up @@ -3,11 +3,11 @@
Subset, random_split)
from .distributed import DistributedSampler
from .dataloader import DataLoader, _DatasetKind, get_worker_info
from .datasets import BatchIterableDataset, CollateIterableDataset
from .datasets import (BatchIterableDataset, CollateIterableDataset, SamplerIterableDataset)

__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler',
'SubsetRandomSampler', 'WeightedRandomSampler', 'BatchSampler',
'DistributedSampler', 'Dataset', 'IterableDataset', 'TensorDataset',
'ConcatDataset', 'ChainDataset', 'BufferedShuffleDataset', 'Subset',
'random_split', 'DataLoader', '_DatasetKind', 'get_worker_info',
'BatchIterableDataset', 'CollateIterableDataset']
'BatchIterableDataset', 'CollateIterableDataset', 'SamplerIterableDataset']
3 changes: 2 additions & 1 deletion torch/utils/data/datasets/__init__.py
@@ -1,4 +1,5 @@
from .batchdataset import BatchIterableDataset
from .collatedataset import CollateIterableDataset
from .samplerdataset import SamplerIterableDataset

__all__ = ['BatchIterableDataset', 'CollateIterableDataset']
__all__ = ['BatchIterableDataset', 'CollateIterableDataset', 'SamplerIterableDataset']
33 changes: 33 additions & 0 deletions torch/utils/data/datasets/samplerdataset.py
@@ -0,0 +1,33 @@
from torch.utils.data import IterableDataset, Sampler, SequentialSampler
from typing import TypeVar, Iterator, Sized

T_co = TypeVar('T_co', covariant=True)


class SamplerIterableDataset(IterableDataset[T_co]):
r""" Prototype of :class:`SamplerIterableDataset`.
IterableDataset to generate samples elements.
args:
dataset: IterableDataset being collated
sampler: Sampler to genereate sample elements from input dataset.
Default is :class:`SequentialSampler` for IterableDataset
"""
def __init__(self,
dataset: IterableDataset[T_co],
*,
sampler: Sampler = SequentialSampler,
) -> None:
self.dataset = dataset
self.sampler = sampler(self.dataset)

def __iter__(self) -> Iterator[T_co]:
return iter(self.sampler)

def __len__(self) -> int:
if isinstance(self.dataset, Sized) and \
isinstance(self.sampler, Sized) and \
len(self.sampler) >= 0:
return len(self.sampler)
else:
raise NotImplementedError

0 comments on commit 62df863

Please sign in to comment.