Skip to content

Commit

Permalink
ENH: exclude nuisance columns in GroupBy.transform, close #1364
Browse files Browse the repository at this point in the history
  • Loading branch information
wesm committed Jun 2, 2012
1 parent 0d3c3ec commit f62f571
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
2 changes: 1 addition & 1 deletion RELEASE.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ pandas 0.8.0
- Enable storage of sparse data structures in HDFStore (#85)
- Enable Series.asof to work with arrays of timestamp inputs
- Cython implementation of DataFrame.corr speeds up by > 100x (#1349, #1354)

- Exclude "nuisance" columns automatically in GroupBy.transform (#1364)

**API Changes**

Expand Down
26 changes: 24 additions & 2 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,14 +1626,17 @@ def transform(self, func, *args, **kwargs):
obj = self._obj_with_exclusions
gen = self.grouper.get_iterator(obj, axis=self.axis)

wrapper = lambda x: func(x, *args, **kwargs)

for name, group in gen:
object.__setattr__(group, 'name', name)

try:
wrapper = lambda x: func(x, *args, **kwargs)
res = group.apply(wrapper, axis=self.axis)
except TypeError:
return self._transform_item_by_item(obj, wrapper)
except Exception: # pragma: no cover
res = func(group, *args, **kwargs)
res = wrapper(group)

# broadcasting
if isinstance(res, Series):
Expand All @@ -1651,6 +1654,25 @@ def transform(self, func, *args, **kwargs):
axis=self.axis, verify_integrity=False)
return concatenated.reindex_like(obj)

def _transform_item_by_item(self, obj, wrapper):
# iterate through columns
output = {}
inds = []
for i, col in enumerate(obj):
try:
output[col] = self[col].transform(wrapper)
inds.append(i)
except Exception:
pass

if len(output) == 0:
raise TypeError('Transform function invalid for data types')

columns = obj.columns
if len(output) < len(obj.columns):
columns = columns.take(inds)

return DataFrame(output, index=obj.index, columns=columns)


class DataFrameGroupBy(NDFrameGroupBy):
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,16 @@ def test_transform_select_columns(self):

assert_frame_equal(result, expected)

def test_transform_exclude_nuisance(self):
expected = {}
grouped = self.df.groupby('A')
expected['C'] = grouped['C'].transform(np.mean)
expected['D'] = grouped['D'].transform(np.mean)
expected = DataFrame(expected)

result = self.df.groupby('A').transform(np.mean)

assert_frame_equal(result, expected)

def test_with_na(self):
index = Index(np.arange(10))
Expand Down

0 comments on commit f62f571

Please sign in to comment.