Skip to content

Commit

Permalink
FIX delete feature_names_in_ when refitting on a ndarray (#21389)
Browse files Browse the repository at this point in the history
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
  • Loading branch information
2 people authored and glemaitre committed Oct 25, 2021
1 parent ae223ee commit cd927c0
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 70 deletions.
8 changes: 8 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ Fixed models
where the underlying check for an attribute did not work with NumPy arrays.
:pr:`21145` by :user:`Zahlii <Zahlii>`.

Miscellaneous
.............

- |Fix| Fitting an estimator on a dataset that has no feature names, that was previously
fitted on a dataset with feature names no longer keeps the old feature names stored in
the `feature_names_in_` attribute. :pr:`21389` by
:user:`Jérémie du Boisberranger <jeremiedbb>`.

.. _changes_1_0:

Version 1.0.0
Expand Down
4 changes: 4 additions & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,10 @@ def _check_feature_names(self, X, *, reset):
feature_names_in = _get_feature_names(X)
if feature_names_in is not None:
self.feature_names_in_ = feature_names_in
elif hasattr(self, "feature_names_in_"):
# Delete the attribute when the estimator is fitted on a new dataset
# that has no feature names.
delattr(self, "feature_names_in_")
return

fitted_feature_names = getattr(self, "feature_names_in_", None)
Expand Down
29 changes: 18 additions & 11 deletions sklearn/cluster/_agglomerative.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,22 @@ def fit(self, X, y=None):
Returns the fitted instance.
"""
X = self._validate_data(X, ensure_min_samples=2, estimator=self)
return self._fit(X)

def _fit(self, X):
"""Fit without validation
Parameters
----------
X : ndarray of shape (n_samples, n_features) or (n_samples, n_samples)
Training instances to cluster, or distances between instances if
``affinity='precomputed'``.
Returns
-------
self : object
Returns the fitted instance.
"""
memory = check_memory(self.memory)

if self.n_clusters is not None and self.n_clusters <= 0:
Expand Down Expand Up @@ -1218,17 +1234,8 @@ def fit(self, X, y=None):
self : object
Returns the transformer.
"""
X = self._validate_data(
X,
accept_sparse=["csr", "csc", "coo"],
ensure_min_features=2,
estimator=self,
)
# save n_features_in_ attribute here to reset it after, because it will
# be overridden in AgglomerativeClustering since we passed it X.T.
n_features_in_ = self.n_features_in_
AgglomerativeClustering.fit(self, X.T)
self.n_features_in_ = n_features_in_
X = self._validate_data(X, ensure_min_features=2, estimator=self)
super()._fit(X.T)
return self

@property
Expand Down
24 changes: 4 additions & 20 deletions sklearn/decomposition/_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,20 +684,6 @@ def _unnormalized_transform(self, X):
doc_topic_distr : ndarray of shape (n_samples, n_components)
Document topic distribution for X.
"""
check_is_fitted(self)

# make sure feature size is the same in fitted model and in X
X = self._check_non_neg_array(
X, reset_n_features=True, whom="LatentDirichletAllocation.transform"
)
n_samples, n_features = X.shape
if n_features != self.components_.shape[1]:
raise ValueError(
"The provided data has %d dimensions while "
"the model was trained with feature size %d."
% (n_features, self.components_.shape[1])
)

doc_topic_distr, _ = self._e_step(X, cal_sstats=False, random_init=False)

return doc_topic_distr
Expand Down Expand Up @@ -851,12 +837,6 @@ def _perplexity_precomp_distr(self, X, doc_topic_distr=None, sub_sampling=False)
score : float
Perplexity score.
"""
check_is_fitted(self)

X = self._check_non_neg_array(
X, reset_n_features=True, whom="LatentDirichletAllocation.perplexity"
)

if doc_topic_distr is None:
doc_topic_distr = self._unnormalized_transform(X)
else:
Expand Down Expand Up @@ -902,4 +882,8 @@ def perplexity(self, X, sub_sampling=False):
score : float
Perplexity score.
"""
check_is_fitted(self)
X = self._check_non_neg_array(
X, reset_n_features=True, whom="LatentDirichletAllocation.perplexity"
)
return self._perplexity_precomp_distr(X, sub_sampling=sub_sampling)
18 changes: 9 additions & 9 deletions sklearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,15 @@ def fit(self, X, y, sample_weight=None):
self : object
Fitted estimator.
"""
# Convert data (X is required to be 2d and indexable)
X, y = self._validate_data(
X,
y,
accept_sparse=["csr", "csc"],
dtype=None,
force_all_finite=False,
multi_output=True,
)
return self._fit(X, y, self.max_samples, sample_weight=sample_weight)

def _parallel_args(self):
Expand Down Expand Up @@ -295,15 +304,6 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
"""
random_state = check_random_state(self.random_state)

# Convert data (X is required to be 2d and indexable)
X, y = self._validate_data(
X,
y,
accept_sparse=["csr", "csc"],
dtype=None,
force_all_finite=False,
multi_output=True,
)
if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X, dtype=None)

Expand Down
9 changes: 2 additions & 7 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
from ..utils.fixes import _joblib_parallel_args
from ..utils.multiclass import check_classification_targets, type_of_target
from ..utils.validation import check_is_fitted, _check_sample_weight
from ..utils.validation import _num_samples


__all__ = [
Expand Down Expand Up @@ -2627,14 +2628,8 @@ def fit_transform(self, X, y=None, sample_weight=None):
X_transformed : sparse matrix of shape (n_samples, n_out)
Transformed dataset.
"""
X = self._validate_data(X, accept_sparse=["csc"])
if issparse(X):
# Pre-sort indices to avoid that each individual tree of the
# ensemble sorts the indices.
X.sort_indices()

rnd = check_random_state(self.random_state)
y = rnd.uniform(size=X.shape[0])
y = rnd.uniform(size=_num_samples(X))
super().fit(X, y, sample_weight=sample_weight)

self.one_hot_encoder_ = OneHotEncoder(sparse=self.sparse_output)
Expand Down
19 changes: 9 additions & 10 deletions sklearn/linear_model/_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,16 +700,6 @@ def fit(self, X, y, sample_weight=None):
self.normalize, default=False, estimator_name=self.__class__.__name__
)

