Skip to content

Commit

Permalink
BUG: pass index name in GroupBy.apply, GH #416
Browse files Browse the repository at this point in the history
  • Loading branch information
wesm committed Nov 25, 2011
1 parent 0505033 commit 6418067
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
1 change: 1 addition & 0 deletions RELEASE.rst
Expand Up @@ -76,6 +76,7 @@ pandas 0.6.0
- MaskedArray can be passed to DataFrame constructor and masked values will be
converted to NaN (PR #396)
- Add `DataFrame.boxplot` function (GH #368, others)
- Can pass extra args, kwds to DataFrame.apply (GH #376)
**Improvements to existing features**

Expand Down
18 changes: 11 additions & 7 deletions pandas/core/groupby.py
Expand Up @@ -756,12 +756,18 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):

key_names = [ping.name for ping in self.groupings]

def _get_index():
if len(self.groupings) > 1:
index = MultiIndex.from_tuples(keys, names=key_names)
else:
index = Index(keys, name=key_names[0])
return index

if isinstance(values[0], Series):
if not_indexed_same:
data_dict = dict(zip(keys, values))
result = DataFrame(data_dict).T
if len(self.groupings) > 1:
result.index = MultiIndex.from_tuples(keys, names=key_names)
result.index = _get_index()
return result
else:
cat_values = np.concatenate([x.values for x in values])
Expand All @@ -774,11 +780,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
return self._wrap_frames(keys, values,
not_indexed_same=not_indexed_same)
else:
if len(self.groupings) > 1:
index = MultiIndex.from_tuples(keys, names=key_names)
return Series(values, index)
else:
return Series(values, keys)
return Series(values, index=_get_index())

def _aggregate_multiple_funcs(self, arg):
if not isinstance(arg, dict):
Expand Down Expand Up @@ -1071,6 +1073,8 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
else:
if len(self.groupings) > 1:
keys = MultiIndex.from_tuples(keys, names=key_names)
else:
keys = Index(keys, name=key_names[0])

if isinstance(values[0], np.ndarray):
if self.axis == 0:
Expand Down
14 changes: 14 additions & 0 deletions pandas/tests/test_groupby.py
Expand Up @@ -732,6 +732,9 @@ def test_groupby_level(self):
expected0 = frame.groupby(deleveled['first']).sum()
expected1 = frame.groupby(deleveled['second']).sum()

self.assert_(result0.index.name == 'first')
self.assert_(result1.index.name == 'second')

assert_frame_equal(result0, expected0)
assert_frame_equal(result1, expected1)
self.assertEquals(result0.index.name, frame.index.names[0])
Expand All @@ -753,6 +756,17 @@ def test_groupby_level(self):
# raise exception for non-MultiIndex
self.assertRaises(ValueError, self.df.groupby, level=0)

def test_groupby_level_apply(self):
frame = self.mframe

result = frame.groupby(level=0).count()
self.assert_(result.index.name == 'first')
result = frame.groupby(level=1).count()
self.assert_(result.index.name == 'second')

result = frame['A'].groupby(level=0).count()
self.assert_(result.index.name == 'first')

def test_groupby_level_mapper(self):
frame = self.mframe
deleveled = frame.delevel()
Expand Down

0 comments on commit 6418067

Please sign in to comment.