Skip to content

Commit

Permalink
FIX: Doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Jun 6, 2014
1 parent 3bfd1c8 commit cd1e238
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 72 deletions.
136 changes: 69 additions & 67 deletions sklearn/linear_model/logistic.py
@@ -1,26 +1,24 @@
# Authors: Fabian Pedregosa
# Alexandre Gramfort
# License: 3-clause BSD

"""
Logistic Regression
"""

# Author: Gael Varoquaux <gael.varoquaux@normalesup.org>
# Fabian Pedregosa <f@bianp.net>
# Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
# Manoj Kumar <manojkumarsivaraj334@gmail.com>

import numbers

import numpy as np
from scipy import optimize, sparse, special
from scipy import optimize, sparse

from .base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator
from ..feature_selection.from_model import _LearntSelectorMixin
from ..preprocessing import LabelEncoder, LabelBinarizer
from ..svm.base import BaseLibLinear
from ..utils import atleast2d_or_csc, as_float_array, check_arrays
from ..utils.extmath import log_logistic, safe_sparse_dot
from ..utils.fixes import expit
from ..externals.joblib import Parallel, delayed
from ..cross_validation import check_cv
from ..utils.optimize import newton_cg
Expand All @@ -30,19 +28,19 @@

# .. some helper functions for logistic_regression_path ..
def _intercept_dot(w, X, y):
"""
Computes y * np.dot(w, X), taking into consideration
if the intercept should be fit or not.
"""Computes y * np.dot(w, X).
It takes into consideration if the intercept should be fit or not.
Parameters
----------
w : ndarray, shape = (n_features,) or (n_features + 1,)
w : ndarray, shape (n_features,) or (n_features + 1,)
Coefficient vector
X : {array-like, sparse matrix}, shape (n_samples, n_features)
Training data
y : ndarray, shape (n_samples)
y : ndarray, shape (n_samples,)
Array of labels
"""
c = None
Expand All @@ -57,18 +55,17 @@ def _intercept_dot(w, X, y):


def _logistic_loss_and_grad(w, X, y, alpha):
"""
Computes the logistic loss and gradient.
"""Computes the logistic loss and gradient.
Parameters
----------
w : ndarray, shape = (n_features,) or (n_features + 1,)
w : ndarray, shape (n_features,) or (n_features + 1,)
Coefficient vector
X : {array-like, sparse matrix}, shape (n_samples, n_features)
Training data
y : ndarray, shape (n_samples)
y : ndarray, shape (n_samples,)
Array of labels.
alpha : float
Expand All @@ -82,7 +79,7 @@ def _logistic_loss_and_grad(w, X, y, alpha):
# Logistic loss is the negative of the log of the logistic function.
out = -np.sum(log_logistic(yz)) + .5 * alpha * np.dot(w, w)

z = special.expit(yz)
z = expit(yz)
z0 = (z - 1) * y

grad[:n_features] = safe_sparse_dot(X.T, z0) + alpha * w
Expand All @@ -92,12 +89,11 @@ def _logistic_loss_and_grad(w, X, y, alpha):


def _logistic_loss(w, X, y, alpha, fit_intercept=False):
"""
Computes the logistic loss and gradient.
"""Computes the logistic loss.
Parameters
----------
w : ndarray, shape = (n_features,) or (n_features + 1,)
w : ndarray, shape (n_features,) or (n_features + 1,)
Coefficient vector
X : {array-like, sparse matrix}, shape (n_samples, n_features)
Expand All @@ -117,12 +113,11 @@ def _logistic_loss(w, X, y, alpha, fit_intercept=False):


def _logistic_loss_grad_hess(w, X, y, alpha):
"""
Computes the logistic loss, gradient and the Hessian.
"""Computes the logistic loss, gradient and the Hessian.
Parameters
----------
w : ndarray, shape = (n_features,) or (n_features + 1,)
w : ndarray, shape (n_features,) or (n_features + 1,)
Coefficient vector
X : {array-like, sparse matrix}, shape (n_samples, n_features)
Expand All @@ -142,7 +137,7 @@ def _logistic_loss_grad_hess(w, X, y, alpha):
# Logistic loss is the negative of the log of the logistic function.
out = -np.sum(log_logistic(yz)) + .5 * alpha * np.dot(w, w)

z = special.expit(yz)
z = expit(yz)
z0 = (z - 1) * y
grad[:n_features] = safe_sparse_dot(X.T, z0) + alpha * w
if c is not None:
Expand All @@ -159,7 +154,7 @@ def _logistic_loss_grad_hess(w, X, y, alpha):
dX = d[:, np.newaxis] * X

if c is not None:
# Calculate the double derivative with respect to intecept
# Calculate the double derivative with respect to intercept
# In the case of sparse matrices this returns a matrix object.
dd_intercept = np.squeeze(np.array(dX.sum(axis=0)))

Expand All @@ -178,11 +173,10 @@ def Hs(s):


def logistic_regression_path(X, y, Cs=10, fit_intercept=True,
max_iter=100, gtol=1e-4, verbose=0,
max_iter=100, tol=1e-4, verbose=0,
solver='liblinear', callback=None,
coef=None):
"""
Compute a Logistic Regression model for a list of regularization
"""Compute a Logistic Regression model for a list of regularization
parameters.
This is an implementation that uses the result of the previous model
Expand Down Expand Up @@ -210,11 +204,10 @@ def logistic_regression_path(X, y, Cs=10, fit_intercept=True,
max_iter : integer
Maximum number of iterations for the solver.
gtol : float
tol : float
Stopping criterion. The iteration will stop when
``max{|g_i | i = 1, ..., n} <= gtol``
where ``g_i`` is the i-th component of the gradient. Only used
by the methods 'lbfgs' and 'trust-ncg'
``max{|g_i | i = 1, ..., n} <= tol``
where ``g_i`` is the i-th component of the gradient.
verbose: int
Print convergence message if True.
Expand All @@ -231,20 +224,18 @@ def logistic_regression_path(X, y, Cs=10, fit_intercept=True,
Returns
-------
coefs: array of shape (n_cs, n_features) or (n_cs, n_features + 1)
coefs: ndarray, shape (n_cs, n_features) or (n_cs, n_features + 1)
List of coefficients for the Logistic Regression model. If
fit_intercept is set to True then the seconds dimension will be
fit_intercept is set to True then the second dimension will be
n_features + 1, where the last item represents the intercept.
Notes
-----
You might get slighly different results with the solver trust-ncg than
with the others since this uses LIBLINEAR penalizes the intercept.
"""
if isinstance(Cs, numbers.Integral):
Cs = np.logspace(-4, 4, Cs)
Cs = np.sort(Cs)

