Skip to content

Commit

Permalink
[DataLoader] Implement FilterIterDataPipe (#51783)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #51783

Test Plan: Imported from OSS

Reviewed By: glaringlee

Differential Revision: D26277688

Pulled By: ejguan

fbshipit-source-id: 25ed7da9da88c030b29627142c2f04fed26cdcda
  • Loading branch information
ejguan authored and facebook-github-bot committed Feb 10, 2021
1 parent e964d77 commit 104371e
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 3 deletions.
30 changes: 30 additions & 0 deletions test/test_datapipe.py
Expand Up @@ -238,6 +238,8 @@ def __len__(self):
def _fake_fn(self, data, *args, **kwargs):
return data

def _fake_filter_fn(self, data, *args, **kwargs):
return data >= 5

class TestFunctionalIterDataPipe(TestCase):

Expand All @@ -248,13 +250,15 @@ def test_picklable(self):
(dp.iter.Callable, IDP(arr), [0], {'fn': _fake_fn, 'test': True}),
(dp.iter.Collate, IDP(arr), [], {}),
(dp.iter.Collate, IDP(arr), [0], {'collate_fn': _fake_fn, 'test': True}),
(dp.iter.Filter, IDP(arr), [0], {'filter_fn': _fake_filter_fn, 'test': True}),
]
for dpipe, input_dp, args, kargs in picklable_datapipes:
p = pickle.dumps(dpipe(input_dp, *args, **kargs)) # type: ignore

unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, List, Dict[str, Any]]] = [
(dp.iter.Callable, IDP(arr), [], {'fn': lambda x: x}),
(dp.iter.Collate, IDP(arr), [], {'collate_fn': lambda x: x}),
(dp.iter.Filter, IDP(arr), [], {'filter_fn': lambda x: x >= 5}),
]
for dpipe, input_dp, args, kargs in unpicklable_datapipes:
with self.assertRaises(AttributeError):
Expand Down Expand Up @@ -381,6 +385,32 @@ def _sort_fn(data):
_helper(batch_size=7, drop_last=False, bucket_size_mul=5, sort_key=_sort_fn)
_helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=_sort_fn)

def test_filter_datapipe(self):
input_ds = IDP(range(10))

def _filter_fn(data, val, clip=False):
if clip:
return data >= val
return True

filter_dp = dp.iter.Filter(input_ds, 5, filter_fn=_filter_fn)
for data, exp in zip(filter_dp, range(10)):
self.assertEqual(data, exp)

filter_dp = dp.iter.Filter(input_ds, filter_fn=_filter_fn, val=5, clip=True)
for data, exp in zip(filter_dp, range(5, 10)):
self.assertEqual(data, exp)

with self.assertRaises(NotImplementedError):
len(filter_dp)

def _non_bool_fn(data):
return 1

filter_dp = dp.iter.Filter(input_ds, filter_fn=_non_bool_fn)
with self.assertRaises(ValueError):
temp = list(d for d in filter_dp)

def test_sampler_datapipe(self):
arrs = range(10)
input_dp = IDP(arrs)
Expand Down
11 changes: 8 additions & 3 deletions torch/utils/data/datapipes/iter/__init__.py
Expand Up @@ -7,8 +7,13 @@

# Functional DataPipe
from torch.utils.data.datapipes.iter.batch import BatchIterDataPipe as Batch, BucketBatchIterDataPipe as BucketBatch
from torch.utils.data.datapipes.iter.callable import CallableIterDataPipe as Callable, CollateIterDataPipe as Collate
from torch.utils.data.datapipes.iter.sampler import SamplerIterDataPipe as Sampler
from torch.utils.data.datapipes.iter.callable import \
(CallableIterDataPipe as Callable, CollateIterDataPipe as Collate)
from torch.utils.data.datapipes.iter.selecting import \
(FilterIterDataPipe as Filter)
from torch.utils.data.datapipes.iter.sampler import \
(SamplerIterDataPipe as Sampler)


__all__ = ['ListDirFiles', 'LoadFilesFromDisk', 'ReadFilesFormTar', 'ReadFilesFromZip', 'RoutedDecoder', 'GroupByKey',
'Batch', 'BucketBatch', 'Callable', 'Collate', 'Sampler']
'Batch', 'BucketBatch', 'Callable', 'Collate', 'Filter', 'Sampler']
36 changes: 36 additions & 0 deletions torch/utils/data/datapipes/iter/selecting.py
@@ -0,0 +1,36 @@
from torch.utils.data import IterDataPipe
from typing import Callable, TypeVar, Iterator

from .callable import CallableIterDataPipe

T_co = TypeVar('T_co', covariant=True)


class FilterIterDataPipe(CallableIterDataPipe[T_co]):
r""" :class:`FilterIterDataPipe`.
Iterable DataPipe to filter elements from datapipe according to filter_fn.
args:
datapipe: Iterable DataPipe being filterd
filter_fn: Customized function mapping an element to a boolean.
"""
def __init__(self,
datapipe: IterDataPipe[T_co],
*args,
filter_fn: Callable[..., bool],
**kwargs,
) -> None:
super().__init__(datapipe, *args, fn=filter_fn, **kwargs)

def __iter__(self) -> Iterator[T_co]:
res: bool
for data in self.datapipe:
res = self.fn(data, *self.args, **self.kwargs)
if not isinstance(res, bool):
raise ValueError("Boolean output is required for "
"`filter_fn` of FilterIterDataPipe")
if res:
yield data

def __len__(self):
raise(NotImplementedError)

0 comments on commit 104371e

Please sign in to comment.