Skip to content

Commit

Permalink
Refactor Any and All behavior. Closes #636.
Browse files Browse the repository at this point in the history
  • Loading branch information
onyxfish committed Nov 14, 2016
1 parent 8c38c43 commit 8ee59ee
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 38 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
1.5.0
-----

* :class:`.Any` and :class:`.All` aggregations no longer behave differently for boolean data. (#636)
* :class:`.Any` and :class:`.All` aggregations now accept a single value as a test argument, in addition to a function.
* :class:`.Any` and :class:`.All` aggregations now require a test argument.
* Tables rendered by :meth:`.Table.print_table` are now Github Friendly Markdown (GFM) compatible. (#626)
* The agate tutorial has been converted to a Jupyter Notebook.
* :class:`.Table` now supports ``len`` as a proxy for ``len(table.rows)``.
Expand Down
20 changes: 9 additions & 11 deletions agate/aggregations/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,27 @@ class All(Aggregation):
"""
Check if all values in a column pass a test.
The test may be omitted when checking :class:`.Boolean` data.
:param column_name:
The name of the column to check.
:param test:
A function that takes a value and returns `True` or `False`.
Either a single value that all values in the column are compared against
(for equality) or a function that takes a column value and returns
`True` or `False`.
"""
def __init__(self, column_name, test=None):
def __init__(self, column_name, test):
self._column_name = column_name
self._test = test

if callable(test):
self._test = test
else:
self._test = lambda d: d == test

def get_aggregate_data_type(self, table):
return Boolean()

def validate(self, table):
column = table.columns[self._column_name]

if not isinstance(column.data_type, Boolean) and not self._test:
raise ValueError('You must supply a test function for columns containing non-Boolean data.')

def run(self, table):
"""
:returns:
Expand All @@ -36,7 +37,4 @@ def run(self, table):
column = table.columns[self._column_name]
data = column.values()

if isinstance(column.data_type, Boolean):
return all(data)

return all(self._test(d) for d in data)
20 changes: 9 additions & 11 deletions agate/aggregations/any.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,29 @@ class Any(Aggregation):
"""
Check if any value in a column passes a test.
The test may be omitted when checking :class:`.Boolean` data.
:param column_name:
The name of the column to check.
:param test:
A function that takes a value and returns `True` or `False`.
Either a single value that all values in the column are compared against
(for equality) or a function that takes a column value and returns
`True` or `False`.
"""
def __init__(self, column_name, test=None):
def __init__(self, column_name, test):
self._column_name = column_name
self._test = test

if callable(test):
self._test = test
else:
self._test = lambda d: d == test

def get_aggregate_data_type(self, table):
return Boolean()

def validate(self, table):
column = table.columns[self._column_name]

if not isinstance(column.data_type, Boolean) and not self._test:
raise ValueError('You must supply a test function for columns containing non-Boolean data.')

def run(self, table):
column = table.columns[self._column_name]
data = column.values()

if isinstance(column.data_type, Boolean) and self._test is None:
return any(data)

return any(self._test(d) for d in data)
34 changes: 18 additions & 16 deletions tests/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,20 @@ def test_has_nulls(self):
self.assertEqual(has_nulls.run(self.table), True)

def test_any(self):
with self.assertRaises(ValueError):
Any('one').validate(self.table)

Any('one', lambda d: d).validate(self.table)

self.assertIsInstance(Any('one').get_aggregate_data_type(None), Boolean)
self.assertIsInstance(Any('one', 2).get_aggregate_data_type(None), Boolean)

self.assertEqual(Any('one', 2).run(self.table), True)
self.assertEqual(Any('one', 5).run(self.table), False)

self.assertEqual(Any('one', lambda d: d == 2).run(self.table), True)
self.assertEqual(Any('one', lambda d: d == 5).run(self.table), False)

def test_all(self):
with self.assertRaises(ValueError):
All('one').validate(self.table)

All('one', lambda d: d).validate(self.table)

self.assertIsInstance(All('one').get_aggregate_data_type(None), Boolean)
self.assertIsInstance(All('one', 5).get_aggregate_data_type(None), Boolean)
self.assertEqual(All('one', lambda d: d != 5).run(self.table), True)
self.assertEqual(All('one', lambda d: d == 2).run(self.table), False)

Expand Down Expand Up @@ -137,8 +134,8 @@ def test_any(self):
]

table = Table(rows, ['test'], [Boolean()])
Any('test').validate(table)
self.assertEqual(Any('test').run(table), True)
Any('test', True).validate(table)
self.assertEqual(Any('test', True).run(table), True)

rows = [
[False],
Expand All @@ -147,8 +144,10 @@ def test_any(self):
]

table = Table(rows, ['test'], [Boolean()])
Any('test').validate(table)
self.assertEqual(Any('test').run(table), False)
Any('test', True).validate(table)
self.assertEqual(Any('test', True).run(table), False)
self.assertEqual(Any('test', lambda r: r).run(table), False)
self.assertEqual(Any('test', False).run(table), True)
self.assertEqual(Any('test', lambda r: not r).run(table), True)

def test_all(self):
Expand All @@ -159,8 +158,8 @@ def test_all(self):
]

table = Table(rows, ['test'], [Boolean()])
All('test').validate(table)
self.assertEqual(All('test').run(table), False)
All('test', True).validate(table)
self.assertEqual(All('test', True).run(table), False)

rows = [
[True],
Expand All @@ -169,8 +168,11 @@ def test_all(self):
]

table = Table(rows, ['test'], [Boolean()])
All('test').validate(table)
self.assertEqual(All('test').run(table), True)
All('test', True).validate(table)
self.assertEqual(All('test', True).run(table), True)
self.assertEqual(All('test', lambda r: r).run(table), True)
self.assertEqual(All('test', False).run(table), False)
self.assertEqual(All('test', lambda r: not r).run(table), False)


class TestDateTimeAggregation(unittest.TestCase):
Expand Down

0 comments on commit 8ee59ee

Please sign in to comment.