Skip to content

Commit

Permalink
Consistent CategoricalDtype use in Categorical init
Browse files Browse the repository at this point in the history
Get a valid instance of `CategoricalDtype` as early as possible, and use that
throughout.
  • Loading branch information
TomAugspurger committed Sep 20, 2017
1 parent ed5c814 commit 416d1d7
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 15 deletions.
46 changes: 31 additions & 15 deletions pandas/core/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,21 @@ class Categorical(PandasObject):
def __init__(self, values, categories=None, ordered=None, dtype=None,
fastpath=False):

# Ways of specifying the dtype (prioritized ordered)
# 1. dtype is a CategoricalDtype
# a.) with known categories, use dtype.categories
# b.) else with Categorical values, use values.dtype
# c.) else, infer from values
# d.) specifying dtype=CategoricalDtype and categories is an error
# 2. dtype is a string 'category'
# a.) use categories, ordered
# b.) use values.dtype
# c.) infer from values
# 3. dtype is None
# a.) use categories, ordered
# b.) use values.dtype
# c.) infer from values

This comment has been minimized.

Copy link
@jorisvandenbossche

jorisvandenbossche Sep 20, 2017

Member

This overview looks sensible!


if dtype is not None:
if isinstance(dtype, compat.string_types):
if dtype == 'category':
Expand All @@ -247,20 +262,24 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
categories = dtype.categories
ordered = dtype.ordered

if ordered is None:
ordered = False
elif is_categorical(values):
dtype = values.dtype._from_categorical_dtype(values.dtype,
categories, ordered)
else:
dtype = CategoricalDtype(categories, ordered)

# At this point, dtype is always a CategoricalDtype
# if dtype.categories is None, we are inferring

if fastpath:
if dtype is None:
dtype = CategoricalDtype(categories, ordered)
self._codes = coerce_indexer_dtype(values, categories)
self._dtype = dtype
return

# sanitize input
if is_categorical_dtype(values):

# we are either a Series, CategoricalIndex
# we are either a Series or a CategoricalIndex
if isinstance(values, (ABCSeries, ABCCategoricalIndex)):
values = values._values

Expand All @@ -271,6 +290,7 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
values = values.get_values()

elif isinstance(values, (ABCIndexClass, ABCSeries)):
# we'll do inference later
pass

else:
Expand All @@ -288,12 +308,12 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
# "object" dtype to prevent this. In the end objects will be
# casted to int/... in the category assignment step.
if len(values) == 0 or isna(values).any():
dtype = 'object'
sanitize_dtype = 'object'
else:
dtype = None
values = _sanitize_array(values, None, dtype=dtype)
sanitize_dtype = None
values = _sanitize_array(values, None, dtype=sanitize_dtype)

if categories is None:
if dtype.categories is None:
try:
codes, categories = factorize(values, sort=True)
except TypeError:
Expand All @@ -310,7 +330,8 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
raise NotImplementedError("> 1 ndim Categorical are not "
"supported at this time")

if dtype is None or isinstance(dtype, str):
if dtype.categories is None:
# we're inferring from values
dtype = CategoricalDtype(categories, ordered)

