Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataPipe] Improve Mapper to accept input/output index when apply fn #64951

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 125 additions & 3 deletions test/test_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import torch.utils.data.datapipes as dp
import torch.utils.data.graph
import torch.utils.data.sharding
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_utils import TestCase, run_tests, suppress_warnings
from torch.utils.data import (
DataLoader,
DataChunk,
Expand Down Expand Up @@ -837,7 +837,7 @@ def test_demux_datapipe(self):
with self.assertRaises(TypeError):
len(dp2)


@suppress_warnings # Suppress warning for lambda fn
def test_map_datapipe(self):
input_dp = IDP(range(10))

Expand All @@ -862,12 +862,134 @@ def fn(item, dtype=torch.float, *, sum=False):
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())

input_dp_nl = IDP_NoLen(range(10))
map_dp_nl = input_dp_nl.map()
map_dp_nl = input_dp_nl.map(lambda x: x)
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
len(map_dp_nl)
for x, y in zip(map_dp_nl, input_dp_nl):
self.assertEqual(x, torch.tensor(y, dtype=torch.float))

@suppress_warnings # Suppress warning for lambda fn
def test_map_tuple_list_with_col_datapipe(self):
def fn_11(d):
return -d

def fn_1n(d):
return -d, d

def fn_n1(d0, d1):
return d0 + d1

def fn_nn(d0, d1):
return -d0, -d1, d0 + d1

def _helper(ref_fn, fn, input_col=None, output_col=None):
for constr in (list, tuple):
datapipe = IDP([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))

# Replacing with one input column and default output column
_helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
_helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
# The index of input column is out of range
with self.assertRaises(IndexError):
_helper(None, fn_1n, 3)
# Unmatched input columns with fn arguments
with self.assertRaises(TypeError):
_helper(None, fn_n1, 1)
# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: (data[0], data[1], data[2] + data[0]), fn_n1, [2, 0])
_helper(lambda data: (data[0], data[1], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1])

# output_col can only be specified when input_col is not None
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, 1)
# output_col can only be single-element list or tuple
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, [0, 1])
# Single-element list as output_col
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0])
# Replacing with one input column and single specified output column
_helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0)
_helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2)
# The index of output column is out of range
with self.assertRaises(IndexError):
_helper(None, fn_1n, 1, 3)
_helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1)
_helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0)

# Appending the output at the end
_helper(lambda data: (*data, -data[1]), fn_11, 1, -1)
_helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1)
_helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1)
_helper(lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), fn_nn, [1, 2], -1)

@suppress_warnings # Suppress warning for lambda fn
def test_map_dict_with_col_datapipe(self):
def fn_11(d):
return -d

def fn_1n(d):
return -d, d

def fn_n1(d0, d1):
return d0 + d1

def fn_nn(d0, d1):
return -d0, -d1, d0 + d1

# Prevent modification in-place to support resetting
def _dict_update(data, newdata):
_data = dict(data)
_data.update(newdata)
return _data

def _helper(ref_fn, fn, input_col=None, output_col=None):
datapipe = IDP([{"x": 0, "y": 1, "z": 2},
{"x": 3, "y": 4, "z": 5},
{"x": 6, "y": 7, "z": 8}])
res_dp = datapipe.map(fn, input_col, output_col)
ref_dp = datapipe.map(ref_fn)
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))

# Replacing with one input column and default output column
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
# The key of input column is not in dict
with self.assertRaises(KeyError):
_helper(None, fn_1n, "a")
# Unmatched input columns with fn arguments
with self.assertRaises(TypeError):
_helper(None, fn_n1, "y")
# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}), fn_n1, ["z", "x"])
_helper(lambda data: _dict_update(data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}), fn_nn, ["z", "y"])

# output_col can only be specified when input_col is not None
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, "x")
# output_col can only be single-element list or tuple
with self.assertRaises(ValueError):
_helper(None, fn_n1, None, ["x", "y"])
# Single-element list as output_col
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"])
# Replacing with one input column and single specified output column
_helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x")
_helper(lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), fn_1n, "y", "z")
_helper(lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), fn_n1, ["x", "z"], "y")
_helper(lambda data: _dict_update(data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}), fn_nn, ["y", "z"], "x")

