Skip to content
45 changes: 44 additions & 1 deletion test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from torch.testing._internal.common_utils import (TestCase, run_tests)
from torch.utils.data import IterableDataset
from torch.utils.data.datasets import \
(CollateIterableDataset, ListDirFilesIterableDataset, LoadFilesFromDiskIterableDataset)
(CollateIterableDataset, BatchIterableDataset, ListDirFilesIterableDataset,
LoadFilesFromDiskIterableDataset)


def create_temp_dir_and_files():
Expand Down Expand Up @@ -97,6 +98,48 @@ def _collate_fn(batch):
y = next(ds_nolen_iter)
self.assertEqual(x, torch.tensor(y))

def test_batch_dataset(self):
arrs = range(10)
ds = IterDatasetWithLen(arrs)
with self.assertRaises(AssertionError):
batch_ds0 = BatchIterableDataset(ds, batch_size=0)

# Default not drop the last batch
batch_ds1 = BatchIterableDataset(ds, batch_size=3)
self.assertEqual(len(batch_ds1), 4)
batch_iter = iter(batch_ds1)
value = 0
for i in range(len(batch_ds1)):
batch = next(batch_iter)
if i == 3:
self.assertEqual(len(batch), 1)
self.assertEqual(batch, [9])
else:
self.assertEqual(len(batch), 3)
for x in batch:
self.assertEqual(x, value)
value += 1

# Drop the last batch
batch_ds2 = BatchIterableDataset(ds, batch_size=3, drop_last=True)
self.assertEqual(len(batch_ds2), 3)
value = 0
for batch in batch_ds2:
self.assertEqual(len(batch), 3)
for x in batch:
self.assertEqual(x, value)
value += 1

batch_ds3 = BatchIterableDataset(ds, batch_size=2)
self.assertEqual(len(batch_ds3), 5)
batch_ds4 = BatchIterableDataset(ds, batch_size=2, drop_last=True)
self.assertEqual(len(batch_ds4), 5)

ds_nolen = IterDatasetWithoutLen(arrs)
batch_ds_nolen = BatchIterableDataset(ds_nolen, batch_size=5)
with self.assertRaises(NotImplementedError):
len(batch_ds_nolen)


if __name__ == '__main__':
run_tests()
4 changes: 2 additions & 2 deletions torch/utils/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
Subset, random_split)
from .distributed import DistributedSampler
from .dataloader import DataLoader, _DatasetKind, get_worker_info
from .datasets import CollateIterableDataset
from .datasets import BatchIterableDataset, CollateIterableDataset

__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler',
'SubsetRandomSampler', 'WeightedRandomSampler', 'BatchSampler',
'DistributedSampler', 'Dataset', 'IterableDataset', 'TensorDataset',
'ConcatDataset', 'ChainDataset', 'BufferedShuffleDataset', 'Subset',
'random_split', 'DataLoader', '_DatasetKind', 'get_worker_info',
'CollateIterableDataset']
'BatchIterableDataset', 'CollateIterableDataset']
4 changes: 3 additions & 1 deletion torch/utils/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .batchdataset import BatchIterableDataset
from .collatedataset import CollateIterableDataset
from .listdirfilesdataset import ListDirFilesIterableDataset
from .loadfilesfromdiskdataset import LoadFilesFromDiskIterableDataset

__all__ = ['CollateIterableDataset', 'ListDirFilesIterableDataset', 'LoadFilesFromDiskIterableDataset']
__all__ = ['BatchIterableDataset', 'CollateIterableDataset', 'ListDirFilesIterableDataset',
'LoadFilesFromDiskIterableDataset']
57 changes: 57 additions & 0 deletions torch/utils/data/datasets/batchdataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from torch.utils.data import IterableDataset
from typing import TypeVar, Optional, Iterator, List, Sized

T_co = TypeVar('T_co', covariant=True)


class BatchIterableDataset(IterableDataset[List[T_co]]):
r""" :class:`BatchIterableDataset`.

IterableDataset to create mini-batches of data. An outer dimension will be added as
`batch_size` if `drop_last` is set to `True`, or `length % batch_size` for the
last batch if `drop_last` is set to `False`.
args:
dataset: IterableDataset being batched
batch_size: The size of each batch
drop_last: Option to drop the last batch if it's not full
"""
dataset: IterableDataset[T_co]
batch_size: int
drop_last: bool
length: Optional[int]

def __init__(self,
dataset: IterableDataset[T_co],
*,
batch_size: int,
drop_last: bool = False,
) -> None:
assert batch_size > 0, "Batch size is required to be larger than 0!"
super(BatchIterableDataset, self).__init__()
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.length = None

def __iter__(self) -> Iterator[List[T_co]]:
batch: List[T_co] = []
for x in self.dataset:
batch.append(x)
if len(batch) == self.batch_size:
yield batch
batch.clear()
if len(batch) > 0:
if not self.drop_last:
yield batch
batch.clear()

def __len__(self) -> int:
if self.length is not None:
return self.length
if isinstance(self.dataset, Sized) and len(self.dataset) >= 0:
if self.drop_last:
self.length = len(self.dataset) // self.batch_size
else:
self.length = (len(self.dataset) + self.batch_size - 1) // self.batch_size
return self.length
raise NotImplementedError