Skip to content

Commit

Permalink
[DataLoader] Rename Callable to Map IterDataPipe (#51879)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #51879

Test Plan: Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D26314775

Pulled By: ejguan

fbshipit-source-id: ee77909eae97092155ed6a6c794540e68a04d754
  • Loading branch information
ejguan authored and facebook-github-bot committed Feb 10, 2021
1 parent 104371e commit 9eb70c3
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 48 deletions.
58 changes: 32 additions & 26 deletions test/test_datapipe.py
Expand Up @@ -245,26 +245,26 @@ class TestFunctionalIterDataPipe(TestCase):

def test_picklable(self):
arr = range(10)
picklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, List, Dict[str, Any]]] = [
(dp.iter.Callable, IDP(arr), [], {}),
(dp.iter.Callable, IDP(arr), [0], {'fn': _fake_fn, 'test': True}),
(dp.iter.Collate, IDP(arr), [], {}),
(dp.iter.Collate, IDP(arr), [0], {'collate_fn': _fake_fn, 'test': True}),
(dp.iter.Filter, IDP(arr), [0], {'filter_fn': _fake_filter_fn, 'test': True}),
picklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Dict[str, Any]]] = [
(dp.iter.Map, IDP(arr), {}),
(dp.iter.Map, IDP(arr), {'fn': _fake_fn, 'fn_args': (0), 'fn_kwargs': {'test': True}}),
(dp.iter.Collate, IDP(arr), {}),
(dp.iter.Collate, IDP(arr), {'collate_fn': _fake_fn, 'fn_args': (0), 'fn_kwargs': {'test': True}}),
(dp.iter.Filter, IDP(arr), {'filter_fn': _fake_filter_fn, 'fn_args': (0), 'fn_kwargs': {'test': True}}),
]
for dpipe, input_dp, args, kargs in picklable_datapipes:
p = pickle.dumps(dpipe(input_dp, *args, **kargs)) # type: ignore
for dpipe, input_dp, kargs in picklable_datapipes:
p = pickle.dumps(dpipe(input_dp, **kargs)) # type: ignore

unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, List, Dict[str, Any]]] = [
(dp.iter.Callable, IDP(arr), [], {'fn': lambda x: x}),
(dp.iter.Collate, IDP(arr), [], {'collate_fn': lambda x: x}),
(dp.iter.Filter, IDP(arr), [], {'filter_fn': lambda x: x >= 5}),
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Dict[str, Any]]] = [
(dp.iter.Map, IDP(arr), {'fn': lambda x: x}),
(dp.iter.Collate, IDP(arr), {'collate_fn': lambda x: x}),
(dp.iter.Filter, IDP(arr), {'filter_fn': lambda x: x >= 5}),
]
for dpipe, input_dp, args, kargs in unpicklable_datapipes:
for dpipe, input_dp, kargs in unpicklable_datapipes:
with self.assertRaises(AttributeError):
p = pickle.dumps(dpipe(input_dp, *args, **kargs)) # type: ignore
p = pickle.dumps(dpipe(input_dp, **kargs)) # type: ignore

def test_callable_datapipe(self):
def test_map_datapipe(self):
arr = range(10)
input_dp = IDP(arr)
input_dp_nl = IDP_NoLen(arr)
Expand All @@ -273,20 +273,26 @@ def fn(item, dtype=torch.float, *, sum=False):
data = torch.tensor(item, dtype=dtype)
return data if not sum else data.sum()

callable_dp = dp.iter.Callable(input_dp, fn=fn) # type: ignore
self.assertEqual(len(input_dp), len(callable_dp))
for x, y in zip(callable_dp, input_dp):
map_dp = dp.iter.Map(input_dp, fn=fn) # type: ignore
self.assertEqual(len(input_dp), len(map_dp))
for x, y in zip(map_dp, input_dp):
self.assertEqual(x, torch.tensor(y, dtype=torch.float))

