Skip to content

Commit

Permalink
Rename MICEImputer to ChainedImputer (#11314)
Browse files Browse the repository at this point in the history
  • Loading branch information
sergeyf authored and ogrisel committed Jun 22, 2018
1 parent 5b29ae6 commit 93382cc
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 113 deletions.
9 changes: 8 additions & 1 deletion doc/glossary.rst
Expand Up @@ -435,6 +435,13 @@ General Concepts
hyper-parameter
See :term:`parameter`.

impute
imputation
Most machine learning algorithms require that their inputs have no
:term:`missing values`, and will not work if this requirement is
violated. Algorithms that attempt to fill in (or impute) missing values
are referred to as imputation algorithms.

indexable
An :term:`array-like`, :term:`sparse matrix`, pandas DataFrame or
sequence (usually a list).
Expand Down Expand Up @@ -486,7 +493,7 @@ General Concepts
do (e.g. in :class:`impute.SimpleImputer`), NaN is the preferred
representation of missing values in float arrays. If the array has
integer dtype, NaN cannot be represented. For this reason, we support
specifying another ``missing_values`` value when imputation or
specifying another ``missing_values`` value when :term:`imputation` or
learning can be performed in integer space. :term:`Unlabeled data`
is a special case of missing values in the :term:`target`.

Expand Down
2 changes: 1 addition & 1 deletion doc/modules/classes.rst
Expand Up @@ -653,7 +653,7 @@ Kernels:
:template: class.rst

impute.SimpleImputer
impute.MICEImputer
impute.ChainedImputer

.. _kernel_approximation_ref:

Expand Down
72 changes: 54 additions & 18 deletions doc/modules/impute.rst
Expand Up @@ -13,9 +13,22 @@ array are numerical, and that all have and hold meaning. A basic strategy to use
incomplete datasets is to discard entire rows and/or columns containing missing
values. However, this comes at the price of losing data which may be valuable
(even though incomplete). A better strategy is to impute the missing values,
i.e., to infer them from the known part of the data.
i.e., to infer them from the known part of the data. See the :ref:`glossary`
entry on imputation.


Univariate vs. Multivariate Imputation
======================================

One type of imputation algorithm is univariate, which imputes values in the i-th
feature dimension using only non-missing values in that feature dimension
(e.g. :class:`impute.SimpleImputer`). By contrast, multivariate imputation
algorithms use the entire set of available feature dimensions to estimate the
missing values (e.g. :class:`impute.ChainedImputer`).


.. _single_imputer:

Univariate feature imputation
=============================

Expand Down Expand Up @@ -74,35 +87,58 @@ string values or pandas categoricals when using the ``'most_frequent'`` or
['a' 'y']
['b' 'y']]

.. _mice:
.. _chained_imputer:


Multivariate feature imputation
===============================

A more sophisticated approach is to use the :class:`MICEImputer` class, which
implements the Multivariate Imputation by Chained Equations technique. MICE
models each feature with missing values as a function of other features, and
uses that estimate for imputation. It does so in a round-robin fashion: at
each step, a feature column is designated as output `y` and the other feature
columns are treated as inputs `X`. A regressor is fit on `(X, y)` for known `y`.
Then, the regressor is used to predict the unknown values of `y`. This is
repeated for each feature, and then is done for a number of imputation rounds.
Here is an example snippet::
A more sophisticated approach is to use the :class:`ChainedImputer` class, which
implements the imputation technique from MICE (Multivariate Imputation by
Chained Equations). MICE models each feature with missing values as a function of
other features, and uses that estimate for imputation. It does so in a round-robin
fashion: at each step, a feature column is designated as output `y` and the other
feature columns are treated as inputs `X`. A regressor is fit on `(X, y)` for known `y`.
Then, the regressor is used to predict the unknown values of `y`. This is repeated
for each feature in a chained fashion, and then is done for a number of imputation
rounds. Here is an example snippet::

>>> import numpy as np
>>> from sklearn.impute import MICEImputer
>>> imp = MICEImputer(n_imputations=10, random_state=0)
>>> from sklearn.impute import ChainedImputer
>>> imp = ChainedImputer(n_imputations=10, random_state=0)
>>> imp.fit([[1, 2], [np.nan, 3], [7, np.nan]])
MICEImputer(imputation_order='ascending', initial_strategy='mean',
max_value=None, min_value=None, missing_values=nan, n_burn_in=10,
n_imputations=10, n_nearest_features=None, predictor=None,
random_state=0, verbose=False)
ChainedImputer(imputation_order='ascending', initial_strategy='mean',
max_value=None, min_value=None, missing_values=nan, n_burn_in=10,
n_imputations=10, n_nearest_features=None, predictor=None,
random_state=0, verbose=False)
>>> X_test = [[np.nan, 2], [6, np.nan], [np.nan, 6]]
>>> print(np.round(imp.transform(X_test)))
[[ 1. 2.]
[ 6. 4.]
[13. 6.]]

