diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 00bdeb1306c09c5..dca281bf8ceb931 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -245,26 +245,26 @@ class TestFunctionalIterDataPipe(TestCase): def test_picklable(self): arr = range(10) - picklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, List, Dict[str, Any]]] = [ - (dp.iter.Callable, IDP(arr), [], {}), - (dp.iter.Callable, IDP(arr), [0], {'fn': _fake_fn, 'test': True}), - (dp.iter.Collate, IDP(arr), [], {}), - (dp.iter.Collate, IDP(arr), [0], {'collate_fn': _fake_fn, 'test': True}), - (dp.iter.Filter, IDP(arr), [0], {'filter_fn': _fake_filter_fn, 'test': True}), + picklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Dict[str, Any]]] = [ + (dp.iter.Map, IDP(arr), {}), + (dp.iter.Map, IDP(arr), {'fn': _fake_fn, 'fn_args': (0), 'fn_kwargs': {'test': True}}), + (dp.iter.Collate, IDP(arr), {}), + (dp.iter.Collate, IDP(arr), {'collate_fn': _fake_fn, 'fn_args': (0), 'fn_kwargs': {'test': True}}), + (dp.iter.Filter, IDP(arr), {'filter_fn': _fake_filter_fn, 'fn_args': (0), 'fn_kwargs': {'test': True}}), ] - for dpipe, input_dp, args, kargs in picklable_datapipes: - p = pickle.dumps(dpipe(input_dp, *args, **kargs)) # type: ignore + for dpipe, input_dp, kargs in picklable_datapipes: + p = pickle.dumps(dpipe(input_dp, **kargs)) # type: ignore - unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, List, Dict[str, Any]]] = [ - (dp.iter.Callable, IDP(arr), [], {'fn': lambda x: x}), - (dp.iter.Collate, IDP(arr), [], {'collate_fn': lambda x: x}), - (dp.iter.Filter, IDP(arr), [], {'filter_fn': lambda x: x >= 5}), + unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Dict[str, Any]]] = [ + (dp.iter.Map, IDP(arr), {'fn': lambda x: x}), + (dp.iter.Collate, IDP(arr), {'collate_fn': lambda x: x}), + (dp.iter.Filter, IDP(arr), {'filter_fn': lambda x: x >= 5}), ] - for dpipe, input_dp, args, kargs in unpicklable_datapipes: + for dpipe, input_dp, kargs in unpicklable_datapipes: with self.assertRaises(AttributeError): - p = pickle.dumps(dpipe(input_dp, *args, **kargs)) # type: ignore + p = pickle.dumps(dpipe(input_dp, **kargs)) # type: ignore - def test_callable_datapipe(self): + def test_map_datapipe(self): arr = range(10) input_dp = IDP(arr) input_dp_nl = IDP_NoLen(arr) @@ -273,20 +273,26 @@ def fn(item, dtype=torch.float, *, sum=False): data = torch.tensor(item, dtype=dtype) return data if not sum else data.sum() - callable_dp = dp.iter.Callable(input_dp, fn=fn) # type: ignore - self.assertEqual(len(input_dp), len(callable_dp)) - for x, y in zip(callable_dp, input_dp): + map_dp = dp.iter.Map(input_dp, fn=fn) # type: ignore + self.assertEqual(len(input_dp), len(map_dp)) + for x, y in zip(map_dp, input_dp): self.assertEqual(x, torch.tensor(y, dtype=torch.float)) - callable_dp = dp.iter.Callable(input_dp, torch.int, fn=fn, sum=True) # type: ignore - self.assertEqual(len(input_dp), len(callable_dp)) - for x, y in zip(callable_dp, input_dp): + map_dp = dp.iter.Map(input_dp, fn=fn, fn_args=(torch.int, ), fn_kwargs={'sum': True}) # type: ignore + self.assertEqual(len(input_dp), len(map_dp)) + for x, y in zip(map_dp, input_dp): self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) - callable_dp_nl = dp.iter.Callable(input_dp_nl) # type: ignore + from functools import partial + map_dp = dp.iter.Map(input_dp, fn=partial(fn, dtype=torch.int, sum=True)) # type: ignore + self.assertEqual(len(input_dp), len(map_dp)) + for x, y in zip(map_dp, input_dp): + self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) + + map_dp_nl = dp.iter.Map(input_dp_nl) # type: ignore with self.assertRaises(NotImplementedError): - len(callable_dp_nl) - for x, y in zip(callable_dp_nl, input_dp_nl): + len(map_dp_nl) + for x, y in zip(map_dp_nl, input_dp_nl): self.assertEqual(x, torch.tensor(y, dtype=torch.float)) def test_collate_datapipe(self): @@ -393,11 +399,11 @@ def _filter_fn(data, val, clip=False): return data >= val return True - filter_dp = dp.iter.Filter(input_ds, 5, filter_fn=_filter_fn) + filter_dp = dp.iter.Filter(input_ds, filter_fn=_filter_fn, fn_args=(5, )) for data, exp in zip(filter_dp, range(10)): self.assertEqual(data, exp) - filter_dp = dp.iter.Filter(input_ds, filter_fn=_filter_fn, val=5, clip=True) + filter_dp = dp.iter.Filter(input_ds, filter_fn=_filter_fn, fn_kwargs={'val': 5, 'clip': True}) for data, exp in zip(filter_dp, range(5, 10)): self.assertEqual(data, exp) diff --git a/torch/utils/data/datapipes/iter/__init__.py b/torch/utils/data/datapipes/iter/__init__.py index 76030963f37aa5c..13f77653a0f1183 100755 --- a/torch/utils/data/datapipes/iter/__init__.py +++ b/torch/utils/data/datapipes/iter/__init__.py @@ -8,7 +8,7 @@ # Functional DataPipe from torch.utils.data.datapipes.iter.batch import BatchIterDataPipe as Batch, BucketBatchIterDataPipe as BucketBatch from torch.utils.data.datapipes.iter.callable import \ - (CallableIterDataPipe as Callable, CollateIterDataPipe as Collate) + (MapIterDataPipe as Map, CollateIterDataPipe as Collate) from torch.utils.data.datapipes.iter.selecting import \ (FilterIterDataPipe as Filter) from torch.utils.data.datapipes.iter.sampler import \ @@ -16,4 +16,4 @@ __all__ = ['ListDirFiles', 'LoadFilesFromDisk', 'ReadFilesFormTar', 'ReadFilesFromZip', 'RoutedDecoder', 'GroupByKey', - 'Batch', 'BucketBatch', 'Callable', 'Collate', 'Filter', 'Sampler'] + 'Batch', 'BucketBatch', 'Collate', 'Filter', 'Map', 'Sampler'] diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index deccbd8ef390569..10a2e07b29aa049 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -1,6 +1,6 @@ import warnings from torch.utils.data import IterDataPipe, _utils -from typing import TypeVar, Callable, Iterator, Sized +from typing import TypeVar, Callable, Iterator, Sized, Optional, Tuple, Dict T_co = TypeVar('T_co', covariant=True) @@ -12,31 +12,37 @@ def default_fn(data): return data -class CallableIterDataPipe(IterDataPipe[T_co]): - r""" :class:`CallableIterDataPipe`. +class MapIterDataPipe(IterDataPipe[T_co]): + r""" :class:`MapIterDataPipe`. Iterable DataPipe to run a function over each item from the source DataPipe. + The function can be any regular python function or partial object. Lambda + function is not recommended as it is not supported by pickle. args: datapipe: Source Iterable DataPipe fn: Function called over each item + fn_args: Positional arguments for `fn` + fn_kwargs: Keyword arguments for `fn` """ datapipe: IterDataPipe fn: Callable def __init__(self, datapipe: IterDataPipe, - *args, + *, fn: Callable = default_fn, - **kwargs, + fn_args: Optional[Tuple] = None, + fn_kwargs: Optional[Dict] = None, ) -> None: super().__init__() self.datapipe = datapipe - if fn.__name__ == '': - warnings.warn("Lambda function is not supported for pickle, " - "please use regular python function instead.") + # Partial object has no attribute '__name__', but can be pickled + if hasattr(fn, '__name__') and fn.__name__ == '': + warnings.warn("Lambda function is not supported for pickle, please use " + "regular python function or functools.partial instead.") self.fn = fn # type: ignore - self.args = args - self.kwargs = kwargs + self.args = () if fn_args is None else fn_args + self.kwargs = {} if fn_kwargs is None else fn_kwargs def __iter__(self) -> Iterator[T_co]: for data in self.datapipe: @@ -48,7 +54,7 @@ def __len__(self) -> int: raise NotImplementedError -class CollateIterDataPipe(CallableIterDataPipe): +class CollateIterDataPipe(MapIterDataPipe): r""" :class:`CollateIterDataPipe`. Iterable DataPipe to collate samples from datapipe to Tensor(s) by `util_.collate.default_collate`, @@ -57,6 +63,8 @@ class CollateIterDataPipe(CallableIterDataPipe): datapipe: Iterable DataPipe being collated collate_fn: Customized collate function to collect and combine data or a batch of data. Default function collates to Tensor(s) based on data type. + fn_args: Positional arguments for `collate_fn` + fn_kwargs: Keyword arguments for `collate_fn` Example: Convert integer data to float Tensor >>> class MyIterDataPipe(torch.utils.data.IterDataPipe): @@ -85,8 +93,9 @@ class CollateIterDataPipe(CallableIterDataPipe): """ def __init__(self, datapipe: IterDataPipe, - *args, + *, collate_fn: Callable = _utils.collate.default_collate, - **kwargs, + fn_args: Optional[Tuple] = None, + fn_kwargs: Optional[Dict] = None, ) -> None: - super().__init__(datapipe, *args, fn=collate_fn, **kwargs) + super().__init__(datapipe, fn=collate_fn, fn_args=fn_args, fn_kwargs=fn_kwargs) diff --git a/torch/utils/data/datapipes/iter/selecting.py b/torch/utils/data/datapipes/iter/selecting.py index 78a2a74d31f39af..084f2ea839803dd 100644 --- a/torch/utils/data/datapipes/iter/selecting.py +++ b/torch/utils/data/datapipes/iter/selecting.py @@ -1,26 +1,29 @@ from torch.utils.data import IterDataPipe -from typing import Callable, TypeVar, Iterator +from typing import Callable, TypeVar, Iterator, Optional, Tuple, Dict -from .callable import CallableIterDataPipe +from .callable import MapIterDataPipe T_co = TypeVar('T_co', covariant=True) -class FilterIterDataPipe(CallableIterDataPipe[T_co]): +class FilterIterDataPipe(MapIterDataPipe[T_co]): r""" :class:`FilterIterDataPipe`. Iterable DataPipe to filter elements from datapipe according to filter_fn. args: datapipe: Iterable DataPipe being filterd filter_fn: Customized function mapping an element to a boolean. + fn_args: Positional arguments for `filter_fn` + fn_kwargs: Keyword arguments for `filter_fn` """ def __init__(self, datapipe: IterDataPipe[T_co], - *args, + *, filter_fn: Callable[..., bool], - **kwargs, + fn_args: Optional[Tuple] = None, + fn_kwargs: Optional[Dict] = None, ) -> None: - super().__init__(datapipe, *args, fn=filter_fn, **kwargs) + super().__init__(datapipe, fn=filter_fn, fn_args=fn_args, fn_kwargs=fn_kwargs) def __iter__(self) -> Iterator[T_co]: res: bool