Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

Already on GitHub? Sign in to your account

ENH: Adding argmin, argmax to Series and DataFrame. #286

Closed
wants to merge 1 commit into
from
Jump to file or symbol
Failed to load files and symbols.
+195 −1
Split
View
@@ -2488,6 +2488,30 @@ def min(self, axis=0, skipna=True):
np.putmask(values, -np.isfinite(values), np.inf)
return Series(values.min(axis), index=self._get_agg_axis(axis))
+ def argmin(self, axis=0, skipna=True):
+ """
+ Return index of first occurence of minimum over requested axis.
+ NA/null values are excluded.
+
+ Parameters
+ ----------
+ axis : {0, 1}
+ 0 for row-wise, 1 for column-wise
+ skipna : boolean, default True
+ Exclude NA/null values. If an entire row/column is NA, the result
+ will be NA
+
+ Returns
+ -------
+ argmin : Series
+ """
+ values = self.values.copy()
+ if skipna and not issubclass(values.dtype.type, np.integer):
+ np.putmask(values, -np.isfinite(values), np.inf)
+ argmin_index = self._get_agg_axis([1, 0][axis])
+ return Series([argmin_index[i] for i in values.argmin(axis)],
+ index=self._get_agg_axis(axis))
+
def max(self, axis=0, skipna=True):
"""
Return maximum over requested axis. NA/null values are excluded
@@ -2498,7 +2522,7 @@ def max(self, axis=0, skipna=True):
0 for row-wise, 1 for column-wise
skipna : boolean, default True
Exclude NA/null values. If an entire row/column is NA, the result
- will be NA
+ will be first index.
Returns
-------
@@ -2509,6 +2533,30 @@ def max(self, axis=0, skipna=True):
np.putmask(values, -np.isfinite(values), -np.inf)
return Series(values.max(axis), index=self._get_agg_axis(axis))
+ def argmax(self, axis=0, skipna=True):
+ """
+ Return index of first occurence of maximum over requested axis.
+ NA/null values are excluded.
+
+ Parameters
+ ----------
+ axis : {0, 1}
+ 0 for row-wise, 1 for column-wise
+ skipna : boolean, default True
+ Exclude NA/null values. If an entire row/column is NA, the result
+ will be first index.
+
+ Returns
+ -------
+ max : Series
+ """
+ values = self.values.copy()
+ if skipna and not issubclass(values.dtype.type, np.integer):
+ np.putmask(values, -np.isfinite(values), -np.inf)
+ argmax_index = self._get_agg_axis([1, 0][axis])
+ return Series([argmax_index[i] for i in values.argmax(axis)],
+ index=self._get_agg_axis(axis))
+
def prod(self, axis=0, skipna=True):
"""
Return product over requested axis. NA/null values are treated as 1
View
@@ -718,6 +718,25 @@ def min(self, axis=None, out=None, skipna=True):
np.putmask(arr, isnull(arr), np.inf)
return arr.min()
+ def argmin(self, axis=None, out=None, skipna=True):
+ """
+ Index of first occurence of minimum of values.
+
+ Parameters
+ ----------
+ skipna : boolean, default True
+ Exclude NA/null values
+
+ Returns
+ -------
+ Index of mimimum of values.
+ """
+ arr = self.values.copy()
+ if skipna:
+ if not issubclass(arr.dtype.type, np.integer):
+ np.putmask(arr, isnull(arr), np.inf)
+ return self.index[arr.argmin()]
+
def max(self, axis=None, out=None, skipna=True):
"""
Maximum of values
@@ -737,6 +756,25 @@ def max(self, axis=None, out=None, skipna=True):
np.putmask(arr, isnull(arr), -np.inf)
return arr.max()
+ def argmax(self, axis=None, out=None, skipna=True):
+ """
+ Index of first occurence of maximum of values.
+
+ Parameters
+ ----------
+ skipna : boolean, default True
+ Exclude NA/null values
+
+ Returns
+ -------
+ Index of mimimum of values.
+ """
+ arr = self.values.copy()
+ if skipna:
+ if not issubclass(arr.dtype.type, np.integer):
+ np.putmask(arr, isnull(arr), -np.inf)
+ return self.index[arr.argmax()]
+
def std(self, axis=None, dtype=None, out=None, ddof=1, skipna=True):
"""
Unbiased standard deviation of values
View
@@ -2729,10 +2729,74 @@ def test_min(self):
self._check_stat_op('min', np.min)
self._check_stat_op('min', np.min, frame=self.intframe)
+ def test_argmin(self):
+ def validate(f, s, axis, skipna):
+ def get_result(f, i, v, axis, skipna):
+ if axis == 0:
+ return (f[i][v], f[i].min(skipna=skipna))
+ else:
+ return (f[v][i], f.ix[i].min(skipna=skipna))
+ for i, v in s.iteritems():
+ (r1, r2) = get_result(f, i, v, axis, skipna)
+ if np.isnan(r1) or np.isinf(r1):
+ self.assert_(np.isnan(r2) or np.isinf(r2))
+ elif np.isnan(r2) or np.isinf(r2):
+ self.assert_(np.isnan(r1) or np.isinf(r1))
+ else:
+ self.assertEqual(r1, r2)
+
+ frame = self.frame
+ frame.ix[5:10] = np.nan
+ frame.ix[15:20, -2:] = np.nan
+ for skipna in [True, False]:
+ for axis in [0, 1]:
+ validate(frame,
+ frame.argmin(axis=axis, skipna=skipna),
+ axis,
+ skipna)
+ validate(self.intframe,
+ self.intframe.argmin(axis=axis, skipna=skipna),
+ axis,
+ skipna)
+
+ self.assertRaises(Exception, frame.argmin, axis=2)
+
def test_max(self):
self._check_stat_op('max', np.max)
self._check_stat_op('max', np.max, frame=self.intframe)
+ def test_argmax(self):
+ def validate(f, s, axis, skipna):
+ def get_result(f, i, v, axis, skipna):
+ if axis == 0:
+ return (f[i][v], f[i].max(skipna=skipna))
+ else:
+ return (f[v][i], f.ix[i].max(skipna=skipna))
+ for i, v in s.iteritems():
+ (r1, r2) = get_result(f, i, v, axis, skipna)
+ if np.isnan(r1) or np.isinf(r1):
+ self.assert_(np.isnan(r2) or np.isinf(r2))
+ elif np.isnan(r2) or np.isinf(r2):
+ self.assert_(np.isnan(r1) or np.isinf(r1))
+ else:
+ self.assertEqual(r1, r2)
+
+ frame = self.frame
+ frame.ix[5:10] = np.nan
+ frame.ix[15:20, -2:] = np.nan
+ for skipna in [True, False]:
+ for axis in [0, 1]:
+ validate(frame,
+ frame.argmax(axis=axis, skipna=skipna),
+ axis,
+ skipna)
+ validate(self.intframe,
+ self.intframe.argmax(axis=axis, skipna=skipna),
+ axis,
+ skipna)
+
+ self.assertRaises(Exception, frame.argmax, axis=2)
+
def test_mad(self):
f = lambda x: np.abs(x - x.mean()).mean()
self._check_stat_op('mad', f)
@@ -492,9 +492,53 @@ def test_prod(self):
def test_min(self):
self._check_stat_op('min', np.min)
+ def test_argmin(self):
+ """
+ test argmin
+ _check_stat_op approach can not be used here because of isnull check.
+ """
+ # add some NaNs
+ self.series[5:15] = np.NaN
+
+ # skipna or no
+ self.assertEqual(self.series[self.series.argmin()], self.series.min())
+ self.assert_(isnull(self.series[self.series.argmin(skipna=False)]))
+
+ # no NaNs
+ nona = self.series.dropna()
+ self.assertEqual(nona[nona.argmin()], nona.min())
+ self.assertEqual(nona.index.values.tolist().index(nona.argmin()),
+ nona.values.argmin())
+
+ # all NaNs
+ allna = self.series * nan
+ self.assertEqual(allna.argmin(), allna.index[0])
+
def test_max(self):
self._check_stat_op('max', np.max)
+ def test_argmax(self):
+ """
+ test argmax
+ _check_stat_op approach can not be used here because of isnull check.
+ """
+ # add some NaNs
+ self.series[5:15] = np.NaN
+
+ # skipna or no
+ self.assertEqual(self.series[self.series.argmax()], self.series.max())
+ self.assert_(isnull(self.series[self.series.argmax(skipna=False)]))
+
+ # no NaNs
+ nona = self.series.dropna()
+ self.assertEqual(nona[nona.argmax()], nona.max())
+ self.assertEqual(nona.index.values.tolist().index(nona.argmax()),
+ nona.values.argmax())
+
+ # all NaNs
+ allna = self.series * nan
+ self.assertEqual(allna.argmax(), allna.index[0])
+
def test_std(self):
alt = lambda x: np.std(x, ddof=1)
self._check_stat_op('std', alt)