else:
Expand All @@ -321,11 +342,6 @@ def __init__(self, values, categories=None, ordered=None, dtype=None,
# - the new one, where each value is also in the categories array
# (or np.nan)

# make sure that we always have the same type here, no matter what
# we get passed in
if dtype is None or isinstance(dtype, str):
dtype = CategoricalDtype(categories, ordered)

codes = _get_codes_for_values(values, dtype.categories)

# TODO: check for old style usage. These warnings should be removes
Expand Down
13 changes: 13 additions & 0 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,22 @@ def _from_fastpath(cls, categories=None, ordered=False):
self._finalize(categories, ordered, fastpath=True)
return self

@classmethod
def _from_categorical_dtype(cls, dtype, categories=None, ordered=None):
if categories is ordered is None:
return dtype
if categories is None:
categories = dtype.categories
if ordered is None:
ordered = dtype.ordered
return cls(categories, ordered)

def _finalize(self, categories, ordered, fastpath=False):
from pandas.core.indexes.base import Index

if ordered is None:
ordered = False

if categories is not None:
categories = Index(categories, tupleize_cols=False)
# validation
Expand Down
27 changes: 27 additions & 0 deletions pandas/tests/dtypes/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,3 +622,30 @@ def test_mixed(self):
a = CategoricalDtype(['a', 'b', 1, 2])
b = CategoricalDtype(['a', 'b', '1', '2'])
assert hash(a) != hash(b)

def test_from_categorical_dtype_identity(self):
c1 = Categorical([1, 2], categories=[1, 2, 3], ordered=True)
# Identity test for no changes
c2 = CategoricalDtype._from_categorical_dtype(c1)
assert c2 is c1

def test_from_categorical_dtype_categories(self):
c1 = Categorical([1, 2], categories=[1, 2, 3], ordered=True)
# override categories
result = CategoricalDtype._from_categorical_dtype(
c1, categories=[2, 3])
assert result == CategoricalDtype([2, 3], ordered=True)

def test_from_categorical_dtype_ordered(self):
c1 = Categorical([1, 2], categories=[1, 2, 3], ordered=True)
# override ordered
result = CategoricalDtype._from_categorical_dtype(
c1, ordered=False)
assert result == CategoricalDtype([1, 2, 3], ordered=False)

def test_from_categorical_dtype_both(self):
c1 = Categorical([1, 2], categories=[1, 2, 3], ordered=True)
# override ordered
result = CategoricalDtype._from_categorical_dtype(
c1, categories=[1, 2], ordered=False)
assert result == CategoricalDtype([1, 2], ordered=False)
41 changes: 41 additions & 0 deletions pandas/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,37 @@ def test_constructor_str_unknown(self):
with tm.assert_raises_regex(ValueError, "Unknown `dtype`"):
Categorical([1, 2], dtype="foo")

def test_constructor_from_categorical_with_dtype(self):
dtype = CategoricalDtype(['a', 'b', 'c'], ordered=True)
values = Categorical(['a', 'b', 'd'])
result = Categorical(values, dtype=dtype)
# We use dtype.categories, not values.categories
expected = Categorical(['a', 'b', 'd'], categories=['a', 'b', 'c'],
ordered=True)
tm.assert_categorical_equal(result, expected)

def test_constructor_from_categorical_with_unknown_dtype(self):
dtype = CategoricalDtype(None, ordered=True)
values = Categorical(['a', 'b', 'd'])
result = Categorical(values, dtype=dtype)
# We use values.categories, not dtype.categories
expected = Categorical(['a', 'b', 'd'], categories=['a', 'b', 'd'],
ordered=True)
tm.assert_categorical_equal(result, expected)

def test_contructor_from_categorical_string(self):
values = Categorical(['a', 'b', 'd'])
# use categories, ordered
result = Categorical(values, categories=['a', 'b', 'c'], ordered=True,
dtype='category')
expected = Categorical(['a', 'b', 'd'], categories=['a', 'b', 'c'],
ordered=True)
tm.assert_categorical_equal(result, expected)

# No string
result = Categorical(values, categories=['a', 'b', 'c'], ordered=True)
tm.assert_categorical_equal(result, expected)

def test_from_codes(self):

# too few categories
Expand Down Expand Up @@ -932,6 +963,16 @@ def test_set_dtype_nans(self):
tm.assert_numpy_array_equal(result.codes, np.array([0, -1, -1],
dtype='int8'))

def test_set_categories(self):
cat = Categorical(['a', 'b', 'c'], categories=['a', 'b', 'c', 'd'])
result = cat._set_categories(['a', 'b', 'c', 'd', 'e'])
expected = Categorical(['a', 'b', 'c'], categories=list('abcde'))
tm.assert_categorical_equal(result, expected)

# fastpath
result = cat._set_categories(['a', 'b', 'c', 'd', 'e'], fastpath=True)
tm.assert_categorical_equal(result, expected)

@pytest.mark.parametrize('values, categories, new_categories', [
# No NaNs, same cats, same order
(['a', 'b', 'a'], ['a', 'b'], ['a', 'b'],),
Expand Down

0 comments on commit 416d1d7

Please sign in to comment.