_dtype = [np.float64, np.float32]
_accept_sparse = _get_valid_accept_sparse(sparse.issparse(X), self.solver)
X, y = self._validate_data(
X,
y,
accept_sparse=_accept_sparse,
dtype=_dtype,
multi_output=True,
y_numeric=True,
)
if self.solver == "lbfgs" and not self.positive:
raise ValueError(
"'lbfgs' solver can be used only when positive=True. "
Expand Down Expand Up @@ -1008,6 +998,15 @@ def fit(self, X, y, sample_weight=None):
self : object
Fitted estimator.
"""
_accept_sparse = _get_valid_accept_sparse(sparse.issparse(X), self.solver)
X, y = self._validate_data(
X,
y,
accept_sparse=_accept_sparse,
dtype=[np.float64, np.float32],
multi_output=True,
y_numeric=True,
)
return super().fit(X, y, sample_weight=sample_weight)


Expand Down
13 changes: 3 additions & 10 deletions sklearn/linear_model/_stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,19 +648,12 @@ def _fit(
):
self._validate_params()
if hasattr(self, "classes_"):
self.classes_ = None

X, y = self._validate_data(
X,
y,
accept_sparse="csr",
dtype=np.float64,
order="C",
accept_large_sparse=False,
)
# delete the attribute otherwise _partial_fit thinks it's not the first call
delattr(self, "classes_")

# labels can be encoded as float, int, or string literals
# np.unique sorts in asc order; largest class id is positive class
y = self._validate_data(y=y)
classes = np.unique(y)

if self.warm_start and hasattr(self, "coef_"):
Expand Down
2 changes: 1 addition & 1 deletion sklearn/model_selection/tests/test_successive_halving.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def test_groups_support(Est):
]
error_msg = "The 'groups' parameter should not be None."
for cv in group_cvs:
gs = Est(clf, grid, cv=cv)
gs = Est(clf, grid, cv=cv, random_state=0)
with pytest.raises(ValueError, match=error_msg):
gs.fit(X, y)
gs.fit(X, y, groups=groups)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def fit(self, X, y, sample_weight=None):
self : object
Returns the instance itself.
"""
X, y = self._validate_data(X, y)
y = self._validate_data(y=y)
return self._partial_fit(
X, y, np.unique(y), _refit=True, sample_weight=sample_weight
)
Expand Down
7 changes: 6 additions & 1 deletion sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,11 @@ def transform(self, X):
trans = NoOpTransformer().fit(df)
assert_array_equal(trans.feature_names_in_, df.columns)

# fit again but on ndarray does not keep the previous feature names (see #21383)
trans.fit(X_np)
assert not hasattr(trans, "feature_names_in_")

trans.fit(df)
msg = "The feature names should match those that were passed"
df_bad = pd.DataFrame(X_np, columns=iris.feature_names[::-1])
with pytest.warns(FutureWarning, match=msg):
Expand Down Expand Up @@ -665,7 +670,7 @@ def transform(self, X):
assert not record

# fit on dataframe with no feature names or all integer feature names
# -> do not warn on trainsform
# -> do not warn on transform
Xs = [X_np, df_int_names]
for X in Xs:
with pytest.warns(None) as record:
Expand Down

0 comments on commit cd927c0

Please sign in to comment.