Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

Already on GitHub? Sign in to your account

ENH: DataFrame where and mask #2109 #2151

Closed
wants to merge 1 commit into
from
Jump to file or symbol
Failed to load files and symbols.
+71 −0
Split
View
@@ -4832,6 +4832,49 @@ def combineMult(self, other):
"""
return self.mul(other, fill_value=1.)
+ def where(self, cond, other):
+ """
+ Return a DataFrame with the same shape as self and whose corresponding
+ entries are from self where cond is True and otherwise are from other.
+
+
+ Parameters
+ ----------
+ cond: boolean DataFrame or array
+ other: scalar or DataFrame
+
+ Returns
+ -------
+ wh: DataFrame
+ """
+ if isinstance(cond, np.ndarray):
+ if cond.shape != self.shape:
+ raise ValueError('Array onditional must be same shape as self')
+ cond = self._constructor(cond, index=self.index, columns=self.columns)
+ if cond.shape != self.shape:
+ cond = cond.reindex(self.index, columns=self.columns)
+ cond = cond.fillna(False)
+
+ if isinstance(other, DataFrame):
+ _, other = self.align(other, join='left', fill_value=np.nan)
+
+ rs = np.where(cond, self, other)
+ return self._constructor(rs, self.index, self.columns)
+
+ def mask(self, cond):
+ """
+ Returns copy of self whose values are replaced with nan if the
+ corresponding entry in cond is False
+
+ Parameters
+ ----------
+ cond: boolean DataFrame or array
+
+ Returns
+ -------
+ wh: DataFrame
+ """
+ return self.where(cond, np.nan)
_EMPTY_SERIES = Series([])
View
@@ -5063,6 +5063,34 @@ def test_align_int_fill_bug(self):
expected = df2 - df2.mean()
assert_frame_equal(result, expected)
+ def test_where(self):
+ df = DataFrame(np.random.randn(5, 3))
+ cond = df > 0
+
+ other1 = df + 1
+ rs = df.where(cond, other1)
+ for k, v in rs.iteritems():
+ assert_series_equal(v, np.where(cond[k], df[k], other1[k]))
+
+ other2 = (df + 1).values
+ rs = df.where(cond, other2)
+ for k, v in rs.iteritems():
+ assert_series_equal(v, np.where(cond[k], df[k], other2[:, k]))
+
+ other5 = np.nan
+ rs = df.where(cond, other5)
+ for k, v in rs.iteritems():
+ assert_series_equal(v, np.where(cond[k], df[k], other5))
+
+ assert_frame_equal(rs, df.mask(cond))
+
+ err1 = (df + 1).values[0:2, :]
+ self.assertRaises(ValueError, df.where, cond, err1)
+
+ err2 = cond.ix[:2, :].values
+ self.assertRaises(ValueError, df.where, err2, other1)
+
+
#----------------------------------------------------------------------
# Transposing