diff --git a/RELEASE.rst b/RELEASE.rst index 610dc390898df..7faaf84f55c72 100644 --- a/RELEASE.rst +++ b/RELEASE.rst @@ -112,7 +112,7 @@ pandas 0.8.0 - Enable storage of sparse data structures in HDFStore (#85) - Enable Series.asof to work with arrays of timestamp inputs - Cython implementation of DataFrame.corr speeds up by > 100x (#1349, #1354) - + - Exclude "nuisance" columns automatically in GroupBy.transform (#1364) **API Changes** diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index c83d5f7831fca..7f8791269cc49 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -1626,14 +1626,17 @@ def transform(self, func, *args, **kwargs): obj = self._obj_with_exclusions gen = self.grouper.get_iterator(obj, axis=self.axis) + wrapper = lambda x: func(x, *args, **kwargs) + for name, group in gen: object.__setattr__(group, 'name', name) try: - wrapper = lambda x: func(x, *args, **kwargs) res = group.apply(wrapper, axis=self.axis) + except TypeError: + return self._transform_item_by_item(obj, wrapper) except Exception: # pragma: no cover - res = func(group, *args, **kwargs) + res = wrapper(group) # broadcasting if isinstance(res, Series): @@ -1651,6 +1654,25 @@ def transform(self, func, *args, **kwargs): axis=self.axis, verify_integrity=False) return concatenated.reindex_like(obj) + def _transform_item_by_item(self, obj, wrapper): + # iterate through columns + output = {} + inds = [] + for i, col in enumerate(obj): + try: + output[col] = self[col].transform(wrapper) + inds.append(i) + except Exception: + pass + + if len(output) == 0: + raise TypeError('Transform function invalid for data types') + + columns = obj.columns + if len(output) < len(obj.columns): + columns = columns.take(inds) + + return DataFrame(output, index=obj.index, columns=columns) class DataFrameGroupBy(NDFrameGroupBy): diff --git a/pandas/tests/test_groupby.py b/pandas/tests/test_groupby.py index 340ccfadb61b1..c2e24e7e90c53 100644 --- a/pandas/tests/test_groupby.py +++ b/pandas/tests/test_groupby.py @@ -419,6 +419,16 @@ def test_transform_select_columns(self): assert_frame_equal(result, expected) + def test_transform_exclude_nuisance(self): + expected = {} + grouped = self.df.groupby('A') + expected['C'] = grouped['C'].transform(np.mean) + expected['D'] = grouped['D'].transform(np.mean) + expected = DataFrame(expected) + + result = self.df.groupby('A').transform(np.mean) + + assert_frame_equal(result, expected) def test_with_na(self): index = Index(np.arange(10))