Skip to content

Commit

Permalink
Allow grouping by non-text keys. Closes #205.
Browse files Browse the repository at this point in the history
  • Loading branch information
onyxfish committed Sep 3, 2015
1 parent bc46266 commit 50515c0
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
0.7.0
-----

* Add key_type argument to TableSet and Table.group_by. (#205)
* Nested TableSet's and multi-dimensional aggregates. (#204)
* TableSet.aggregate will now use key_name as the group column name. (#203)
* Added key_name argument to TableSet and Table.group_by.
Expand Down
6 changes: 4 additions & 2 deletions agate/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def left_outer_join(self, left_key, table, right_key):

return self._fork(rows, zip(column_names, column_types))

def group_by(self, key, key_name=None):
def group_by(self, key, key_name=None, key_type=None):
"""
Create a new :class:`Table` for unique value and return them as a
:class:`.TableSet`. The :code:`key` can be either a column name
Expand All @@ -524,6 +524,8 @@ def group_by(self, key, key_name=None):
:param key_name: A name that describes the grouped properties.
Defaults to the column name that was grouped on or "group" if
grouping with a key function. See :class:`.TableSet` for more.
:param key_type: An instance some subclass of :class:`.ColumnType`. If
not provided it will default to a :class`.TextType`.
:returns: A :class:`.TableSet` mapping where the keys are unique
values from the :code:`key` and the values are new :class:`Table`
instances containing the grouped rows.
Expand Down Expand Up @@ -561,7 +563,7 @@ def group_by(self, key, key_name=None):
for group, rows in groups.items():
output[group] = self._fork(rows)

return TableSet(output, key_name=key_name)
return TableSet(output, key_name=key_name, key_type=key_type)

def compute(self, computations):
"""
Expand Down
13 changes: 8 additions & 5 deletions agate/tableset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ordereddict import OrderedDict

from agate.aggregations import Aggregation
from agate.column_types import TextType, NumberType
from agate.column_types import *
from agate.exceptions import ColumnDoesNotExistError
from agate.rows import RowSequence

Expand Down Expand Up @@ -65,12 +65,15 @@ class TableSet(Mapping):
values.
:param tables: A dictionary of string keys and :class:`Table` values.
:param group_name: A name that describes the grouping properties. Used as
:param key_name: A name that describes the grouping properties. Used as
the column header when the groups are aggregated. Defaults to the
column name that was grouped on.
:param key_type: An instance some subclass of :class:`.ColumnType`. If not
provided it will default to a :class`.TextType`.
"""
def __init__(self, group, key_name='group'):
def __init__(self, group, key_name='group', key_type=None):
self._key_name = key_name
self._key_type = key_type or TextType()

# Note: list call is a workaround for Python 3 "ValuesView"
self._sample_table = list(group.values())[0]
Expand Down Expand Up @@ -206,11 +209,11 @@ def _aggregate(self, aggregations=[]):
output.append(row)

column_names.insert(0, self._key_name)
column_types.insert(0, TextType())
column_types.insert(0, self._key_type)
# Regular Tables
else:
column_names = [self._key_name]
column_types = [TextType()]
column_types = [self._key_type]

for column_name, aggregation, new_column_name in aggregations:
c = self._sample_table.columns[column_name]
Expand Down
24 changes: 23 additions & 1 deletion tests/test_tableset.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_compute(self):
self.assertSequenceEqual(new_table._column_types, (self.number_type,))
self.assertSequenceEqual(new_table._column_names, ('number',))

def test_aggregate_grouper_name(self):
def test_aggregate_key_name(self):
tableset = TableSet(self.tables, key_name='test')

new_table = tableset.aggregate([
Expand All @@ -157,6 +157,28 @@ def test_aggregate_grouper_name(self):
self.assertEqual(len(new_table.rows), 3)
self.assertEqual(len(new_table.columns), 2)
self.assertSequenceEqual(new_table._column_names, ('test', 'count'))
self.assertIsInstance(new_table._column_types[0], TextType)
self.assertIsInstance(new_table._column_types[1], NumberType)

def test_aggregate_key_type(self):
tables = OrderedDict([
(1, Table(self.table1, self.columns)),
(2, Table(self.table2, self.columns)),
(3, Table(self.table3, self.columns))
])

tableset = TableSet(tables, key_name='test', key_type=self.number_type)

new_table = tableset.aggregate([
('number', Length(), 'count')
])

self.assertIsInstance(new_table, Table)
self.assertEqual(len(new_table.rows), 3)
self.assertEqual(len(new_table.columns), 2)
self.assertSequenceEqual(new_table._column_names, ('test', 'count'))
self.assertIsInstance(new_table._column_types[0], NumberType)
self.assertIsInstance(new_table._column_types[1], NumberType)

def test_aggregate_sum(self):
tableset = TableSet(self.tables)
Expand Down

0 comments on commit 50515c0

Please sign in to comment.