Skip to content

Commit

Permalink
FEA Refactor CategoricalEncoder into OneHotEncoder and OrdinalEncoder (
Browse files Browse the repository at this point in the history
…#10523)

Deprecated some OneHotEncoder behaviour
  • Loading branch information
jorisvandenbossche authored and jnothman committed Jun 21, 2018
1 parent bb5110b commit 007aa71
Show file tree
Hide file tree
Showing 17 changed files with 1,454 additions and 1,075 deletions.
5 changes: 2 additions & 3 deletions doc/datasets/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,8 @@ refer to:
for reading WAV files into a numpy array

Categorical (or nominal) features stored as strings (common in pandas DataFrames)
will need converting to integers, and integer categorical variables may be best
exploited when encoded as one-hot variables
(:class:`sklearn.preprocessing.OneHotEncoder`) or similar.
will need converting to numerical features using :class:`sklearn.preprocessing.OneHotEncoder`
or :class:`sklearn.preprocessing.OrdinalEncoder` or similar.
See :ref:`preprocessing`.

Note: if you manage your own numerical data it is recommended to use an
Expand Down
6 changes: 4 additions & 2 deletions doc/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,10 @@ General Concepts
tree-based models such as random forests and gradient boosting
models that often work better and faster with integer-coded
categorical variables.
:class:`~sklearn.preprocessing.CategoricalEncoder` helps
encoding string-valued categorical features.
:class:`~sklearn.preprocessing.OrdinalEncoder` helps encoding
string-valued categorical features as ordinal integers, and
:class:`~sklearn.preprocessing.OneHotEncoder` can be used to
one-hot encode categorical features.
See also :ref:`preprocessing_categorical_features` and the
`http://contrib.scikit-learn.org/categorical-encoding
<category_encoders>`_ package for tools related to encoding
Expand Down
2 changes: 1 addition & 1 deletion doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,7 @@ Model validation
preprocessing.MinMaxScaler
preprocessing.Normalizer
preprocessing.OneHotEncoder
preprocessing.CategoricalEncoder
preprocessing.OrdinalEncoder
preprocessing.PolynomialFeatures
preprocessing.PowerTransformer
preprocessing.QuantileTransformer
Expand Down
35 changes: 18 additions & 17 deletions doc/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -508,15 +508,13 @@ Such features can be efficiently coded as integers, for instance
``[1, 2, 1]``.

To convert categorical features to such integer codes, we can use the
:class:`CategoricalEncoder`. When specifying that we want to perform an
ordinal encoding, the estimator transforms each categorical feature to one
:class:`OrdinalEncoder`. This estimator transforms each categorical feature to one
new feature of integers (0 to n_categories - 1)::

>>> enc = preprocessing.CategoricalEncoder(encoding='ordinal')
>>> enc = preprocessing.OrdinalEncoder()
>>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']]
>>> enc.fit(X) # doctest: +ELLIPSIS
CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>,
encoding='ordinal', handle_unknown='error')
OrdinalEncoder(categories='auto', dtype=<... 'numpy.float64'>)
>>> enc.transform([['female', 'from US', 'uses Safari']])
array([[0., 1., 1.]])

Expand All @@ -528,18 +526,19 @@ browsers was ordered arbitrarily).
Another possibility to convert categorical features to features that can be used
with scikit-learn estimators is to use a one-of-K, also known as one-hot or
dummy encoding.
This type of encoding is the default behaviour of the :class:`CategoricalEncoder`.
The :class:`CategoricalEncoder` then transforms each categorical feature with
This type of encoding can be obtained with the :class:`OneHotEncoder`,
which transforms each categorical feature with
``n_categories`` possible values into ``n_categories`` binary features, with
one of them 1, and all others 0.

Continuing the example above::

