Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 75 additions & 22 deletions toolz/itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,13 +809,74 @@ def getter(index):
return operator.itemgetter(index)


def _inner_join(leftkey, leftseq, rightkey, rightseq):
d = groupby(leftkey, leftseq)
for item in rightseq:
key = rightkey(item)
if key in d:
for left_match in d[key]:
yield (left_match, item)


def _right_join(leftkey, leftseq, rightkey, rightseq,
left_default=no_default):
d = groupby(leftkey, leftseq)
for item in rightseq:
key = rightkey(item)
if key in d:
for left_match in d[key]:
yield (left_match, item)
else:
yield (left_default, item)


def _left_join(leftkey, leftseq, rightkey, rightseq,
right_default=no_default):
d = groupby(leftkey, leftseq)
seen_keys = set()
for item in rightseq:
key = rightkey(item)
seen_keys.add(key)
if key in d:
for left_match in d[key]:
yield(left_match, item)

for key, matches in iteritems(d):
if key not in seen_keys:
for match in matches:
yield (match, right_default)


def _full_join(leftkey, leftseq, rightkey, rightseq,
left_default=no_default, right_default=no_default):
d = groupby(leftkey, leftseq)
seen_keys = set()
for item in rightseq:
key = rightkey(item)
seen_keys.add(key)
if key in d:
for left_match in d[key]:
yield (left_match, item)
else:
yield (left_default, item)

for key, matches in iteritems(d):
if key not in seen_keys:
for match in matches:
yield (match, right_default)


def join(leftkey, leftseq, rightkey, rightseq,
left_default=no_default, right_default=no_default):
""" Join two sequences on common attributes

This is a semi-streaming operation. The LEFT sequence is fully evaluated
and placed into memory. The RIGHT sequence is evaluated lazily and so can
be arbitrarily large.
and placed into memory. The RIGHT sequence is evaluated lazily and unless
right_default is defined, it can be arbitrarily large. If right_default is
defined, the unique keys of rightseq will be placed into memory.
The join is implemented as a hash join and the keys of leftseq must be
hashable. Additionally, if right_default is defined, then keys of rightseq
must also be hashable.

>>> friends = [('Alice', 'Edith'),
... ('Alice', 'Zhao'),
Expand Down Expand Up @@ -868,26 +929,18 @@ def join(leftkey, leftseq, rightkey, rightseq,
if not callable(rightkey):
rightkey = getter(rightkey)

d = groupby(leftkey, leftseq)
seen_keys = set()

left_default_is_no_default = (left_default == no_default)
for item in rightseq:
key = rightkey(item)
seen_keys.add(key)
try:
left_matches = d[key]
for match in left_matches:
yield (match, item)
except KeyError:
if not left_default_is_no_default:
yield (left_default, item)

if right_default != no_default:
for key, matches in d.items():
if key not in seen_keys:
for match in matches:
yield (match, right_default)
if (left_default == no_default) and (right_default == no_default):
return _inner_join(leftkey, leftseq, rightkey, rightseq)
elif (left_default != no_default) and (right_default == no_default):
return _right_join(leftkey, leftseq, rightkey, rightseq,
left_default=left_default)
elif (left_default == no_default) and (right_default != no_default):
return _left_join(leftkey, leftseq, rightkey, rightseq,
right_default=right_default)
else:
return _full_join(leftkey, leftseq, rightkey, rightseq,
left_default=left_default,
right_default=right_default)


def diff(*seqs, **kwargs):
Expand Down