callable_dp = dp.iter.Callable(input_dp, torch.int, fn=fn, sum=True) # type: ignore
self.assertEqual(len(input_dp), len(callable_dp))
for x, y in zip(callable_dp, input_dp):
map_dp = dp.iter.Map(input_dp, fn=fn, fn_args=(torch.int, ), fn_kwargs={'sum': True}) # type: ignore
self.assertEqual(len(input_dp), len(map_dp))
for x, y in zip(map_dp, input_dp):
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())

callable_dp_nl = dp.iter.Callable(input_dp_nl) # type: ignore
from functools import partial
map_dp = dp.iter.Map(input_dp, fn=partial(fn, dtype=torch.int, sum=True)) # type: ignore
self.assertEqual(len(input_dp), len(map_dp))
for x, y in zip(map_dp, input_dp):
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())

map_dp_nl = dp.iter.Map(input_dp_nl) # type: ignore
with self.assertRaises(NotImplementedError):
len(callable_dp_nl)
for x, y in zip(callable_dp_nl, input_dp_nl):
len(map_dp_nl)
for x, y in zip(map_dp_nl, input_dp_nl):
self.assertEqual(x, torch.tensor(y, dtype=torch.float))

def test_collate_datapipe(self):
Expand Down Expand Up @@ -393,11 +399,11 @@ def _filter_fn(data, val, clip=False):
return data >= val
return True

filter_dp = dp.iter.Filter(input_ds, 5, filter_fn=_filter_fn)
filter_dp = dp.iter.Filter(input_ds, filter_fn=_filter_fn, fn_args=(5, ))
for data, exp in zip(filter_dp, range(10)):
self.assertEqual(data, exp)

filter_dp = dp.iter.Filter(input_ds, filter_fn=_filter_fn, val=5, clip=True)
filter_dp = dp.iter.Filter(input_ds, filter_fn=_filter_fn, fn_kwargs={'val': 5, 'clip': True})
for data, exp in zip(filter_dp, range(5, 10)):
self.assertEqual(data, exp)

Expand Down
4 changes: 2 additions & 2 deletions torch/utils/data/datapipes/iter/__init__.py
Expand Up @@ -8,12 +8,12 @@
# Functional DataPipe
from torch.utils.data.datapipes.iter.batch import BatchIterDataPipe as Batch, BucketBatchIterDataPipe as BucketBatch
from torch.utils.data.datapipes.iter.callable import \
(CallableIterDataPipe as Callable, CollateIterDataPipe as Collate)
(MapIterDataPipe as Map, CollateIterDataPipe as Collate)
from torch.utils.data.datapipes.iter.selecting import \
(FilterIterDataPipe as Filter)
from torch.utils.data.datapipes.iter.sampler import \
(SamplerIterDataPipe as Sampler)


__all__ = ['ListDirFiles', 'LoadFilesFromDisk', 'ReadFilesFormTar', 'ReadFilesFromZip', 'RoutedDecoder', 'GroupByKey',
'Batch', 'BucketBatch', 'Callable', 'Collate', 'Filter', 'Sampler']
'Batch', 'BucketBatch', 'Collate', 'Filter', 'Map', 'Sampler']
37 changes: 23 additions & 14 deletions torch/utils/data/datapipes/iter/callable.py
@@ -1,6 +1,6 @@
import warnings
from torch.utils.data import IterDataPipe, _utils
from typing import TypeVar, Callable, Iterator, Sized
from typing import TypeVar, Callable, Iterator, Sized, Optional, Tuple, Dict

T_co = TypeVar('T_co', covariant=True)

Expand All @@ -12,31 +12,37 @@ def default_fn(data):
return data


