Skip to content

Commit

Permalink
Merge remote-tracking branch 'chang/groupby-last'
Browse files Browse the repository at this point in the history
* chang/groupby-last:
  cython methods for group bins #1809
  BUG: allow non-numeric columns in groupby first/last #1809
  • Loading branch information
wesm committed Sep 20, 2012
2 parents 163cc8a + d0c9957 commit 8fc8d6f
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 27 deletions.
74 changes: 51 additions & 23 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ class DataError(GroupByError):
class SpecificationError(GroupByError):
pass

def _groupby_function(name, alias, npfunc):
def _groupby_function(name, alias, npfunc, numeric_only=True):
def f(self):
try:
return self._cython_agg_general(alias)
return self._cython_agg_general(alias, numeric_only=numeric_only)
except Exception:
return self.aggregate(lambda x: npfunc(x, axis=self.axis))

Expand Down Expand Up @@ -350,8 +350,9 @@ def size(self):
prod = _groupby_function('prod', 'prod', np.prod)
min = _groupby_function('min', 'min', np.min)
max = _groupby_function('max', 'max', np.max)
first = _groupby_function('first', 'first', _first_compat)
last = _groupby_function('last', 'last', _last_compat)
first = _groupby_function('first', 'first', _first_compat,
numeric_only=False)
last = _groupby_function('last', 'last', _last_compat, numeric_only=False)

def ohlc(self):
"""
Expand All @@ -370,10 +371,11 @@ def picker(arr):
return np.nan
return self.agg(picker)

def _cython_agg_general(self, how):
def _cython_agg_general(self, how, numeric_only=True):
output = {}
for name, obj in self._iterate_slices():
if not issubclass(obj.dtype.type, (np.number, np.bool_)):
is_numeric = issubclass(obj.dtype.type, (np.number, np.bool_))
if numeric_only and not is_numeric:
continue

result, names = self.grouper.aggregate(obj.values, how)
Expand Down Expand Up @@ -668,6 +670,11 @@ def get_group_levels(self):
'last': lib.group_last
}

_cython_object_functions = {
'first' : lambda a, b, c, d: lib.group_nth_object(a, b, c, d, 1),
'last' : lib.group_last_object
}

_cython_transforms = {
'std' : np.sqrt
}
Expand All @@ -681,7 +688,13 @@ def get_group_levels(self):
_filter_empty_groups = True

def aggregate(self, values, how, axis=0):
values = com._ensure_float64(values)
values = com.ensure_float(values)
is_numeric = True

if not issubclass(values.dtype.type, (np.number, np.bool_)):
values = values.astype(object)
is_numeric = False

arity = self._cython_arity.get(how, 1)

vdim = values.ndim
Expand All @@ -698,15 +711,19 @@ def aggregate(self, values, how, axis=0):
out_shape = (self.ngroups,) + values.shape[1:]

# will be filled in Cython function
result = np.empty(out_shape, dtype=np.float64)
result = np.empty(out_shape, dtype=values.dtype)
counts = np.zeros(self.ngroups, dtype=np.int64)

result = self._aggregate(result, counts, values, how)
result = self._aggregate(result, counts, values, how, is_numeric)

if self._filter_empty_groups:
if result.ndim == 2:
result = lib.row_bool_subset(result,
(counts > 0).view(np.uint8))
if is_numeric:
result = lib.row_bool_subset(result,
(counts > 0).view(np.uint8))
else:
result = lib.row_bool_subset_object(result,
(counts > 0).view(np.uint8))
else:
result = result[counts > 0]

Expand All @@ -724,8 +741,11 @@ def aggregate(self, values, how, axis=0):

return result, names

def _aggregate(self, result, counts, values, how):
agg_func = self._cython_functions[how]
def _aggregate(self, result, counts, values, how, is_numeric):
fdict = self._cython_functions
if not is_numeric:
fdict = self._cython_object_functions
agg_func = fdict[how]
trans_func = self._cython_transforms.get(how, lambda x: x)

comp_ids, _, ngroups = self.group_info
Expand Down Expand Up @@ -913,14 +933,22 @@ def names(self):
'last': lib.group_last_bin
}

_cython_object_functions = {
'first' : lambda a, b, c, d: lib.group_nth_bin_object(a, b, c, d, 1),
'last' : lib.group_last_bin_object
}

_name_functions = {
'ohlc' : lambda *args: ['open', 'high', 'low', 'close']
}

_filter_empty_groups = True

def _aggregate(self, result, counts, values, how):
agg_func = self._cython_functions[how]
def _aggregate(self, result, counts, values, how, is_numeric=True):
fdict = self._cython_functions
if not is_numeric:
fdict = self._cython_object_functions
agg_func = fdict[how]
trans_func = self._cython_transforms.get(how, lambda x: x)

if values.ndim > 3:
Expand Down Expand Up @@ -1385,8 +1413,8 @@ def _iterate_slices(self):

yield val, slicer(val)

def _cython_agg_general(self, how):
new_blocks = self._cython_agg_blocks(how)
def _cython_agg_general(self, how, numeric_only=True):
new_blocks = self._cython_agg_blocks(how, numeric_only=numeric_only)
return self._wrap_agged_blocks(new_blocks)

def _wrap_agged_blocks(self, blocks):
Expand All @@ -1408,18 +1436,20 @@ def _wrap_agged_blocks(self, blocks):

_block_agg_axis = 0

def _cython_agg_blocks(self, how):
def _cython_agg_blocks(self, how, numeric_only=True):
data, agg_axis = self._get_data_to_aggregate()

new_blocks = []

for block in data.blocks:
values = block.values
if not issubclass(values.dtype.type, (np.number, np.bool_)):
is_numeric = issubclass(values.dtype.type, (np.number, np.bool_))
if numeric_only and not is_numeric:
continue

values = com._ensure_float64(values)
result, names = self.grouper.aggregate(values, how, axis=agg_axis)
if is_numeric:
values = com.ensure_float(values)
result, _ = self.grouper.aggregate(values, how, axis=agg_axis)
newb = make_block(result, block.items, block.ref_items)
new_blocks.append(newb)

Expand Down Expand Up @@ -2210,5 +2240,3 @@ def complete_dataframe(obj, prev_completions):
install_ipython_completers()
except Exception:
pass


Loading

0 comments on commit 8fc8d6f

Please sign in to comment.