Skip to content

Commit

Permalink
TST Ensure that attributes ending _ are not set in __init__ (#7464)
Browse files Browse the repository at this point in the history
  • Loading branch information
lesteve authored and jnothman committed Dec 12, 2016
1 parent 83beb5f commit e542efa
Show file tree
Hide file tree
Showing 16 changed files with 64 additions and 72 deletions.
10 changes: 10 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ Bug fixes
when a numpy array is passed in for weights. :issue:`7983` by
:user:`Vincent Pham <vincentpham1991>`.

API changes summary
-------------------

- Ensure that estimators' attributes ending with ``_`` are not set
in the constructor but only in the ``fit`` method. Most notably,
ensemble estimators (deriving from :class:`ensemble.BaseEnsemble`)
now only have ``self.estimators_`` available after ``fit``.
:issue:`7464` by `Lars Buitinck`_ and `Loic Esteve`_.


.. _changes_0_18_1:

Version 0.18.1
Expand Down
5 changes: 4 additions & 1 deletion sklearn/decomposition/dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def _sparse_encode(X, dictionary, gram, cov=None, algorithm='lasso_lars',
# argument that we could pass in from Lasso.
clf = Lasso(alpha=alpha, fit_intercept=False, normalize=False,
precompute=gram, max_iter=max_iter, warm_start=True)
clf.coef_ = init

if init is not None:
clf.coef_ = init

clf.fit(dictionary.T, X.T, check_input=check_input)
new_code = clf.coef_

Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
if hasattr(self, "oob_score_") and self.warm_start:
del self.oob_score_

if not self.warm_start or len(self.estimators_) == 0:
if not self.warm_start or not hasattr(self, 'estimators_'):
# Free allocated memory, if any
self.estimators_ = []
self.estimators_features_ = []
Expand Down
3 changes: 1 addition & 2 deletions sklearn/ensemble/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ def __init__(self, base_estimator, n_estimators=10,

# Don't instantiate estimators now! Parameters of base_estimator might
# still change. Eg., when grid-searching with the nested object syntax.
# This needs to be filled by the derived classes.
self.estimators_ = []
# self.estimators_ needs to be filled by the derived classes in fit.

def _validate_estimator(self, default=None):
"""Check the estimator and the n_estimator attribute, set the
Expand Down
9 changes: 5 additions & 4 deletions sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
from .base import BaseEnsemble, _partition_estimators
from ..utils.fixes import bincount, parallel_helper
from ..utils.multiclass import check_classification_targets
from ..utils.validation import check_is_fitted

__all__ = ["RandomForestClassifier",
"RandomForestRegressor",
Expand Down Expand Up @@ -286,7 +287,7 @@ def fit(self, X, y, sample_weight=None):

random_state = check_random_state(self.random_state)

if not self.warm_start:
if not self.warm_start or not hasattr(self, "estimators_"):
# Free allocated memory, if any
self.estimators_ = []

Expand Down Expand Up @@ -361,9 +362,7 @@ def feature_importances_(self):
-------
feature_importances_ : array, shape = [n_features]
"""
if self.estimators_ is None or len(self.estimators_) == 0:
raise NotFittedError("Estimator not fitted, "
"call `fit` before `feature_importances_`.")
check_is_fitted(self, 'estimators_')

all_importances = Parallel(n_jobs=self.n_jobs,
backend="threading")(
Expand Down Expand Up @@ -557,6 +556,7 @@ class in a leaf.
The class probabilities of the input samples. The order of the
classes corresponds to that in the attribute `classes_`.
"""
check_is_fitted(self, 'estimators_')
# Check data
X = self._validate_X_predict(X)

Expand Down Expand Up @@ -669,6 +669,7 @@ def predict(self, X):
y : array of shape = [n_samples] or [n_samples, n_outputs]
The predicted values.
"""
check_is_fitted(self, 'estimators_')
# Check data
X = self._validate_X_predict(X)

Expand Down
6 changes: 1 addition & 5 deletions sklearn/ensemble/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,8 +742,6 @@ def __init__(self, loss, learning_rate, n_estimators, criterion,
self.warm_start = warm_start
self.presort = presort

self.estimators_ = np.empty((0, 0), dtype=np.object)

def _fit_stage(self, i, X, y, y_pred, sample_weight, sample_mask,
random_state, X_idx_sorted, X_csc=None, X_csr=None):
"""Fit another stage of ``n_classes_`` trees to the boosting model. """
Expand Down Expand Up @@ -923,9 +921,7 @@ def _is_initialized(self):

def _check_initialized(self):
"""Check that the estimator is initialized, raising an error if not."""
if self.estimators_ is None or len(self.estimators_) == 0:
raise NotFittedError("Estimator not fitted, call `fit`"
" before making predictions`.")
check_is_fitted(self, 'estimators_')

def fit(self, X, y, sample_weight=None, monitor=None):
"""Fit the gradient boosting model.
Expand Down
9 changes: 3 additions & 6 deletions sklearn/ensemble/partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..externals import six
from ..externals.six.moves import map, range, zip
from ..utils import check_array
from ..utils.validation import check_is_fitted
from ..tree._tree import DTYPE

from ._gradient_boosting import _partial_dependence_tree
Expand Down Expand Up @@ -121,9 +122,7 @@ def partial_dependence(gbrt, target_variables, grid=None, X=None,
"""
if not isinstance(gbrt, BaseGradientBoosting):
raise ValueError('gbrt has to be an instance of BaseGradientBoosting')
if gbrt.estimators_.shape[0] == 0:
raise ValueError('Call %s.fit before partial_dependence' %
gbrt.__class__.__name__)
check_is_fitted(gbrt, 'estimators_')
if (grid is None and X is None) or (grid is not None and X is not None):
raise ValueError('Either grid or X must be specified')

Expand Down Expand Up @@ -245,9 +244,7 @@ def plot_partial_dependence(gbrt, X, features, feature_names=None,

if not isinstance(gbrt, BaseGradientBoosting):
raise ValueError('gbrt has to be an instance of BaseGradientBoosting')
if gbrt.estimators_.shape[0] == 0:
raise ValueError('Call %s.fit before partial_dependence' %
gbrt.__class__.__name__)
check_is_fitted(gbrt, 'estimators_')

# set label_idx for multi-class GBRT
if hasattr(gbrt, 'classes_') and np.size(gbrt.classes_) > 2:
Expand Down
6 changes: 1 addition & 5 deletions sklearn/linear_model/coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,6 @@ def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True,
random_state=None, selection='cyclic'):
self.alpha = alpha
self.l1_ratio = l1_ratio
self.coef_ = None
self.fit_intercept = fit_intercept
self.normalize = normalize
self.precompute = precompute
Expand All @@ -634,7 +633,6 @@ def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True,
self.tol = tol
self.warm_start = warm_start
self.positive = positive
self.intercept_ = 0.0
self.random_state = random_state
self.selection = selection

Expand Down Expand Up @@ -697,7 +695,7 @@ def fit(self, X, y, check_input=True):
if self.selection not in ['cyclic', 'random']:
raise ValueError("selection should be either random or cyclic.")

if not self.warm_start or self.coef_ is None:
if not self.warm_start or not hasattr(self, "coef_"):
coef_ = np.zeros((n_targets, n_features), dtype=X.dtype,
order='F')
else:
Expand Down Expand Up @@ -1648,7 +1646,6 @@ def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True,
warm_start=False, random_state=None, selection='cyclic'):
self.l1_ratio = l1_ratio
self.alpha = alpha
self.coef_ = None
self.fit_intercept = fit_intercept
self.normalize = normalize
self.max_iter = max_iter
Expand Down Expand Up @@ -1832,7 +1829,6 @@ def __init__(self, alpha=1.0, fit_intercept=True, normalize=False,
copy_X=True, max_iter=1000, tol=1e-4, warm_start=False,
random_state=None, selection='cyclic'):
self.alpha = alpha
self.coef_ = None
self.fit_intercept = fit_intercept
self.normalize = normalize
self.max_iter = max_iter
Expand Down
28 changes: 9 additions & 19 deletions sklearn/linear_model/stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,6 @@ def __init__(self, loss, penalty='l2', alpha=0.0001, C=1.0,

self._validate_params()

self.coef_ = None

if self.average > 0:
self.standard_coef_ = None
self.average_coef_ = None
# iteration count for learning rate schedule
# must not be int (e.g. if ``learning_rate=='optimal'``)
self.t_ = None

def set_params(self, *args, **kwargs):
super(BaseSGD, self).set_params(*args, **kwargs)
self._validate_params()
Expand Down Expand Up @@ -332,7 +323,6 @@ def __init__(self, loss="hinge", penalty='l2', alpha=0.0001, l1_ratio=0.15,
warm_start=warm_start,
average=average)
self.class_weight = class_weight
self.classes_ = None
self.n_jobs = int(n_jobs)

def _partial_fit(self, X, y, alpha, C,
Expand All @@ -353,15 +343,15 @@ def _partial_fit(self, X, y, alpha, C,
self.classes_, y)
sample_weight = self._validate_sample_weight(sample_weight, n_samples)

if self.coef_ is None or coef_init is not None:
if getattr(self, "coef_", None) is None or coef_init is not None:
self._allocate_parameter_mem(n_classes, n_features,
coef_init, intercept_init)
elif n_features != self.coef_.shape[-1]:
raise ValueError("Number of features %d does not match previous "
"data %d." % (n_features, self.coef_.shape[-1]))

self.loss_function = self._get_loss_function(loss)
if self.t_ is None:
if not hasattr(self, "t_"):
self.t_ = 1.0

# delegate to concrete training procedure
Expand Down Expand Up @@ -391,7 +381,7 @@ def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None,
# np.unique sorts in asc order; largest class id is positive class
classes = np.unique(y)

if self.warm_start and self.coef_ is not None:
if self.warm_start and hasattr(self, "coef_"):
if coef_init is None:
coef_init = self.coef_
if intercept_init is None:
Expand All @@ -407,7 +397,7 @@ def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None,
self.average_intercept_ = None

# Clear iteration count for multiple call to fit.
self.t_ = None
self.t_ = 1.0

self._partial_fit(X, y, alpha, C, loss, learning_rate, self.n_iter,
classes, sample_weight, coef_init, intercept_init)
Expand Down Expand Up @@ -871,13 +861,13 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate,
# Allocate datastructures from input arguments
sample_weight = self._validate_sample_weight(sample_weight, n_samples)

if self.coef_ is None:
if getattr(self, "coef_", None) is None:
self._allocate_parameter_mem(1, n_features,
coef_init, intercept_init)
elif n_features != self.coef_.shape[-1]:
raise ValueError("Number of features %d does not match previous "
"data %d." % (n_features, self.coef_.shape[-1]))
if self.average > 0 and self.average_coef_ is None:
if self.average > 0 and getattr(self, "average_coef_", None) is None:
self.average_coef_ = np.zeros(n_features,
dtype=np.float64,
order="C")
Expand Down Expand Up @@ -917,7 +907,7 @@ def partial_fit(self, X, y, sample_weight=None):

def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None,
intercept_init=None, sample_weight=None):
if self.warm_start and self.coef_ is not None:
if self.warm_start and getattr(self, "coef_", None) is not None:
if coef_init is None:
coef_init = self.coef_
if intercept_init is None:
Expand All @@ -933,7 +923,7 @@ def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None,
self.average_intercept_ = None

# Clear iteration count for multiple call to fit.
self.t_ = None
self.t_ = 1.0

return self._partial_fit(X, y, alpha, C, loss, learning_rate,
self.n_iter, sample_weight,
Expand Down Expand Up @@ -1012,7 +1002,7 @@ def _fit_regressor(self, X, y, alpha, C, loss, learning_rate,
penalty_type = self._get_penalty_type(self.penalty)
learning_rate_type = self._get_learning_rate_type(learning_rate)

if self.t_ is None:
if not hasattr(self, "t_"):
self.t_ = 1.0

random_state = check_random_state(self.random_state)
Expand Down
1 change: 0 additions & 1 deletion sklearn/manifold/t_sne.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,6 @@ def __init__(self, n_components=2, perplexity=30.0,
self.random_state = random_state
self.method = method
self.angle = angle
self.embedding_ = None

def _fit(self, X, skip_num_points=0):
"""Fit the model using X as training data.
Expand Down
6 changes: 0 additions & 6 deletions sklearn/mixture/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,6 @@ def __init__(self, n_components=1, covariance_type='diag',
if n_init < 1:
raise ValueError('GMM estimation requires at least one run')

self.weights_ = np.ones(self.n_components) / self.n_components

# flag to indicate exit status of fit() method: converged (True) or
# n_iter reached (False)
self.converged_ = False

def _get_covars(self):
"""Covariance parameters for each mixture component.
Expand Down
1 change: 1 addition & 0 deletions sklearn/neighbors/tests/test_approximate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_neighbors_accuracy_with_n_candidates():

accuracies[i] = accuracies[i] / float(n_iter)
# Sorted accuracies should be equal to original accuracies
print('accuracies:', accuracies)
assert_true(np.all(np.diff(accuracies) >= 0),
msg="Accuracies are not non-decreasing.")
# Highest accuracy should be strictly greater than the lowest
Expand Down
12 changes: 3 additions & 9 deletions sklearn/random_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@
from .utils import check_random_state
from .utils.extmath import safe_sparse_dot
from .utils.random import sample_without_replacement
from .utils.validation import check_array
from .utils.validation import check_array, check_is_fitted
from .exceptions import DataDimensionalityWarning
from .exceptions import NotFittedError


__all__ = ["SparseRandomProjection",
Expand Down Expand Up @@ -303,9 +302,6 @@ def __init__(self, n_components='auto', eps=0.1, dense_output=False,
self.dense_output = dense_output
self.random_state = random_state

self.components_ = None
self.n_components_ = None

@abstractmethod
def _make_random_matrix(n_components, n_features):
""" Generate the random projection matrix
Expand Down Expand Up @@ -365,7 +361,7 @@ def fit(self, X, y=None):
else:
if self.n_components <= 0:
raise ValueError("n_components must be greater than 0, got %s"
% self.n_components_)
% self.n_components)

elif self.n_components > n_features:
warnings.warn(
Expand Down Expand Up @@ -408,8 +404,7 @@ def transform(self, X, y=None):
"""
X = check_array(X, accept_sparse=['csr', 'csc'])

if self.components_ is None:
raise NotFittedError('No random projection matrix had been fit.')
check_is_fitted(self, 'components_')

if X.shape[1] != self.components_.shape[1]:
raise ValueError(
Expand Down Expand Up @@ -596,7 +591,6 @@ def __init__(self, n_components='auto', density='auto', eps=0.1,
random_state=random_state)

self.density = density
self.density_ = None

def _make_random_matrix(self, n_components, n_features):
""" Generate the random projection matrix
Expand Down
2 changes: 1 addition & 1 deletion sklearn/svm/tests/test_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_oneclass():
assert_array_almost_equal(clf.dual_coef_,
[[0.632, 0.233, 0.633, 0.234, 0.632, 0.633]],
decimal=3)
assert_false(hasattr(clf, "coef_"))
assert_raises(AttributeError, lambda: clf.coef_)


def test_oneclass_decision_function():
Expand Down
Loading

0 comments on commit e542efa

Please sign in to comment.