class CallableIterDataPipe(IterDataPipe[T_co]):
r""" :class:`CallableIterDataPipe`.
class MapIterDataPipe(IterDataPipe[T_co]):
r""" :class:`MapIterDataPipe`.
Iterable DataPipe to run a function over each item from the source DataPipe.
The function can be any regular python function or partial object. Lambda
function is not recommended as it is not supported by pickle.
args:
datapipe: Source Iterable DataPipe
fn: Function called over each item
fn_args: Positional arguments for `fn`
fn_kwargs: Keyword arguments for `fn`
"""
datapipe: IterDataPipe
fn: Callable

def __init__(self,
datapipe: IterDataPipe,
*args,
*,
fn: Callable = default_fn,
**kwargs,
fn_args: Optional[Tuple] = None,
fn_kwargs: Optional[Dict] = None,
) -> None:
super().__init__()
self.datapipe = datapipe
if fn.__name__ == '<lambda>':
warnings.warn("Lambda function is not supported for pickle, "
"please use regular python function instead.")
# Partial object has no attribute '__name__', but can be pickled
if hasattr(fn, '__name__') and fn.__name__ == '<lambda>':
warnings.warn("Lambda function is not supported for pickle, please use "
"regular python function or functools.partial instead.")
self.fn = fn # type: ignore
self.args = args
self.kwargs = kwargs
self.args = () if fn_args is None else fn_args
self.kwargs = {} if fn_kwargs is None else fn_kwargs

def __iter__(self) -> Iterator[T_co]:
for data in self.datapipe:
Expand All @@ -48,7 +54,7 @@ def __len__(self) -> int:
raise NotImplementedError


class CollateIterDataPipe(CallableIterDataPipe):
class CollateIterDataPipe(MapIterDataPipe):
r""" :class:`CollateIterDataPipe`.
Iterable DataPipe to collate samples from datapipe to Tensor(s) by `util_.collate.default_collate`,
Expand All @@ -57,6 +63,8 @@ class CollateIterDataPipe(CallableIterDataPipe):
datapipe: Iterable DataPipe being collated
collate_fn: Customized collate function to collect and combine data or a batch of data.
Default function collates to Tensor(s) based on data type.
fn_args: Positional arguments for `collate_fn`
fn_kwargs: Keyword arguments for `collate_fn`
Example: Convert integer data to float Tensor
>>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
Expand Down Expand Up @@ -85,8 +93,9 @@ class CollateIterDataPipe(CallableIterDataPipe):
"""
def __init__(self,
datapipe: IterDataPipe,
*args,
*,
collate_fn: Callable = _utils.collate.default_collate,
**kwargs,
fn_args: Optional[Tuple] = None,
fn_kwargs: Optional[Dict] = None,
) -> None:
super().__init__(datapipe, *args, fn=collate_fn, **kwargs)
super().__init__(datapipe, fn=collate_fn, fn_args=fn_args, fn_kwargs=fn_kwargs)
15 changes: 9 additions & 6 deletions torch/utils/data/datapipes/iter/selecting.py
@@ -1,26 +1,29 @@
from torch.utils.data import IterDataPipe
from typing import Callable, TypeVar, Iterator
from typing import Callable, TypeVar, Iterator, Optional, Tuple, Dict

from .callable import CallableIterDataPipe
from .callable import MapIterDataPipe

T_co = TypeVar('T_co', covariant=True)


class FilterIterDataPipe(CallableIterDataPipe[T_co]):
class FilterIterDataPipe(MapIterDataPipe[T_co]):
r""" :class:`FilterIterDataPipe`.
Iterable DataPipe to filter elements from datapipe according to filter_fn.
args:
datapipe: Iterable DataPipe being filterd
filter_fn: Customized function mapping an element to a boolean.
fn_args: Positional arguments for `filter_fn`
fn_kwargs: Keyword arguments for `filter_fn`
"""
def __init__(self,
datapipe: IterDataPipe[T_co],
*args,
*,
filter_fn: Callable[..., bool],
**kwargs,
fn_args: Optional[Tuple] = None,
fn_kwargs: Optional[Dict] = None,
) -> None:
super().__init__(datapipe, *args, fn=filter_fn, **kwargs)
super().__init__(datapipe, fn=filter_fn, fn_args=fn_args, fn_kwargs=fn_kwargs)

def __iter__(self) -> Iterator[T_co]:
res: bool
Expand Down

0 comments on commit 9eb70c3

Please sign in to comment.