# Adding new key to dict for the output
_helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a")
_helper(lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), fn_1n, "y", "a")
_helper(lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), fn_n1, ["x", "z"], "a")
_helper(lambda data: _dict_update(data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}), fn_nn, ["y", "z"], "a")

# TODO(VitalyFedyunin): If dill installed this test fails
def _test_map_datapipe_nested_level(self):

Expand Down
149 changes: 113 additions & 36 deletions torch/utils/data/datapipes/iter/callable.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import warnings
from torch.utils.data import IterDataPipe, _utils, functional_datapipe, DataChunk
from typing import Callable, Dict, Iterator, Optional, Sized, Tuple, TypeVar
Expand All @@ -14,19 +15,12 @@
except ImportError:
DILL_AVAILABLE = False

T_co = TypeVar('T_co', covariant=True)
T_co = TypeVar("T_co", covariant=True)


# Default function to return each item directly
# In order to keep datapipe picklable, eliminates the usage
# of python lambda function
def default_fn(data):
return data


@functional_datapipe('map')
@functional_datapipe("map")
class MapperIterDataPipe(IterDataPipe[T_co]):
r""" :class:`MapperIterDataPipe`.
r""":class:`MapperIterDataPipe`.

Iterable DataPipe to run a function over each item from the source DataPipe.
The function can be any regular python function or partial object. Lambda
Expand All @@ -35,6 +29,15 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
Args:
datapipe: Source Iterable DataPipe
fn: Function called over each item
input_col: Index or indices of data which `fn` is applied
- None as default to apply `fn` to the data directly.
- Integer(s) is used for list/tuple.
- Key(s) is used for dict.
output_col: Index of data where result of `fn` is placed. Can be specified only when `input_col` is not None
- None as default to replace the index that `input_col` specified;
For multiple indices of `input_col`, the left-most one is used.
- Integer is used for list/tuple. -1 represents to appending result at the end.
- Key is used for dict. New key is acceptable.
fn_args: Positional arguments for `fn`
fn_kwargs: Keyword arguments for `fn`
nesting_level: Determines which level the fn gets applied to, by default it applies to the top level (= 0).
Expand All @@ -44,43 +47,98 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
datapipe: IterDataPipe
fn: Callable

def __init__(self,
datapipe: IterDataPipe,
fn: Callable = default_fn,
fn_args: Optional[Tuple] = None,
fn_kwargs: Optional[Dict] = None,
nesting_level: int = 0,
) -> None:
def __init__(
self,
datapipe: IterDataPipe,
fn: Callable,
input_col=None,
output_col=None,
*,
fn_args: Optional[Tuple] = None,
fn_kwargs: Optional[Dict] = None,
nesting_level: int = 0,
) -> None:
super().__init__()
self.datapipe = datapipe
# Partial object has no attribute '__name__', but can be pickled
if hasattr(fn, '__name__') and fn.__name__ == '<lambda>' and not DILL_AVAILABLE:
warnings.warn("Lambda function is not supported for pickle, please use "
"regular python function or functools.partial instead.")
if hasattr(fn, "__name__") and fn.__name__ == "<lambda>" and not DILL_AVAILABLE:
warnings.warn(
"Lambda function is not supported for pickle, please use "
"regular python function or functools.partial instead."
)
self.fn = fn # type: ignore[assignment]
self.input_col = input_col
if input_col is None and output_col is not None:
raise ValueError("`output_col` must be None when `input_col` is None.")
if isinstance(output_col, (list, tuple)):
if len(output_col) > 1:
raise ValueError("`output_col` must be a single-element list or tuple")
output_col = output_col[0]
self.output_col = output_col
self.args = () if fn_args is None else fn_args
self.kwargs = {} if fn_kwargs is None else fn_kwargs
if nesting_level < -1:
raise ValueError("nesting_level must be -1 or >= 0")
self.nesting_level = nesting_level

def _apply_fn(self, data):
if self.input_col is None and self.output_col is None:
return self.fn(data, *self.args, **self.kwargs)