Both :class:`SimpleImputer` and :class:`MICEImputer` can be used in a Pipeline
Both :class:`SimpleImputer` and :class:`ChainedImputer` can be used in a Pipeline
as a way to build a composite estimator that supports imputation.
See :ref:`sphx_glr_auto_examples_plot_missing_values.py`.


.. _multiple_imputation:

Multiple vs. Single Imputation
==============================

In the statistics community, it is common practice to perform multiple imputations,
generating, for example, 10 separate imputations for a single feature matrix.
Each of these 10 imputations is then put through the subsequent analysis pipeline
(e.g. feature engineering, clustering, regression, classification). The 10 final
analysis results (e.g. held-out validation error) allow the data scientist to
obtain understanding of the uncertainty inherent in the missing values. The above
practice is called multiple imputation. As implemented, the :class:`ChainedImputer`
class generates a single (averaged) imputation for each missing value because this
is the most common use case for machine learning applications. However, it can also be used
for multiple imputations by applying it repeatedly to the same dataset with different
random seeds with the ``n_imputations`` parameter set to 1.

Note that a call to the ``transform`` method of :class:`ChainedImputer` is not
allowed to change the number of samples. Therefore multiple imputations cannot be
achieved by a single call to ``transform``.
2 changes: 1 addition & 1 deletion doc/whats_new/v0.20.rst
Expand Up @@ -123,7 +123,7 @@ Preprocessing
back to the original space via an inverse transform. :issue:`9041` by
`Andreas Müller`_ and :user:`Guillaume Lemaitre <glemaitre>`.