>>> enc = preprocessing.CategoricalEncoder()
>>> enc = preprocessing.OneHotEncoder()
>>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']]
>>> enc.fit(X) # doctest: +ELLIPSIS
CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>,
encoding='onehot', handle_unknown='error')
OneHotEncoder(categorical_features=None, categories=None,
dtype=<... 'numpy.float64'>, handle_unknown='error',
n_values=None, sparse=True)
>>> enc.transform([['female', 'from US', 'uses Safari'],
... ['male', 'from Europe', 'uses Safari']]).toarray()
array([[1., 0., 0., 1., 0., 1.],
Expand All @@ -558,14 +557,15 @@ dataset::
>>> genders = ['female', 'male']
>>> locations = ['from Africa', 'from Asia', 'from Europe', 'from US']
>>> browsers = ['uses Chrome', 'uses Firefox', 'uses IE', 'uses Safari']
>>> enc = preprocessing.CategoricalEncoder(categories=[genders, locations, browsers])
>>> enc = preprocessing.OneHotEncoder(categories=[genders, locations, browsers])
>>> # Note that for there are missing categorical values for the 2nd and 3rd
>>> # feature
>>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']]
>>> enc.fit(X) # doctest: +ELLIPSIS
CategoricalEncoder(categories=[...],
dtype=<... 'numpy.float64'>, encoding='onehot',
handle_unknown='error')
OneHotEncoder(categorical_features=None,
categories=[...],
dtype=<... 'numpy.float64'>, handle_unknown='error',
n_values=None, sparse=True)
>>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray()
array([[1., 0., 0., 1., 0., 0., 1., 0., 0., 0.]])

Expand All @@ -577,11 +577,12 @@ during transform, no error will be raised but the resulting one-hot encoded
columns for this feature will be all zeros
(``handle_unknown='ignore'`` is only supported for one-hot encoding)::

>>> enc = preprocessing.CategoricalEncoder(handle_unknown='ignore')
>>> enc = preprocessing.OneHotEncoder(handle_unknown='ignore')
>>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']]
>>> enc.fit(X) # doctest: +ELLIPSIS
CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>,
encoding='onehot', handle_unknown='ignore')
OneHotEncoder(categorical_features=None, categories=None,
dtype=<... 'numpy.float64'>, handle_unknown='ignore',
n_values=None, sparse=True)
>>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray()
array([[1., 0., 0., 0., 0., 0.]])

Expand Down
23 changes: 17 additions & 6 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ Classifiers and regressors

Preprocessing

- Added :class:`preprocessing.CategoricalEncoder`, which allows to encode
categorical features as a numeric array, either using a one-hot (or dummy)
encoding scheme or by converting to ordinal integers. Compared to the
existing :class:`~preprocessing.OneHotEncoder`, this new class handles
- Expanded :class:`preprocessing.OneHotEncoder` to allow to encode
categorical string features as a numeric array using a one-hot (or dummy)
encoding scheme, and added :class:`preprocessing.OrdinalEncoder` to
convert to ordinal integers. Those two classes now handle
encoding of all feature types (also handles string-valued features) and
derives the categories based on the unique values in the features instead of
the maximum value in the features. :issue:`9151` by :user:`Vighnesh Birodkar
<vighneshbirodkar>` and `Joris Van den Bossche`_.
the maximum value in the features. :issue:`9151` and :issue:`10521` by
:user:`Vighnesh Birodkar <vighneshbirodkar>` and `Joris Van den Bossche`_.

- Added :class:`compose.ColumnTransformer`, which allows to apply
different transformers to different columns of arrays or pandas
Expand Down Expand Up @@ -584,6 +584,17 @@ Linear, kernelized and related models
:class:`linear_model.LogisticRegression` when ``verbose`` is set to 0.
:issue:`10881` by :user:`Alexandre Sevin <AlexandreSev>`.

Preprocessing

- Deprecate ``n_values`` and ``categorical_features`` parameters and
``active_features_``, ``feature_indices_`` and ``n_values_`` attributes
of :class:`preprocessing.OneHotEncoder`. The ``n_values`` parameter can be
replaced with the new ``categories`` parameter, and the attributes with the
new ``categories_`` attribute. Selecting the categorical features with
the ``categorical_features`` parameter is now better supported using the
:class:`compose.ColumnTransformer`.
:issue:`10521` by `Joris Van den Bossche`_.

Decomposition, manifold learning and clustering

- Deprecate ``precomputed`` parameter in function
Expand Down
6 changes: 3 additions & 3 deletions examples/ensemble/plot_feature_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import (RandomTreesEmbedding, RandomForestClassifier,
GradientBoostingClassifier)
from sklearn.preprocessing import CategoricalEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve
from sklearn.pipeline import make_pipeline
Expand Down Expand Up @@ -62,7 +62,7 @@

# Supervised transformation based on random forests
rf = RandomForestClassifier(max_depth=3, n_estimators=n_estimator)
rf_enc = CategoricalEncoder()
rf_enc = OneHotEncoder()
rf_lm = LogisticRegression()
rf.fit(X_train, y_train)
rf_enc.fit(rf.apply(X_train))
Expand All @@ -72,7 +72,7 @@
fpr_rf_lm, tpr_rf_lm, _ = roc_curve(y_test, y_pred_rf_lm)

grd = GradientBoostingClassifier(n_estimators=n_estimator)
grd_enc = CategoricalEncoder()
grd_enc = OneHotEncoder()
grd_lm = LogisticRegression()
grd.fit(X_train, y_train)
grd_enc.fit(grd.apply(X_train)[:, :, 0])
Expand Down
8 changes: 4 additions & 4 deletions sklearn/compose/_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,19 +636,19 @@ def make_column_transformer(*transformers, **kwargs):
Examples
--------
>>> from sklearn.preprocessing import StandardScaler, CategoricalEncoder
>>> from sklearn.preprocessing import StandardScaler, OneHotEncoder
>>> from sklearn.compose import make_column_transformer
>>> make_column_transformer(
... (['numerical_column'], StandardScaler()),
... (['categorical_column'], CategoricalEncoder()))
... (['categorical_column'], OneHotEncoder()))
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
ColumnTransformer(n_jobs=1, remainder='passthrough',
transformer_weights=None,
transformers=[('standardscaler',
StandardScaler(...),
['numerical_column']),
('categoricalencoder',
CategoricalEncoder(...),
('onehotencoder',
OneHotEncoder(...),
['categorical_column'])])
"""
Expand Down
3 changes: 2 additions & 1 deletion sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,7 +1932,8 @@ def fit_transform(self, X, y=None, sample_weight=None):
super(RandomTreesEmbedding, self).fit(X, y,
sample_weight=sample_weight)

self.one_hot_encoder_ = OneHotEncoder(sparse=self.sparse_output)
self.one_hot_encoder_ = OneHotEncoder(sparse=self.sparse_output,
categories='auto')
return self.one_hot_encoder_.fit_transform(self.apply(X))

def transform(self, X):
Expand Down
4 changes: 2 additions & 2 deletions sklearn/feature_extraction/dict_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class DictVectorizer(BaseEstimator, TransformerMixin):
However, note that this transformer will only do a binary one-hot encoding
when feature values are of type string. If categorical features are
represented as numeric values such as int, the DictVectorizer can be
followed by :class:`sklearn.preprocessing.CategoricalEncoder` to complete
followed by :class:`sklearn.preprocessing.OneHotEncoder` to complete
binary one-hot encoding.
Features that do not occur in a sample (mapping) will have a zero value
Expand Down Expand Up @@ -89,7 +89,7 @@ class DictVectorizer(BaseEstimator, TransformerMixin):
See also
--------
FeatureHasher : performs vectorization using only a hash function.
sklearn.preprocessing.CategoricalEncoder : handles nominal/categorical
sklearn.preprocessing.OrdinalEncoder : handles nominal/categorical
features encoded as columns of arbitrary data types.
"""

Expand Down
3 changes: 1 addition & 2 deletions sklearn/feature_extraction/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ class FeatureHasher(BaseEstimator, TransformerMixin):
See also
--------
DictVectorizer : vectorizes string-valued features using a hash table.
sklearn.preprocessing.OneHotEncoder : handles nominal/categorical features
encoded as columns of integers.
sklearn.preprocessing.OneHotEncoder : handles nominal/categorical features.
"""

def __init__(self, n_features=(2 ** 20), input_type="dict",
Expand Down
9 changes: 6 additions & 3 deletions sklearn/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,21 @@
from .data import minmax_scale
from .data import quantile_transform
from .data import power_transform
from .data import OneHotEncoder
from .data import PowerTransformer
from .data import CategoricalEncoder
from .data import PolynomialFeatures

from ._encoders import OneHotEncoder
from ._encoders import OrdinalEncoder

from .label import label_binarize
from .label import LabelBinarizer
from .label import LabelEncoder
from .label import MultiLabelBinarizer

from .imputation import Imputer

# stub, remove in version 0.21
from .data import CategoricalEncoder # noqa

__all__ = [
'Binarizer',
Expand All @@ -48,7 +51,7 @@
'QuantileTransformer',
'Normalizer',
'OneHotEncoder',
'CategoricalEncoder',
'OrdinalEncoder',
'PowerTransformer',
'RobustScaler',
'StandardScaler',
Expand Down

0 comments on commit 007aa71

Please sign in to comment.