diff --git a/test/test_datapipe.py b/test/test_datapipe.py index ec228c0086f1..00bdeb1306c0 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -238,6 +238,8 @@ def __len__(self): def _fake_fn(self, data, *args, **kwargs): return data +def _fake_filter_fn(self, data, *args, **kwargs): + return data >= 5 class TestFunctionalIterDataPipe(TestCase): @@ -248,6 +250,7 @@ def test_picklable(self): (dp.iter.Callable, IDP(arr), [0], {'fn': _fake_fn, 'test': True}), (dp.iter.Collate, IDP(arr), [], {}), (dp.iter.Collate, IDP(arr), [0], {'collate_fn': _fake_fn, 'test': True}), + (dp.iter.Filter, IDP(arr), [0], {'filter_fn': _fake_filter_fn, 'test': True}), ] for dpipe, input_dp, args, kargs in picklable_datapipes: p = pickle.dumps(dpipe(input_dp, *args, **kargs)) # type: ignore @@ -255,6 +258,7 @@ def test_picklable(self): unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, List, Dict[str, Any]]] = [ (dp.iter.Callable, IDP(arr), [], {'fn': lambda x: x}), (dp.iter.Collate, IDP(arr), [], {'collate_fn': lambda x: x}), + (dp.iter.Filter, IDP(arr), [], {'filter_fn': lambda x: x >= 5}), ] for dpipe, input_dp, args, kargs in unpicklable_datapipes: with self.assertRaises(AttributeError): @@ -381,6 +385,32 @@ def _sort_fn(data): _helper(batch_size=7, drop_last=False, bucket_size_mul=5, sort_key=_sort_fn) _helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=_sort_fn) + def test_filter_datapipe(self): + input_ds = IDP(range(10)) + + def _filter_fn(data, val, clip=False): + if clip: + return data >= val + return True + + filter_dp = dp.iter.Filter(input_ds, 5, filter_fn=_filter_fn) + for data, exp in zip(filter_dp, range(10)): + self.assertEqual(data, exp) + + filter_dp = dp.iter.Filter(input_ds, filter_fn=_filter_fn, val=5, clip=True) + for data, exp in zip(filter_dp, range(5, 10)): + self.assertEqual(data, exp) + + with self.assertRaises(NotImplementedError): + len(filter_dp) + + def _non_bool_fn(data): + return 1 + + filter_dp = dp.iter.Filter(input_ds, filter_fn=_non_bool_fn) + with self.assertRaises(ValueError): + temp = list(d for d in filter_dp) + def test_sampler_datapipe(self): arrs = range(10) input_dp = IDP(arrs) diff --git a/torch/utils/data/datapipes/iter/__init__.py b/torch/utils/data/datapipes/iter/__init__.py index dcfe3474aaf3..76030963f37a 100755 --- a/torch/utils/data/datapipes/iter/__init__.py +++ b/torch/utils/data/datapipes/iter/__init__.py @@ -7,8 +7,13 @@ # Functional DataPipe from torch.utils.data.datapipes.iter.batch import BatchIterDataPipe as Batch, BucketBatchIterDataPipe as BucketBatch -from torch.utils.data.datapipes.iter.callable import CallableIterDataPipe as Callable, CollateIterDataPipe as Collate -from torch.utils.data.datapipes.iter.sampler import SamplerIterDataPipe as Sampler +from torch.utils.data.datapipes.iter.callable import \ + (CallableIterDataPipe as Callable, CollateIterDataPipe as Collate) +from torch.utils.data.datapipes.iter.selecting import \ + (FilterIterDataPipe as Filter) +from torch.utils.data.datapipes.iter.sampler import \ + (SamplerIterDataPipe as Sampler) + __all__ = ['ListDirFiles', 'LoadFilesFromDisk', 'ReadFilesFormTar', 'ReadFilesFromZip', 'RoutedDecoder', 'GroupByKey', - 'Batch', 'BucketBatch', 'Callable', 'Collate', 'Sampler'] + 'Batch', 'BucketBatch', 'Callable', 'Collate', 'Filter', 'Sampler'] diff --git a/torch/utils/data/datapipes/iter/selecting.py b/torch/utils/data/datapipes/iter/selecting.py new file mode 100644 index 000000000000..78a2a74d31f3 --- /dev/null +++ b/torch/utils/data/datapipes/iter/selecting.py @@ -0,0 +1,36 @@ +from torch.utils.data import IterDataPipe +from typing import Callable, TypeVar, Iterator + +from .callable import CallableIterDataPipe + +T_co = TypeVar('T_co', covariant=True) + + +class FilterIterDataPipe(CallableIterDataPipe[T_co]): + r""" :class:`FilterIterDataPipe`. + + Iterable DataPipe to filter elements from datapipe according to filter_fn. + args: + datapipe: Iterable DataPipe being filterd + filter_fn: Customized function mapping an element to a boolean. + """ + def __init__(self, + datapipe: IterDataPipe[T_co], + *args, + filter_fn: Callable[..., bool], + **kwargs, + ) -> None: + super().__init__(datapipe, *args, fn=filter_fn, **kwargs) + + def __iter__(self) -> Iterator[T_co]: + res: bool + for data in self.datapipe: + res = self.fn(data, *self.args, **self.kwargs) + if not isinstance(res, bool): + raise ValueError("Boolean output is required for " + "`filter_fn` of FilterIterDataPipe") + if res: + yield data + + def __len__(self): + raise(NotImplementedError)