Skip to content

Commit

Permalink
MAINT: Improve documentation and coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Jul 22, 2014
1 parent a52174b commit 67585f6
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 52 deletions.
139 changes: 93 additions & 46 deletions sklearn/linear_model/logistic.py
Expand Up @@ -51,7 +51,7 @@ def _intercept_dot(w, X, y):
w = w[:-1]

z = safe_sparse_dot(X, w) + c
return w, c, y*z
return w, c, y * z


def _logistic_loss_and_grad(w, X, y, alpha, sample_weight=None):
Expand All @@ -77,10 +77,10 @@ def _logistic_loss_and_grad(w, X, y, alpha, sample_weight=None):
Returns
-------
out: float
out : float
Logistic loss.
grad: ndarray, shape (n_features,) or (n_features + 1,)
grad : ndarray, shape (n_features,) or (n_features + 1,)
Logistic gradient.
"""
_, n_features = X.shape
Expand Down Expand Up @@ -128,7 +128,7 @@ def _logistic_loss(w, X, y, alpha, sample_weight=None):
Returns
-------
out: float
out : float
Logistic loss.
"""
w, c, yz = _intercept_dot(w, X, y)
Expand Down Expand Up @@ -164,13 +164,13 @@ def _logistic_loss_grad_hess(w, X, y, alpha, sample_weight=None):
Returns
-------
out: float
out : float
Logistic loss.
grad: ndarray, shape (n_features,) or (n_features + 1,)
grad : ndarray, shape (n_features,) or (n_features + 1,)
Logistic gradient.
Hs: callable
Hs : callable
Function that takes the gradient as a parameter and returns the
matrix product of the Hessian and gradient.
"""
Expand Down Expand Up @@ -226,8 +226,9 @@ def Hs(s):

def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
max_iter=100, tol=1e-4, verbose=0,
solver='liblinear', coef=None, copy=True,
class_weight=None, dual=False, penalty='l2'):
solver='lbfgs', coef=None, copy=True,
class_weight=None, dual=False, penalty='l2',
intercept_scaling=1.):
"""Compute a Logistic Regression model for a list of regularization
parameters.
Expand All @@ -243,13 +244,13 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
y : array-like, shape (n_samples,)
Input data, target values.
Cs : array-like or integer of shape (n_cs,)
Cs : int | array-like, shape (n_cs,)
List of values for the regularization parameter or integer specifying
the number of regularization parameters that should be used. In this
case, the parameters will be chosen in a logarithmic scale between
1e-4 and 1e4.
pos_class: int, None
pos_class : int, None
The class with respect to which we perform a one-vs-all fit.
If None, then it is assumed that the given problem is binary.
Expand All @@ -261,20 +262,20 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
Maximum number of iterations for the solver.
tol : float
Stopping criterion. The iteration will stop when
``max{|g_i | i = 1, ..., n} <= tol``
Stopping criterion. For the newton-cg and lbfgs solvers, the iteration
will stop when ``max{|g_i | i = 1, ..., n} <= tol``
where ``g_i`` is the i-th component of the gradient.
verbose: int
verbose : int
Print convergence message if True.
solver : {'lbfgs', 'newton-cg', 'liblinear'}
Numerical solver to use.
coef: array-like, shape (n_features,), default None
coef : array-like, shape (n_features,), default None
Initialization value for coefficients of logistic regression.
copy: bool, default True
copy : bool, default True
Whether or not to produce a copy of the data. Setting this to
True will be useful in cases, when logistic_regression_path
is called repeatedly with the same data, as y is modified
Expand All @@ -295,6 +296,18 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
Used to specify the norm used in the penalization. The newton-cg and
lbfgs solvers support only l2 penalties.
intercept_scaling : float, default 1.
This parameter is useful only when the solver 'liblinear' is used
and self.fit_intercept is set to True. In this case, x becomes
[x, self.intercept_scaling],
i.e. a "synthetic" feature with constant value equals to
intercept_scaling is appended to the instance vector.
The intercept becomes intercept_scaling * synthetic feature weight
Note! the synthetic feature weight is subject to l1/l2 regularization
as all other features.
To lessen the effect of regularization on synthetic feature weight
(and therefore on the intercept) intercept_scaling has to be increased.
Returns
-------
coefs : ndarray, shape (n_cs, n_features) or (n_cs, n_features + 1)
Expand Down Expand Up @@ -396,7 +409,8 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
elif solver == 'liblinear':
lr = LogisticRegression(C=C, fit_intercept=fit_intercept, tol=tol,
class_weight=class_weight, dual=dual,
penalty=penalty)
penalty=penalty,
intercept_scaling=intercept_scaling)
lr.fit(X, y)
if fit_intercept:
w0 = np.concatenate([lr.coef_.ravel(), lr.intercept_])
Expand All @@ -413,8 +427,8 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
scoring=None, fit_intercept=False,
max_iter=100, tol=1e-4, class_weight=None,
verbose=0, method='liblinear', penalty='l2',
dual=False, copy=True):
verbose=0, solver='lbfgs', penalty='l2',
dual=False, copy=True, intercept_scaling=1.):
"""Computes scores across logistic_regression_path
Parameters
Expand All @@ -431,11 +445,11 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
test : list of indices
The indices of the test set.
pos_class: int, None
pos_class : int, None
The class with respect to which we perform a one-vs-all fit.
If None, then it is assumed that the given problem is binary.
Cs: list of floats | int
Cs : list of floats | int
Each of the values in Cs describes the inverse of
regularization strength. If Cs is as an int, then a grid of Cs
values are chosen in a logarithmic scale between 1e-4 and 1e4.
Expand All @@ -451,7 +465,7 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
term of each coef_ gives us the intercept.
max_iter : int
Maximum no. of iterations for the solver.
Maximum number of iterations for the solver.
tol : float
Tolerance for stopping criteria.
Expand All @@ -465,7 +479,7 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
verbose : int
Amount of verbosity.
method : {'lbfgs', 'newton-cg', 'liblinear'}
solver : {'lbfgs', 'newton-cg', 'liblinear'}
Decides which solver to use.
penalty : str, 'l1' or 'l2'
Expand All @@ -477,6 +491,18 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
l2 penalty with liblinear solver. Prefer dual=False when
n_samples > n_features.
intercept_scaling : float, default 1.
This parameter is useful only when the solver 'liblinear' is used
and self.fit_intercept is set to True. In this case, x becomes
[x, self.intercept_scaling],
i.e. a "synthetic" feature with constant value equals to
intercept_scaling is appended to the instance vector.
The intercept becomes intercept_scaling * synthetic feature weight
Note! the synthetic feature weight is subject to l1/l2 regularization
as all other features.
To lessen the effect of regularization on synthetic feature weight
(and therefore on the intercept) intercept_scaling has to be increased.
Returns
-------
coefs : ndarray, shape (n_cs, n_features) or (n_cs, n_features + 1)
Expand Down Expand Up @@ -510,12 +536,13 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,

coefs, Cs = logistic_regression_path(X_train, y_train, Cs=Cs,
fit_intercept=fit_intercept,
solver=method,
solver=solver,
max_iter=max_iter,
class_weight=class_weight,
copy=copy, pos_class=pos_class,
tol=tol, verbose=verbose,
dual=dual, penalty=penalty)
dual=dual, penalty=penalty,
intercept_scaling=intercept_scaling)

scores = list()

Expand Down Expand Up @@ -593,30 +620,30 @@ class LogisticRegression(BaseLibLinear, LinearClassifierMixin,
Useful only for the newton-cg and lbfgs solvers. Maximum number of
iterations taken for the solvers to converge.
random_state: int seed, RandomState instance, or None (default)
random_state : int seed, RandomState instance, or None (default)
The seed of the pseudo random number generator to use when
shuffling the data.
solver: {'newton-cg', 'lbfgs', 'liblinear'}
solver : {'newton-cg', 'lbfgs', 'liblinear'}
Algorithm to use in the optimization problem.
tol: float, optional
tol : float, optional
Tolerance for stopping criteria.
Attributes
----------
`coef_` : array, shape (n_classes, n_features)
Coefficient of the features in the decision function.
`intercept_` : array, shape = (n_classes,)
`intercept_` : array, shape (n_classes,)
Intercept (a.k.a. bias) added to the decision function.
If `fit_intercept` is set to False, the intercept is set to zero.
See also
--------
SGDClassifier: incrementally trained logistic regression (when given
SGDClassifier : incrementally trained logistic regression (when given
the parameter ``loss="log"``).
sklearn.svm.LinearSVC: learns SVM models using the same algorithm.
sklearn.svm.LinearSVC : learns SVM models using the same algorithm.
Notes
-----
Expand Down Expand Up @@ -683,8 +710,8 @@ def predict_log_proba(self, X):
return np.log(self.predict_proba(X))


class LogisticRegressionCV(BaseEstimator, LinearClassifierMixin,
_LearntSelectorMixin):
class LogisticRegressionCV(LogisticRegression, BaseEstimator,
LinearClassifierMixin, _LearntSelectorMixin):
"""Logistic Regression CV (aka logit, MaxEnt) classifier.
This class implements logistic regression using liblinear, newton-cg or
Expand All @@ -694,14 +721,14 @@ class LogisticRegressionCV(BaseEstimator, LinearClassifierMixin,
Parameters
----------
Cs: list of floats | int
Cs : list of floats | int
Each of the values in Cs describes the inverse of regularization
strength. If Cs is as an int, then a grid of Cs values are chosen
in a logarithmic scale between 1e-4 and 1e4.
Like in support vector machines, smaller values specify stronger
regularization.
fit_intercept: bool, default: True
fit_intercept : bool, default: True
Specifies if a constant (a.k.a. bias or intercept) should be
added the decision function.
Expand All @@ -726,18 +753,18 @@ class LogisticRegressionCV(BaseEstimator, LinearClassifierMixin,
l2 penalty with liblinear solver. Prefer dual=False when
n_samples > n_features.
scoring: callabale
scoring : callabale
Scoring function to use as cross-validation criteria. For a list of
scoring functions that can be used, look at :mod:`sklearn.metrics`.
The default scoring option used is accuracy_score.
solver: {'newton-cg', 'lbfgs', 'liblinear'}
solver : {'newton-cg', 'lbfgs', 'liblinear'}
Algorithm to use in the optimization problem.
tol: float, optional
tol : float, optional
Tolerance for stopping criteria.
max_iter: int, optional
max_iter : int, optional
Maximum number of iterations of the optimization algorithm.
class_weight : {dict, 'auto'}, optional
Expand All @@ -760,6 +787,18 @@ class LogisticRegressionCV(BaseEstimator, LinearClassifierMixin,
Otherwise the coefs, intercepts and C that correspond to the
best scores across folds are averaged.
intercept_scaling : float, default 1.
This parameter is useful only when the solver 'liblinear' is used
and self.fit_intercept is set to True. In this case, x becomes
[x, self.intercept_scaling],
i.e. a "synthetic" feature with constant value equals to
intercept_scaling is appended to the instance vector.
The intercept becomes intercept_scaling * synthetic feature weight
Note! the synthetic feature weight is subject to l1/l2 regularization
as all other features.
To lessen the effect of regularization on synthetic feature weight
(and therefore on the intercept) intercept_scaling has to be increased.
Attributes
----------
`coef_` : array, shape (1, n_features) or (n_classes, n_features)
Expand Down Expand Up @@ -795,7 +834,9 @@ class LogisticRegressionCV(BaseEstimator, LinearClassifierMixin,
Each dict value has shape (n_folds, len(Cs))
`C_` : array, shape (n_classes,) or (n_classes - 1,)
Array of C that maps to the best scores across every class.
Array of C that maps to the best scores across every class. If refit is
set to False, then for each class, the best C is the average of the
C's that correspond to the best scores for each fold.
See also
--------
Expand All @@ -804,9 +845,9 @@ class LogisticRegressionCV(BaseEstimator, LinearClassifierMixin,
"""

def __init__(self, Cs=10, fit_intercept=True, cv=None, dual=False,
penalty='l2', scoring=None, solver='newton-cg', tol=1e-4,
penalty='l2', scoring=None, solver='lbfgs', tol=1e-4,
max_iter=100, class_weight=None, n_jobs=1, verbose=False,
refit=True):
refit=True, intercept_scaling=1.):
self.Cs = Cs
self.fit_intercept = fit_intercept
self.cv = cv
Expand All @@ -820,6 +861,7 @@ def __init__(self, Cs=10, fit_intercept=True, cv=None, dual=False,
self.verbose = verbose
self.solver = solver
self.refit = refit
self.intercept_scaling = 1.

def fit(self, X, y):
"""Fit the model according to the given training data.
Expand Down Expand Up @@ -864,7 +906,10 @@ def fit(self, X, y):
cv = _check_cv(self.cv, X, y, classifier=True)
folds = list(cv)

self.classes_ = labels = np.unique(y)
self._enc = LabelEncoder()
self._enc.fit(y)

labels = self.classes_
n_classes = len(labels)

if n_classes < 2:
Expand All @@ -888,12 +933,13 @@ def fit(self, X, y):
fit_intercept=self.fit_intercept,
penalty=self.penalty,
dual=self.dual,
method=self.solver,
solver=self.solver,
max_iter=self.max_iter,
tol=self.tol,
class_weight=self.class_weight,
verbose=max(0, self.verbose - 1),
scoring=self.scoring)
scoring=self.scoring,
intercept_scaling=self.intercept_scaling)
for label in labels
for train, test in folds
)
Expand Down Expand Up @@ -930,7 +976,7 @@ def fit(self, X, y):

else:
# Take the best scores across every fold and the average of all
# coefficients coressponding to the best scores.
# coefficients corresponding to the best scores.
best_indices = np.argmax(scores, axis=1)
w = np.mean([
coefs_paths[i][best_indices[i]]
Expand All @@ -944,6 +990,7 @@ def fit(self, X, y):
else:
self.coef_.append(w)
self.intercept_.append(0.)
self.C_ = np.asarray(self.C_)
self.coef_ = np.asarray(self.coef_)
self.intercept_ = np.asarray(self.intercept_)
return self

0 comments on commit 67585f6

Please sign in to comment.