if self.input_col is None:
res = self.fn(data, *self.args, **self.kwargs)
elif isinstance(self.input_col, (list, tuple)):
args = tuple(data[col] for col in self.input_col)
res = self.fn(*args, *self.args, **self.kwargs)
else:
res = self.fn(data[self.input_col], *self.args, **self.kwargs)

# Copy tuple to list and run in-place modification because tuple is immutable.
if isinstance(data, tuple):
t_flag = True
data = list(data)
else:
t_flag = False
# Deepcopy data to prevent the original data modified. E.g. list, dict
data = copy.deepcopy(data)

if self.output_col is None:
if isinstance(self.input_col, (list, tuple)):
data[self.input_col[0]] = res
else:
data[self.input_col] = res
else:
if self.output_col == -1:
data.append(res)
else:
data[self.output_col] = res

# Convert list back to tuple
return tuple(data) if t_flag else data

def _apply(self, data, nesting_level):
if nesting_level == 0:
return self.fn(data, *self.args, **self.kwargs)
return self._apply_fn(data)
elif nesting_level > 0:
if isinstance(data, DataChunk):
return type(data)([self._apply(i, nesting_level - 1) for i in data.raw_iterator()])
return type(data)(
[self._apply(i, nesting_level - 1) for i in data.raw_iterator()]
)
elif isinstance(data, list):
return [self._apply(i, nesting_level - 1) for i in data]
else:
raise IndexError(f"nesting_level {self.nesting_level} out of range (exceeds data pipe depth)")
raise IndexError(
f"nesting_level {self.nesting_level} out of range (exceeds data pipe depth)"
)
else:
if isinstance(data, DataChunk):
return type(data)([self._apply(i, nesting_level) for i in data.raw_iterator()])
return type(data)(
[self._apply(i, nesting_level) for i in data.raw_iterator()]
)
elif isinstance(data, list):
return [self._apply(i, nesting_level) for i in data]
else:
return self.fn(data, *self.args, **self.kwargs)
return self._apply_fn(data)

def __iter__(self) -> Iterator[T_co]:
for data in self.datapipe:
Expand All @@ -89,27 +147,45 @@ def __iter__(self) -> Iterator[T_co]:
def __len__(self) -> int:
if isinstance(self.datapipe, Sized):
return len(self.datapipe)
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
raise TypeError(
"{} instance doesn't have valid length".format(type(self).__name__)
)

def __getstate__(self):
if DILL_AVAILABLE:
dill_function = dill.dumps(self.fn)
else:
dill_function = self.fn
state = (self.datapipe, dill_function, self.args, self.kwargs, self.nesting_level)
state = (
self.datapipe,
dill_function,
self.input_col,
self.output_col,
self.args,
self.kwargs,
self.nesting_level,
)
return state

def __setstate__(self, state):
(self.datapipe, dill_function, self.args, self.kwargs, self.nesting_level) = state
(
self.datapipe,
dill_function,
self.input_col,
self.output_col,
self.args,
self.kwargs,
self.nesting_level,
) = state
if DILL_AVAILABLE:
self.fn = dill.loads(dill_function) # type: ignore[assignment]
else:
self.fn = dill_function # type: ignore[assignment]


@functional_datapipe('collate')
@functional_datapipe("collate")
class CollatorIterDataPipe(MapperIterDataPipe):
r""" :class:`CollatorIterDataPipe`.
r""":class:`CollatorIterDataPipe`.

Iterable DataPipe to collate samples from datapipe to Tensor(s) by `util_.collate.default_collate`,
or customized Data Structure by collate_fn.
Expand Down Expand Up @@ -147,10 +223,11 @@ class CollatorIterDataPipe(MapperIterDataPipe):
[tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
"""

def __init__(self,
datapipe: IterDataPipe,
collate_fn: Callable = _utils.collate.default_collate,
fn_args: Optional[Tuple] = None,
fn_kwargs: Optional[Dict] = None,
) -> None:
def __init__(
self,
datapipe: IterDataPipe,
collate_fn: Callable = _utils.collate.default_collate,
fn_args: Optional[Tuple] = None,
fn_kwargs: Optional[Dict] = None,
) -> None:
super().__init__(datapipe, fn=collate_fn, fn_args=fn_args, fn_kwargs=fn_kwargs)