Skip to content

Commit

Permalink
Update on "Adding merge function as an optional input for KeyZipper"
Browse files Browse the repository at this point in the history
Fixes #44 by adding merge function as an optional input for `KeyZipper`. It allows user to specify how the items yielded from the DataPipes will be combined before they are yielded.

This PR changes the default behavior when `keep_key=True` and that may be BC breaking.
Previously, when `keep_key=True`, the output will be `(key, item1, item2)`, now it will default to `(key, (item1, item2)`. It is possible to leave the default behavior unchanged.

Note that `KeyZipper` may be renamed to `IterZipper` in #50 to provide users with better clarity.

Differential Revision: [D31820203](https://our.internmc.facebook.com/intern/diff/D31820203)

[ghstack-poisoned]
  • Loading branch information
NivekT committed Oct 21, 2021
1 parent d86bec1 commit f90a03f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
22 changes: 22 additions & 0 deletions test/test_datapipe.py
Expand Up @@ -125,6 +125,28 @@ def merge_to_string(item1, item2):
)
self.assertEqual([(i, f"{i},{i}") for i in range(10)], list(zip_dp_w_key_str_merge))

# Functional Test: testing nested zipping
zip_dp = source_dp.zip_by_key(
ref_datapipe=ref_dp, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=False, buffer_size=100
)

# Without a custom merge function, there will be nested tuples
zip_dp2 = zip_dp.zip_by_key(
ref_datapipe=ref_dp, key_fn=lambda x: x[0], ref_key_fn=lambda x: x, keep_key=False, buffer_size=100
)
self.assertEqual([((i, i), i) for i in range(10)], list(zip_dp2))

# With a custom merge function, nesting can be prevented
zip_dp2_w_merge = zip_dp.zip_by_key(
ref_datapipe=ref_dp,
key_fn=lambda x: x[0],
ref_key_fn=lambda x: x,
keep_key=False,
buffer_size=100,
merge_fn=lambda x, y: list(x) + [y],
)
self.assertEqual([[i, i, i] for i in range(10)], list(zip_dp2_w_merge))

# Functional Test: element is in source but missing in reference
ref_dp_missing = IterableWrapper(range(1, 10))
zip_dp = source_dp.zip_by_key(
Expand Down
17 changes: 7 additions & 10 deletions torchdata/datapipes/iter/util/combining.py
Expand Up @@ -3,11 +3,7 @@
from collections import OrderedDict

from torch.utils.data import IterDataPipe, MapDataPipe, functional_datapipe
from typing import Callable


def tuple_merge(item, item_from_map):
return item, item_from_map
from typing import Callable, Optional


@functional_datapipe("zip_by_key")
Expand Down Expand Up @@ -38,7 +34,7 @@ def __init__(
ref_key_fn: Callable = None,
keep_key: bool = False,
buffer_size: int = 10000,
merge_fn: Callable = tuple_merge,
merge_fn: Optional[Callable] = None,
):
self.source_datapipe = source_datapipe
self.ref_datapipe = ref_datapipe
Expand Down Expand Up @@ -76,10 +72,11 @@ def __iter__(self):
)
buffer.popitem(last=False)
buffer[ref_key] = ref_data
res = self.merge_fn(data, buffer.pop(key)) if self.merge_fn else (data, buffer.pop(key))
if self.keep_key:
yield key, self.merge_fn(data, buffer.pop(key))
yield key, res
else:
yield self.merge_fn(data, buffer.pop(key))
yield res

def __len__(self):
return len(self.source_datapipe)
Expand Down Expand Up @@ -107,7 +104,7 @@ def __init__(
source_iterdatapipe: IterDataPipe,
map_datapipe: MapDataPipe,
key_fn: Callable,
merge_fn: Callable = tuple_merge,
merge_fn: Optional[Callable] = None,
):
if not isinstance(map_datapipe, MapDataPipe):
raise TypeError(f"map_datapipe must be a MapDataPipe, but its type is {type(map_datapipe)} instead.")
Expand All @@ -124,7 +121,7 @@ def __iter__(self):
map_item = self.map_datapipe[key]
except (KeyError, IndexError):
raise KeyError(f"key_fn maps {item} to {key}, which is not a valid key in the given MapDataPipe.")
yield self.merge_fn(item, map_item)
yield self.merge_fn(item, map_item) if self.merge_fn else (item, map_item)

def __len__(self) -> int:
if self.length == -1:
Expand Down

0 comments on commit f90a03f

Please sign in to comment.