- Added :class:`impute.MICEImputer`, which is a strategy for imputing missing
- Added :class:`impute.ChainedImputer`, which is a strategy for imputing missing
values by modeling each feature with missing values as a function of
other features in a round-robin fashion. :issue:`8478` by
:user:`Sergey Feldman <sergeyf>`.
Expand Down
25 changes: 13 additions & 12 deletions examples/plot_missing_values.py
Expand Up @@ -8,10 +8,11 @@
The median is a more robust estimator for data with high magnitude variables
which could dominate results (otherwise known as a 'long tail').
Another option is the MICE imputer. This uses round-robin linear regression,
treating every variable as an output in turn. The version implemented assumes
Gaussian (output) variables. If your features are obviously non-Normal,
consider transforming them to look more Normal so as to improve performance.
Another option is the ``ChainedImputer``. This uses round-robin linear
regression, treating every variable as an output in turn. The version
implemented assumes Gaussian (output) variables. If your features are obviously
non-Normal, consider transforming them to look more Normal so as to improve
performance.
"""

import numpy as np
Expand All @@ -21,7 +22,7 @@
from sklearn.datasets import load_boston
from sklearn.ensemble import RandomForestRegressor
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer, MICEImputer
from sklearn.impute import SimpleImputer, ChainedImputer
from sklearn.model_selection import cross_val_score

rng = np.random.RandomState(0)
Expand Down Expand Up @@ -66,18 +67,18 @@ def get_results(dataset):
mean_impute_scores = cross_val_score(estimator, X_missing, y_missing,
scoring='neg_mean_squared_error')

# Estimate the score after imputation (MICE strategy) of the missing values
estimator = Pipeline([("imputer", MICEImputer(missing_values=0,
random_state=0)),
# Estimate the score after chained imputation of the missing values
estimator = Pipeline([("imputer", ChainedImputer(missing_values=0,
random_state=0)),
("forest", RandomForestRegressor(random_state=0,
n_estimators=100))])
mice_impute_scores = cross_val_score(estimator, X_missing, y_missing,
scoring='neg_mean_squared_error')
chained_impute_scores = cross_val_score(estimator, X_missing, y_missing,
scoring='neg_mean_squared_error')

return ((full_scores.mean(), full_scores.std()),
(zero_impute_scores.mean(), zero_impute_scores.std()),
(mean_impute_scores.mean(), mean_impute_scores.std()),
(mice_impute_scores.mean(), mice_impute_scores.std()))
(chained_impute_scores.mean(), chained_impute_scores.std()))


results_diabetes = np.array(get_results(load_diabetes()))
Expand All @@ -94,7 +95,7 @@ def get_results(dataset):
x_labels = ['Full data',
'Zero imputation',
'Mean Imputation',
'MICE Imputation']
'Chained Imputation']
colors = ['r', 'g', 'b', 'orange']

# plot diabetes results
Expand Down
40 changes: 21 additions & 19 deletions sklearn/impute.py
Expand Up @@ -30,13 +30,13 @@
zip = six.moves.zip
map = six.moves.map

MICETriplet = namedtuple('MICETriplet', ['feat_idx',
'neighbor_feat_idx',
'predictor'])
ImputerTriplet = namedtuple('ImputerTriplet', ['feat_idx',
'neighbor_feat_idx',
'predictor'])

__all__ = [
'SimpleImputer',
'MICEImputer',
'ChainedImputer',
]


Expand Down Expand Up @@ -423,12 +423,12 @@ def transform(self, X):
return X


class MICEImputer(BaseEstimator, TransformerMixin):
"""MICE transformer to impute missing values.
class ChainedImputer(BaseEstimator, TransformerMixin):
"""Chained imputer transformer to impute missing values.
Basic implementation of MICE (Multivariate Imputations by Chained
Equations) package from R. This version assumes all of the features are
Gaussian.
Basic implementation of chained imputer from MICE (Multivariate
Imputations by Chained Equations) package from R. This version assumes all
of the features are Gaussian.
Read more in the :ref:`User Guide <mice>`.
Expand All @@ -453,11 +453,11 @@ class MICEImputer(BaseEstimator, TransformerMixin):
A random order for each round.
n_imputations : int, optional (default=100)
Number of MICE rounds to perform, the results of which will be
used in the final average.
Number of chained imputation rounds to perform, the results of which
will be used in the final average.
n_burn_in : int, optional (default=10)
Number of initial MICE rounds to perform the results of which
Number of initial imputation rounds to perform the results of which
will not be returned.
predictor : estimator object, default=BayesianRidge()
Expand Down Expand Up @@ -858,7 +858,8 @@ def fit_transform(self, X, y=None):
Xt = np.zeros((n_samples, n_features), dtype=X.dtype)
self.imputation_sequence_ = []
if self.verbose > 0:
print("[MICE] Completing matrix with shape %s" % (X.shape,))
print("[ChainedImputer] Completing matrix with shape %s"
% (X.shape,))
start_t = time()
for i_rnd in range(n_rounds):
if self.imputation_order == 'random':
Expand All @@ -871,15 +872,15 @@ def fit_transform(self, X, y=None):
X_filled, predictor = self._impute_one_feature(
X_filled, mask_missing_values, feat_idx, neighbor_feat_idx,
predictor=None, fit_mode=True)
predictor_triplet = MICETriplet(feat_idx,
neighbor_feat_idx,
predictor)
predictor_triplet = ImputerTriplet(feat_idx,
neighbor_feat_idx,
predictor)
self.imputation_sequence_.append(predictor_triplet)

if i_rnd >= self.n_burn_in:
Xt += X_filled
if self.verbose > 0:
print('[MICE] Ending imputation round '
print('[ChainedImputer] Ending imputation round '
'%d/%d, elapsed time %0.2f'
% (i_rnd + 1, n_rounds, time() - start_t))

Expand Down Expand Up @@ -921,7 +922,8 @@ def transform(self, X):
i_rnd = 0
Xt = np.zeros(X.shape, dtype=X.dtype)
if self.verbose > 0:
print("[MICE] Completing matrix with shape %s" % (X.shape,))
print("[ChainedImputer] Completing matrix with shape %s"
% (X.shape,))
start_t = time()
for it, predictor_triplet in enumerate(self.imputation_sequence_):
X_filled, _ = self._impute_one_feature(
Expand All @@ -936,7 +938,7 @@ def transform(self, X):
if i_rnd >= self.n_burn_in:
Xt += X_filled
if self.verbose > 1:
print('[MICE] Ending imputation round '
print('[ChainedImputer] Ending imputation round '
'%d/%d, elapsed time %0.2f'
% (i_rnd + 1, n_rounds, time() - start_t))
i_rnd += 1
Expand Down

0 comments on commit 93382cc

Please sign in to comment.