Skip to content

Commit

Permalink
Merge 09825e3 into 69f4b45
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaggard committed Jan 31, 2021
2 parents 69f4b45 + 09825e3 commit d5d432a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
7 changes: 7 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
Changes
=======

Version 1.7.2
-------------

* Allow specifying output field name for simple aggregation
By :user:`bmaggard`, :issue:`370`.


Version 1.7.1
-------------

Expand Down
7 changes: 7 additions & 0 deletions petl/test/transform/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ def test_aggregate_simple():
ieq(expect4, table4)
ieq(expect4, table4)

table5 = aggregate(table1, 'foo', len, field='nrows')
expect5 = (('foo', 'nrows'),
('a', 2),
('b', 3),
('c', 1))
ieq(expect5, table5)
ieq(expect5, table5)

def test_aggregate_multifield():

Expand Down
18 changes: 10 additions & 8 deletions petl/transform/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def iterrowreduce(source, key, reducer, header):


def aggregate(table, key, aggregation=None, value=None, presorted=False,
buffersize=None, tempdir=None, cache=True):
buffersize=None, tempdir=None, cache=True, field='value'):
"""Group rows under the given key then apply aggregation functions.
E.g.::
Expand Down Expand Up @@ -183,7 +183,7 @@ def aggregate(table, key, aggregation=None, value=None, presorted=False,
return SimpleAggregateView(table, key, aggregation=aggregation,
value=value, presorted=presorted,
buffersize=buffersize, tempdir=tempdir,
cache=cache)
cache=cache, field=field)
elif aggregation is None or isinstance(aggregation, (list, tuple, dict)):
# ignore value arg
return MultiAggregateView(table, key, aggregation=aggregation,
Expand All @@ -200,7 +200,8 @@ def aggregate(table, key, aggregation=None, value=None, presorted=False,
class SimpleAggregateView(Table):

def __init__(self, table, key, aggregation=list, value=None,
presorted=False, buffersize=None, tempdir=None, cache=True):
presorted=False, buffersize=None, tempdir=None,
cache=True, field='value'):
if presorted:
self.table = table
else:
Expand All @@ -209,25 +210,26 @@ def __init__(self, table, key, aggregation=list, value=None,
self.key = key
self.aggregation = aggregation
self.value = value
self.field = field

def __iter__(self):
return itersimpleaggregate(self.table, self.key, self.aggregation,
self.value)
self.value, self.field)


def itersimpleaggregate(table, key, aggregation, value):
def itersimpleaggregate(table, key, aggregation, value, field):

# special case counting
if aggregation == len:
aggregation = lambda g: sum(1 for _ in g) # count length of iterable

# determine output header
if isinstance(key, (list, tuple)):
outhdr = tuple(key) + ('value',)
outhdr = tuple(key) + (field,)
elif callable(key):
outhdr = ('key', 'value')
outhdr = ('key', field)
else:
outhdr = (key, 'value')
outhdr = (key, field)
yield outhdr

# generate data
Expand Down

0 comments on commit d5d432a

Please sign in to comment.