diff --git a/toolz/itertoolz.py b/toolz/itertoolz.py index 941b9a9f..6fb03967 100644 --- a/toolz/itertoolz.py +++ b/toolz/itertoolz.py @@ -809,13 +809,74 @@ def getter(index): return operator.itemgetter(index) +def _inner_join(leftkey, leftseq, rightkey, rightseq): + d = groupby(leftkey, leftseq) + for item in rightseq: + key = rightkey(item) + if key in d: + for left_match in d[key]: + yield (left_match, item) + + +def _right_join(leftkey, leftseq, rightkey, rightseq, + left_default=no_default): + d = groupby(leftkey, leftseq) + for item in rightseq: + key = rightkey(item) + if key in d: + for left_match in d[key]: + yield (left_match, item) + else: + yield (left_default, item) + + +def _left_join(leftkey, leftseq, rightkey, rightseq, + right_default=no_default): + d = groupby(leftkey, leftseq) + seen_keys = set() + for item in rightseq: + key = rightkey(item) + seen_keys.add(key) + if key in d: + for left_match in d[key]: + yield(left_match, item) + + for key, matches in iteritems(d): + if key not in seen_keys: + for match in matches: + yield (match, right_default) + + +def _full_join(leftkey, leftseq, rightkey, rightseq, + left_default=no_default, right_default=no_default): + d = groupby(leftkey, leftseq) + seen_keys = set() + for item in rightseq: + key = rightkey(item) + seen_keys.add(key) + if key in d: + for left_match in d[key]: + yield (left_match, item) + else: + yield (left_default, item) + + for key, matches in iteritems(d): + if key not in seen_keys: + for match in matches: + yield (match, right_default) + + def join(leftkey, leftseq, rightkey, rightseq, left_default=no_default, right_default=no_default): """ Join two sequences on common attributes This is a semi-streaming operation. The LEFT sequence is fully evaluated - and placed into memory. The RIGHT sequence is evaluated lazily and so can - be arbitrarily large. + and placed into memory. The RIGHT sequence is evaluated lazily and unless + right_default is defined, it can be arbitrarily large. If right_default is + defined, the unique keys of rightseq will be placed into memory. + The join is implemented as a hash join and the keys of leftseq must be + hashable. Additionally, if right_default is defined, then keys of rightseq + must also be hashable. >>> friends = [('Alice', 'Edith'), ... ('Alice', 'Zhao'), @@ -868,26 +929,18 @@ def join(leftkey, leftseq, rightkey, rightseq, if not callable(rightkey): rightkey = getter(rightkey) - d = groupby(leftkey, leftseq) - seen_keys = set() - - left_default_is_no_default = (left_default == no_default) - for item in rightseq: - key = rightkey(item) - seen_keys.add(key) - try: - left_matches = d[key] - for match in left_matches: - yield (match, item) - except KeyError: - if not left_default_is_no_default: - yield (left_default, item) - - if right_default != no_default: - for key, matches in d.items(): - if key not in seen_keys: - for match in matches: - yield (match, right_default) + if (left_default == no_default) and (right_default == no_default): + return _inner_join(leftkey, leftseq, rightkey, rightseq) + elif (left_default != no_default) and (right_default == no_default): + return _right_join(leftkey, leftseq, rightkey, rightseq, + left_default=left_default) + elif (left_default == no_default) and (right_default != no_default): + return _left_join(leftkey, leftseq, rightkey, rightseq, + right_default=right_default) + else: + return _full_join(leftkey, leftseq, rightkey, rightseq, + left_default=left_default, + right_default=right_default) def diff(*seqs, **kwargs):