Skip to content

Commit

Permalink
Fix: unequal comparisons of categorical and scalar
Browse files Browse the repository at this point in the history
Before, unequal comparisons were not checking the order of the
categories.

This was due to a conversion to an ndarray, which turned the
comparison to one between ndarray and scalar, which of course
has no categories to take into account.

Also add test cases and remove the one which actually tested the
wrong behaviour.
  • Loading branch information
jankatins committed Apr 10, 2015
1 parent 2734fff commit 86e0376
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 14 deletions.
28 changes: 17 additions & 11 deletions pandas/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,20 +594,26 @@ def wrapper(self, other):

mask = isnull(self)

values = self.get_values()
other = _index.convert_scalar(values,_values_from_object(other))
if com.is_categorical_dtype(self):
# cats are a special case as get_values() would return an ndarray, which would then
# not take categories ordering into account
# we can go directly to op, as the na_op would just test again and dispatch to it.
res = op(self.values, other)
else:
values = self.get_values()
other = _index.convert_scalar(values,_values_from_object(other))

if issubclass(values.dtype.type, (np.datetime64, np.timedelta64)):
values = values.view('i8')
if issubclass(values.dtype.type, (np.datetime64, np.timedelta64)):
values = values.view('i8')

# scalars
res = na_op(values, other)
if np.isscalar(res):
raise TypeError('Could not compare %s type with Series'
% type(other))
# scalars
res = na_op(values, other)
if np.isscalar(res):
raise TypeError('Could not compare %s type with Series'
% type(other))

# always return a full value series here
res = _values_from_object(res)
# always return a full value series here
res = _values_from_object(res)

res = pd.Series(res, index=self.index, name=self.name,
dtype='bool')
Expand Down
35 changes: 32 additions & 3 deletions pandas/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def f():
Categorical([1,2], [1,2,np.nan, np.nan])
self.assertRaises(ValueError, f)

# The default should be unordered
c1 = Categorical(["a", "b", "c", "a"])
self.assertFalse(c1.ordered)

# Categorical as input
c1 = Categorical(["a", "b", "c", "a"])
Expand Down Expand Up @@ -367,6 +370,13 @@ def f():
self.assertRaises(TypeError, lambda: a < cat)
self.assertRaises(TypeError, lambda: a < cat_rev)

# Make sure that unequal comparison take the categories order in account
cat_rev = pd.Categorical(list("abc"), categories=list("cba"), ordered=True)
exp = np.array([True, False, False])
res = cat_rev > "b"
self.assert_numpy_array_equal(res, exp)


def test_na_flags_int_categories(self):
# #1457

Expand Down Expand Up @@ -2390,6 +2400,18 @@ def test_comparisons(self):
exp = Series([False, False, True])
tm.assert_series_equal(res, exp)

scalar = base[1]
res = cat > scalar
exp = Series([False, False, True])
exp2 = cat.values > scalar
tm.assert_series_equal(res, exp)
tm.assert_numpy_array_equal(res.values, exp2)
res_rev = cat_rev > scalar
exp_rev = Series([True, False, False])
exp_rev2 = cat_rev.values > scalar
tm.assert_series_equal(res_rev, exp_rev)
tm.assert_numpy_array_equal(res_rev.values, exp_rev2)

# Only categories with same categories can be compared
def f():
cat > cat_rev
Expand All @@ -2408,9 +2430,16 @@ def f():
self.assertRaises(TypeError, lambda: a < cat)
self.assertRaises(TypeError, lambda: a < cat_rev)

# Categoricals can be compared to scalar values
res = cat_rev > base[0]
tm.assert_series_equal(res, exp)
# unequal comparison should raise for unordered cats
cat = Series(Categorical(list("abc")))
def f():
cat > "b"
self.assertRaises(TypeError, f)
cat = Series(Categorical(list("abc"), ordered=False))
def f():
cat > "b"
self.assertRaises(TypeError, f)


# And test NaN handling...
cat = Series(Categorical(["a","b","c", np.nan]))
Expand Down

0 comments on commit 86e0376

Please sign in to comment.