Skip to content

Commit

Permalink
[DataLoader] Move BufferedShuffle from Dataset to DataPipe (#52141)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #52141

Remove BufferShuffleDataSet, as it's not being used anywhere within PyTorch (no usage on Github based on a search) and it's not included in the release of PyTorch 1.7.1.

Test Plan: Imported from OSS

Reviewed By: H-Huang

Differential Revision: D26710940

Pulled By: ejguan

fbshipit-source-id: 90023b4bfb105d6aa392753082100f9181ecebd0
  • Loading branch information
ejguan authored and facebook-github-bot committed Mar 1, 2021
1 parent f2657d2 commit 89b1053
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 105 deletions.
1 change: 0 additions & 1 deletion docs/source/data.rst
Expand Up @@ -405,7 +405,6 @@ Example::
.. autoclass:: TensorDataset
.. autoclass:: ConcatDataset
.. autoclass:: ChainDataset
.. autoclass:: BufferedShuffleDataset
.. autoclass:: Subset
.. autofunction:: torch.utils.data.get_worker_info
.. autofunction:: torch.utils.data.random_split
Expand Down
39 changes: 1 addition & 38 deletions test/test_dataloader.py
Expand Up @@ -12,10 +12,8 @@
import itertools
import warnings
import tempfile
import random
from torch import multiprocessing as mp
from torch.utils.data import (_utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset,
ChainDataset, BufferedShuffleDataset)
from torch.utils.data import _utils, Dataset, IterableDataset, TensorDataset, DataLoader, ConcatDataset, ChainDataset
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data.dataset import random_split
from torch._utils import ExceptionWrapper
Expand Down Expand Up @@ -724,10 +722,6 @@ def init_fn(worker_id):
torch.manual_seed(12345)


def shuffle_ds_init_fn(worker_id):
random.seed(123)


# used with test_error_in_init
class ErrorIterableDataset(IterableDataset):
def __iter__(self):
Expand Down Expand Up @@ -1245,37 +1239,6 @@ def test_chain_iterable_style_dataset(self):
with self.assertRaisesRegex(AssertionError, "ChainDataset only supports IterableDataset"):
list(iter(ChainDataset([dataset1, self.dataset])))

def test_buffer_shuffle_dataset(self):
dataset = CountingIterableDataset(20)
expected = list(range(20))
buffer_sizes = [5, 20, 25]
for num_workers in [0, 1]:
# Buffer Size <= 1: Not shuffled dataset
fetched_nos = list(self._get_data_loader(BufferedShuffleDataset(dataset, 1), num_workers=num_workers))
self.assertEqual(len(fetched_nos), len(expected))
for e, d in zip(expected, fetched_nos):
self.assertIsInstance(d, torch.Tensor)
self.assertEqual(e, d)
# Buffer Size > 1: Shuffled dataset
for buffer_size in buffer_sizes:
fetched = sorted(list(self._get_data_loader(BufferedShuffleDataset(dataset, buffer_size), num_workers=num_workers)))
self.assertEqual(len(fetched), len(expected))
for e, d in zip(expected, fetched):
self.assertIsInstance(d, torch.Tensor)
self.assertEqual(e, d)
# Random Seed for single process
random.seed(123)
fetched_seed1 = list(self._get_data_loader(BufferedShuffleDataset(dataset, buffer_size), num_workers=num_workers,
worker_init_fn=shuffle_ds_init_fn))
random.seed(123)
fetched_seed2 = list(self._get_data_loader(BufferedShuffleDataset(dataset, buffer_size), num_workers=num_workers,
worker_init_fn=shuffle_ds_init_fn))
self.assertEqual(len(fetched_seed1), len(fetched_seed2))
for d1, d2 in zip(fetched_seed1, fetched_seed2):
self.assertIsInstance(d1, torch.Tensor)
self.assertIsInstance(d2, torch.Tensor)
self.assertEqual(d1, d2)

def test_multiprocessing_contexts(self):
reference = [
torch.arange(3),
Expand Down
34 changes: 32 additions & 2 deletions test/test_datapipe.py
Expand Up @@ -12,11 +12,10 @@
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import (TestCase, run_tests)
from torch.utils.data import IterDataPipe, RandomSampler
from torch.utils.data import IterDataPipe, RandomSampler, DataLoader
from typing import List, Tuple, Dict, Any, Type

import torch.utils.data.datapipes as dp

from torch.utils.data.datapipes.utils.decoder import (
basichandlers as decoder_basichandlers,
imagehandler as decoder_imagehandler)
Expand Down Expand Up @@ -250,6 +249,10 @@ def _fake_fn(data, *args, **kwargs):
def _fake_filter_fn(data, *args, **kwargs):
return data >= 5

def _worker_init_fn(worker_id):
random.seed(123)


class TestFunctionalIterDataPipe(TestCase):

def test_picklable(self):
Expand Down Expand Up @@ -446,6 +449,33 @@ def test_sampler_datapipe(self):
with self.assertRaises(AssertionError):
sampled_dp = dp.iter.Sampler(input_dp_nolen)

def test_shuffle_datapipe(self):
exp = list(range(20))
input_ds = IDP(exp)

with self.assertRaises(AssertionError):
shuffle_dp = dp.iter.Shuffle(input_ds, buffer_size=0)


for bs in (5, 20, 25):
shuffle_dp = dp.iter.Shuffle(input_ds, buffer_size=bs)
self.assertEqual(len(shuffle_dp), len(input_ds))

random.seed(123)
res = list(d for d in shuffle_dp)
self.assertEqual(sorted(res), exp)

# Test Deterministic
for num_workers in (0, 1):
random.seed(123)
dl = DataLoader(shuffle_dp, num_workers=num_workers, worker_init_fn=_worker_init_fn)
dl_res = list(d for d in dl)
self.assertEqual(res, dl_res)

shuffle_dp_nl = dp.iter.Shuffle(IDP_NoLen(range(20)), buffer_size=5)
with self.assertRaises(NotImplementedError):
len(shuffle_dp_nl)

@skipIfNoTorchVision
def test_transforms_datapipe(self):
torch.set_default_dtype(torch.float)
Expand Down
6 changes: 3 additions & 3 deletions torch/utils/data/__init__.py
@@ -1,5 +1,5 @@
from .sampler import Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler, BatchSampler
from .dataset import (Dataset, IterableDataset, TensorDataset, ConcatDataset, ChainDataset, BufferedShuffleDataset,
from .dataset import (Dataset, IterableDataset, TensorDataset, ConcatDataset, ChainDataset,
Subset, random_split)
from .dataset import IterableDataset as IterDataPipe
from .dataset import functional_datapipe
Expand All @@ -10,6 +10,6 @@
__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler',
'SubsetRandomSampler', 'WeightedRandomSampler', 'BatchSampler',
'DistributedSampler', 'Dataset', 'IterableDataset', 'TensorDataset',
'ConcatDataset', 'ChainDataset', 'BufferedShuffleDataset', 'Subset',
'random_split', 'DataLoader', '_DatasetKind', 'get_worker_info',
'ConcatDataset', 'ChainDataset', 'Subset', 'random_split',
'DataLoader', '_DatasetKind', 'get_worker_info',
'IterDataPipe', 'functional_datapipe']
4 changes: 2 additions & 2 deletions torch/utils/data/datapipes/iter/__init__.py
Expand Up @@ -9,12 +9,12 @@
from torch.utils.data.datapipes.iter.callable import \
(MapIterDataPipe as Map, CollateIterDataPipe as Collate, TransformsIterDataPipe as Transforms)
from torch.utils.data.datapipes.iter.combinatorics import \
(SamplerIterDataPipe as Sampler)
(SamplerIterDataPipe as Sampler, ShuffleIterDataPipe as Shuffle)
from torch.utils.data.datapipes.iter.grouping import \
(BatchIterDataPipe as Batch, BucketBatchIterDataPipe as BucketBatch)
from torch.utils.data.datapipes.iter.selecting import \
(FilterIterDataPipe as Filter)


__all__ = ['ListDirFiles', 'LoadFilesFromDisk', 'ReadFilesFromTar', 'ReadFilesFromZip', 'RoutedDecoder', 'GroupByKey',
'Batch', 'BucketBatch', 'Collate', 'Filter', 'Map', 'Sampler', 'Transforms']
'Batch', 'BucketBatch', 'Collate', 'Filter', 'Map', 'Sampler', 'Shuffle', 'Transforms']
59 changes: 58 additions & 1 deletion torch/utils/data/datapipes/iter/combinatorics.py
@@ -1,5 +1,7 @@
import random

from torch.utils.data import IterDataPipe, Sampler, SequentialSampler
from typing import TypeVar, Type, Iterator, Sized, Optional, Tuple, Dict
from typing import TypeVar, Type, Iterator, Sized, Optional, Tuple, Dict, List

T_co = TypeVar('T_co', covariant=True)

Expand Down Expand Up @@ -39,3 +41,58 @@ def __len__(self) -> int:
if isinstance(self.sampler, Sized) and len(self.sampler) >= 0:
return len(self.sampler)
raise NotImplementedError


class ShuffleIterDataPipe(IterDataPipe[T_co]):
r""" :class:`ShuffleIterDataPipe`
Iterable DataPipe to shuffle the input DataPipe with a buffer. The buffer
with `buffer_size` is filled with elements from the datapipe first. Then,
each item will be yielded from the buffer by reservoir sampling via iterator.
`buffer_size` is required to be larger than 0. For `buffer_size == 1`, the
datapipe is not shuffled. In order to fully shuffle all elements from datapipe,
`buffer_size` is required to be greater than or equal to the size of datapipe.
When it is used with :class:`~torch.utils.data.DataLoader`, the methods to
set up random seed are different based on :attr:`num_workers`.
For single-process mode (:attr:`num_workers == 0`), the random seed is set before
the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed
for each worker process.
args:
datapipe: The IterDataPipe being shuffled
buffer_size: The buffer size for shuffling
"""
datapipe: IterDataPipe[T_co]
buffer_size: int
_buffer: List[T_co]

def __init__(self,
datapipe: IterDataPipe[T_co],
*,
buffer_size: int) -> None:
super().__init__()
assert buffer_size > 0, "buffer_size should be larger than 0"
self.datapipe = datapipe
self.buffer_size = buffer_size
self._buffer = []

def __iter__(self) -> Iterator[T_co]:
for x in self.datapipe:
if len(self._buffer) == self.buffer_size:
idx = random.randint(0, self.buffer_size - 1)
yield self._buffer[idx]
self._buffer[idx] = x
else:
self._buffer.append(x)
random.shuffle(self._buffer)
while self._buffer:
yield self._buffer.pop()

def __len__(self) -> int:
if isinstance(self.datapipe, Sized) and len(self.datapipe) >= 0:
return len(self.datapipe)
raise NotImplementedError
58 changes: 0 additions & 58 deletions torch/utils/data/dataset.py
@@ -1,5 +1,4 @@
import bisect
import random
import warnings
import functools

Expand Down Expand Up @@ -287,63 +286,6 @@ def __len__(self):
return total


class BufferedShuffleDataset(IterableDataset[T_co]):
r"""Dataset shuffled from the original dataset.
This class is useful to shuffle an existing instance of an IterableDataset.
The buffer with `buffer_size` is filled with the items from the dataset first. Then,
each item will be yielded from the buffer by reservoir sampling via iterator.
`buffer_size` is required to be larger than 0. For `buffer_size == 1`, the
dataset is not shuffled. In order to fully shuffle the whole dataset, `buffer_size`
is required to be greater than or equal to the size of dataset.
When it is used with :class:`~torch.utils.data.DataLoader`, each item in the
dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator.
And, the method to set up a random seed is different based on :attr:`num_workers`.
For single-process mode (:attr:`num_workers == 0`), the random seed is required to
be set before the :class:`~torch.utils.data.DataLoader` in the main process.
>>> ds = BufferedShuffleDataset(dataset)
>>> random.seed(...)
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
For multi-process mode (:attr:`num_workers > 0`), the random seed is set by a callable
function in each worker.
>>> ds = BufferedShuffleDataset(dataset)
>>> def init_fn(worker_id):
... random.seed(...)
>>> print(list(torch.utils.data.DataLoader(ds, ..., num_workers=n, worker_init_fn=init_fn)))
Args:
dataset (IterableDataset): The original IterableDataset.
buffer_size (int): The buffer size for shuffling.
"""
dataset: IterableDataset[T_co]
buffer_size: int

def __init__(self, dataset: IterableDataset[T_co], buffer_size: int) -> None:
super(BufferedShuffleDataset, self).__init__()
assert buffer_size > 0, "buffer_size should be larger than 0"
self.dataset = dataset
self.buffer_size = buffer_size

def __iter__(self) -> Iterator[T_co]:
buf: List[T_co] = []
for x in self.dataset:
if len(buf) == self.buffer_size:
idx = random.randint(0, self.buffer_size - 1)
yield buf[idx]
buf[idx] = x
else:
buf.append(x)
random.shuffle(buf)
while buf:
yield buf.pop()


class Subset(Dataset[T_co]):
r"""
Subset of a dataset at specified indices.
Expand Down

0 comments on commit 89b1053

Please sign in to comment.