Skip to content

Commit

Permalink
Add support for full outer joins. wireservice/csvkit#711.
Browse files Browse the repository at this point in the history
  • Loading branch information
onyxfish committed Dec 26, 2016
1 parent 0c09a6c commit 28912b4
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
1.5.3
-----

* :meth:`.Table.join` now supports full outer joins via the ``full_outer`` keyword.
* :meth:`.Table.join` can now accept column indicies instead of column names.
* :meth:`.Table.from_csv` now buffers input files to prevent issues with using STDIN as an input.

Expand Down
43 changes: 31 additions & 12 deletions agate/table/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from agate import utils


def join(self, right_table, left_key, right_key=None, inner=False, require_match=False, columns=None):
def join(self, right_table, left_key, right_key=None, inner=False, full_outer=False, require_match=False, columns=None):
"""
Create a new table by joining two table's on common values.
Expand Down Expand Up @@ -38,16 +38,22 @@ def join(self, right_table, left_key, right_key=None, inner=False, require_match
Perform a SQL-style "inner join" instead of a left outer join. Rows
which have no match for :code:`left_key` will not be included in
the output table.
:param full_outer:
Perform a SQL-style "full outer" join rather than a left or a right.
May not be used in combination with :code:`inner`.
:param require_match:
If true, an exception will be raised if there is a left_key with no
matching right_key.
:param columns:
A sequence of column names from :code:`right_table` to include in
the final output table. Defaults to all columns not in
:code:`right_key`.
:code:`right_key`. Ignored when :code:`full_outer` is :code:`True`.
:returns:
A new :class:`.Table`.
"""
if inner and full_outer:
raise ValueError('A join can not be both "inner" and "full_outer".')

if right_key is None:
right_key = left_key

Expand Down Expand Up @@ -92,11 +98,12 @@ def join(self, right_table, left_key, right_key=None, inner=False, require_match
for i, column in enumerate(right_table._columns):
name = column.name

if columns is None and i in right_key_indices:
continue
if not full_outer:
if columns is None and i in right_key_indices:
continue

if columns is not None and name not in columns:
continue
if columns is not None and name not in columns:
continue

if name in self.column_names:
column_names.append('%s2' % name)
Expand All @@ -105,7 +112,7 @@ def join(self, right_table, left_key, right_key=None, inner=False, require_match

column_types.append(column.data_type)

if columns is not None:
if columns is not None and not full_outer:
right_table = right_table.select([n for n in right_table._column_names if n in columns])

right_hash = {}
Expand All @@ -119,7 +126,7 @@ def join(self, right_table, left_key, right_key=None, inner=False, require_match
# Collect new rows
rows = []

if self._row_names is not None:
if self._row_names is not None and not full_outer:
row_names = []
else:
row_names = None
Expand All @@ -137,28 +144,40 @@ def join(self, right_table, left_key, right_key=None, inner=False, require_match
new_row = list(self._rows[left_index])

for k, v in enumerate(right_row):
if columns is None and k in right_key_indices:
if columns is None and k in right_key_indices and not full_outer:
continue

new_row.append(v)

rows.append(Row(new_row, column_names))

if self._row_names is not None:
if self._row_names is not None and not full_outer:
row_names.append(self._row_names[left_index])
# Rows without matches
elif not inner:
new_row = list(self._rows[left_index])

for k, v in enumerate(right_table._column_names):
if columns is None and k in right_key_indices:
if columns is None and k in right_key_indices and not full_outer:
continue

new_row.append(None)

rows.append(Row(new_row, column_names))

if self._row_names is not None:
if self._row_names is not None and not full_outer:
row_names.append(self._row_names[left_index])

# Full outer join
if full_outer:
left_set = set(left_data)

for right_index, right_value in enumerate(right_data):
if right_value in left_set:
continue

new_row = ([None] * len(self._columns)) + list(right_table.rows[right_index])

rows.append(Row(new_row, column_names))

return self._fork(rows, column_names, column_types, row_names=row_names)
29 changes: 29 additions & 0 deletions tests/test_table/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,32 @@ def test_join_rows_are_tuples(self):
new_table = self.left.join(self.right, 'one', 'four', columns=['four', 'six'])

self.assertIsInstance(new_table.rows[0].values(), tuple)

def test_full_outer(self):
left_rows = (
(1, 4, 'a'),
(2, 3, 'b'),
(3, 2, 'c')
)

right_rows = (
(1, 4, 'a'),
(2, 3, 'b'),
(4, 2, 'c')
)

left = Table(left_rows, self.left_column_names, self.column_types)
right = Table(right_rows, self.right_column_names, self.column_types)

new_table = left.join(right, 'one', 'four', full_outer=True)

self.assertIsNot(new_table, left)
self.assertIsNot(new_table, right)
self.assertColumnNames(new_table, ['one', 'two', 'three', 'four', 'five', 'six'])
self.assertColumnTypes(new_table, [Number, Number, Text, Number, Number, Text])
self.assertRows(new_table, [
(1, 4, 'a', 1, 4, 'a'),
(2, 3, 'b', 2, 3, 'b'),
(3, 2, 'c', None, None, None),
(None, None, None, 4, 2, 'c')
])

0 comments on commit 28912b4

Please sign in to comment.