y = np.sign(y - np.asarray(y).mean())
X = atleast2d_or_csc(X, dtype=np.float64)
Expand Down Expand Up @@ -273,20 +264,20 @@ def logistic_regression_path(X, y, Cs=10, fit_intercept=True,
out = optimize.fmin_l_bfgs_b(
func, w0, fprime=None,
args=(X, y, 1. / C),
iprint=verbose > 0, pgtol=gtol, maxiter=max_iter)
iprint=verbose > 0, pgtol=tol, maxiter=max_iter)
except TypeError:
# old scipy doesn't have maxiter
out = optimize.fmin_l_bfgs_b(
func, w0, fprime=None,
args=(X, y, 1. / C),
iprint=verbose > 0, pgtol=gtol)
iprint=verbose > 0, pgtol=tol)
w0 = out[0]
elif solver == 'newton-cg':
grad = lambda x, *args: _logistic_loss_and_grad(x, *args)[1]
w0 = newton_cg(_logistic_loss_grad_hess, _logistic_loss, grad,
w0, args=(X, y, 1./C), maxiter=max_iter)
elif solver == 'liblinear':
lr = LogisticRegression(C=C, fit_intercept=fit_intercept, tol=gtol)
lr = LogisticRegression(C=C, fit_intercept=fit_intercept, tol=tol)
lr.fit(X, y)
if fit_intercept:
w0 = np.concatenate([lr.coef_.ravel(), lr.intercept_])
Expand All @@ -304,33 +295,43 @@ def logistic_regression_path(X, y, Cs=10, fit_intercept=True,
# helper function for LogisticCV
def _log_reg_scoring_path(X, y, train, test, Cs=10, scoring=None,
fit_intercept=False,
max_iter=100, gtol=1e-4,
tol=1e-4, verbose=0, method='liblinear'):
"""
Computes scores across logistic_regression_path
max_iter=100, tol=1e-4,
verbose=0, method='liblinear'):
"""Computes scores across logistic_regression_path
Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples, n_features)
Training data.
y : array-like, shape (n_samples,) or (n_samples, n_targets)
Target values
Target labels
train : list of indices
The indices of the train set
test : list of indices
The indices of the test set
Cs: list of floats, integer
Each of the values in Cs describes the inverse of
regularization strength and must be a positive float.
If not provided, then a fixed set of values for Cs are used.
scoring : callable
For a list of scoring functions that can be used, look at
:mod:`sklearn.metrics`. The default scoring option used is
accuracy_score.
fit_intercept : bool
If False, then the bias term is set to zero.
If False, then the bias term is set to zero. Else the last
term of each coef_ gives us the intercept.
max_iter : int
Maximum no. of iterations for the solver.
gtol : float
Stopping criteria
tol : float
Tolerance for stopping criteria.
verbose : int
Amount of verbosity
Expand All @@ -346,7 +347,7 @@ def _log_reg_scoring_path(X, y, train, test, Cs=10, scoring=None,
fit_intercept=fit_intercept,
solver=method,
max_iter=max_iter,
gtol=gtol, verbose=verbose)
tol=tol, verbose=verbose)
scores = list()
X_test = X[test]
y_test = y[test]
Expand Down Expand Up @@ -423,13 +424,13 @@ class LogisticRegression(BaseLibLinear, LinearClassifierMixin,
Attributes
----------
`coef_` : array, shape = [n_classes, n_features]
`coef_` : array, shape (n_classes, n_features)
Coefficient of the features in the decision function.
`coef_` is readonly property derived from `raw_coef_` that \
follows the internal memory layout of liblinear.
`intercept_` : array, shape = [n_classes]
`intercept_` : array, shape = (n_classes,)
Intercept (a.k.a. bias) added to the decision function.
If `fit_intercept` is set to False, the intercept is set to zero.
Expand Down Expand Up @@ -528,47 +529,49 @@ class LogisticRegressionCV(BaseEstimator, LinearClassifierMixin,
See the module :mod:`sklearn.cross_validation` module for the
list of possible cross-validation objects.
max_iter: integer, optional
Maximum number of iterations of the optimization algorithm.
gtol: float, optional
Tolerance for stopping criteria.
scoring: callabale
Scoring function to use as cross-validation criteria.
Scoring function to use as cross-validation criteria. For a list of
scoring functions that can be used, look at :mod:`sklearn.metrics`.
The default scoring option used is accuracy_score.
solver: {'newton-cg', 'lbfgs', 'liblinear'}
Algorithm to use in the optimization problem.
verbose : bool or integer
Amount of verbosity.
tol: float, optional
Tolerance for stopping criteria.
max_iter: integer, optional
Maximum number of iterations of the optimization algorithm.
n_jobs : integer, optional
Number of CPU cores used during the cross-validation loop. If given
a value of -1, all cores are used.
verbose : bool or integer
Amount of verbosity.
Attributes
----------
`coef_` : array, shape = (n_classes-1, n_features)
`coef_` : array, shape (n_classes-1, n_features)
Coefficient of the features in the decision function.
`coef_` is readonly property derived from `raw_coef_` that \
follows the internal memory layout of liblinear.
`intercept_` : array, shape = (n_classes-1)
`intercept_` : array, shape (n_classes-1)
Intercept (a.k.a. bias) added to the decision function.
It is available only when parameter intercept is set to True.
`Cs_` : array
Array of C i.e inverse of regularization parameter values used
for cross-validation.
`coefs_paths_` : array, shape = (n_folds, len(Cs_), n_features + 1) or
`coefs_paths_` : array, shape (n_folds, len(Cs_), n_features + 1) or
(n_folds, len(Cs_), n_features + 1)
path of coefficients obtained during cross-validating across each
fold and then across each Cs.
`scores_` : array, shape = [n_folds, len(Cs_)]
`scores_` : array, shape (n_folds, len(Cs_))
grid of scores obtained during cross-validating each fold.
See also
Expand All @@ -578,14 +581,13 @@ class LogisticRegressionCV(BaseEstimator, LinearClassifierMixin,
"""

def __init__(self, Cs=10, fit_intercept=True, cv=None, scoring=None,
solver='newton-cg', tol=1e-4, gtol=1e-4, max_iter=100,
solver='newton-cg', tol=1e-4, max_iter=100,
n_jobs=1, verbose=False):
self.Cs = Cs
self.fit_intercept = fit_intercept
self.cv = cv
self.scoring = scoring
self.tol = tol
self.gtol = gtol
self.max_iter = max_iter
self.n_jobs = n_jobs
self.verbose = verbose
Expand All @@ -596,11 +598,11 @@ def fit(self, X, y):
Parameters
----------
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
X : {array-like, sparse matrix}, shape (n_samples, n_features)
Training vector, where n_samples in the number of samples and
n_features is the number of features.
y : array-like, shape = [n_samples]
y : array-like, shape (n_samples,)
Target vector relative to X
Returns
Expand All @@ -627,7 +629,7 @@ def fit(self, X, y):
fit_intercept=self.fit_intercept,
method=self.solver,
max_iter=self.max_iter,
gtol=self.gtol, tol=self.tol,
tol=self.tol,
verbose=max(0, self.verbose - 1),
scoring=self.scoring,
)
Expand All @@ -643,7 +645,7 @@ def fit(self, X, y):
w, _ = logistic_regression_path(
X, y, Cs=[self.C_], fit_intercept=self.fit_intercept,
coef=coef_init, solver=self.solver, max_iter=self.max_iter,
gtol=self.gtol, verbose=max(0, self.verbose - 1))
tol=self.tol, verbose=max(0, self.verbose - 1))
w = w[0][:, np.newaxis].T
if self.fit_intercept:
self.coef_ = w[:, :-1]
Expand Down

0 comments on commit cd1e238

Please sign in to comment.