Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DataLoader] Implement FilterIterDataPipe (#51783)
Summary: Pull Request resolved: #51783 Test Plan: Imported from OSS Reviewed By: glaringlee Differential Revision: D26277688 Pulled By: ejguan fbshipit-source-id: 25ed7da9da88c030b29627142c2f04fed26cdcda
- Loading branch information
1 parent
e964d77
commit 104371e
Showing
3 changed files
with
74 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from torch.utils.data import IterDataPipe | ||
from typing import Callable, TypeVar, Iterator | ||
|
||
from .callable import CallableIterDataPipe | ||
|
||
T_co = TypeVar('T_co', covariant=True) | ||
|
||
|
||
class FilterIterDataPipe(CallableIterDataPipe[T_co]): | ||
r""" :class:`FilterIterDataPipe`. | ||
Iterable DataPipe to filter elements from datapipe according to filter_fn. | ||
args: | ||
datapipe: Iterable DataPipe being filterd | ||
filter_fn: Customized function mapping an element to a boolean. | ||
""" | ||
def __init__(self, | ||
datapipe: IterDataPipe[T_co], | ||
*args, | ||
filter_fn: Callable[..., bool], | ||
**kwargs, | ||
) -> None: | ||
super().__init__(datapipe, *args, fn=filter_fn, **kwargs) | ||
|
||
def __iter__(self) -> Iterator[T_co]: | ||
res: bool | ||
for data in self.datapipe: | ||
res = self.fn(data, *self.args, **self.kwargs) | ||
if not isinstance(res, bool): | ||
raise ValueError("Boolean output is required for " | ||
"`filter_fn` of FilterIterDataPipe") | ||
if res: | ||
yield data | ||
|
||
def __len__(self): | ||
raise(NotImplementedError) |