Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX CalibratedClassifierCV with sigmoid and large confidence scores #26913

Merged
merged 20 commits into from Aug 25, 2023

Conversation

OmarManzoor
Copy link
Contributor

Reference Issues/PRs

Fixes: #26766

What does this implement/fix? Explain your changes.

  • Adds fallback to Nelder-Mead if CalibratedClassifierCV with sigmoid method faces convergence problems with BFGS.
  • Adds a non regression test.

Any other comments?

CC: @ogrisel Could you kindly check if this looks suitable?

@github-actions
Copy link

github-actions bot commented Jul 27, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: db13ac0. Link to the linter CI: here

@OmarManzoor OmarManzoor changed the title Fix CalibratedClassifierCV with sigmoid and large confidence scores FIX CalibratedClassifierCV with sigmoid and large confidence scores Jul 27, 2023
@lorentzenchr
Copy link
Member

lorentzenchr commented Jul 27, 2023

What is the root cause that lbfgs fails? I would like to understand that better before introducing a fallback solver (use scipy.optimize.minimize(..., method=’L-BFGS-B, verbose=3)).
Also, could we use average instead of sum in the objective, please. It helps the solver.

@OmarManzoor
Copy link
Contributor Author

What is the root cause that lbfgs fails? I would like to understand that better before introducing a fallback solver (use scipy.optimize.minimize(..., method=’L-BFGS-B, verbose=3)). Also, could we use average instead of sum in the objective, please. It helps the solver.

Hi @lorentzenchr

Will using scipy.optimize.minimize(..., method=’L-BFGS-B, verbose=3)) produce more information than using fmin_bfgs(objective, AB0, fprime=grad, disp=True) # Turn on outputs which is already done in the original issue?

@lorentzenchr
Copy link
Member

Will using scipy.optimize.minimize(..., method=’L-BFGS-B, verbose=3)) produce more information than using fmin_bfgs(objective, AB0, fprime=grad, disp=True) # Turn on outputs which is already done in the original issue?

Yes, I think so (maybe play with the verbosity level).

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Jul 31, 2023

@lorentzenchr I don't think we are using lbfgs here but instead simply bfgs. Also I don't think there is any verbose option in scipy.optimize.minimize. This is how it is done in the example https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html

res = minimize(rosen, x0, method='BFGS', jac=rosen_der,
               options={'gtol': 1e-6, 'disp': True})

This shows that disp: True is the option for getting a verbose output.

@lorentzenchr
Copy link
Member

@OmarManzoor You're right. What I meant was

scipy.optimize.minimize(func, x0,  method="L-BFGS-B", options={"iprint": 101,))

This only works for lbfgs, but I would switch from bfgs to lbfgs anyway.

@OmarManzoor
Copy link
Contributor Author

@lorentzenchr for simplicity I used the public function to keep it consistent with the file.

    AB0 = np.array([0.0, log((prior0 + 1.0) / (prior1 + 1.0))])
    AB_ = fmin_l_bfgs_b(objective, AB0, fprime=grad, iprint=101)

    return AB_[0], AB_[1]

Here is the output for this script

import numpy as np
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score

# Setting up a trivially easy classification problem
r = 0.67
N = 1000
y_train = np.array([1] * int(N * r)+ [0] * (N - int(N * r)))
X_train = 1e5 * y_train.reshape((-1,1)) + np.random.default_rng(42).normal(size=N)
model = CalibratedClassifierCV(SGDClassifier(loss='squared_hinge',random_state=42))
print('Logistic calibration: ', cross_val_score(model, X_train, y_train, scoring='roc_auc').mean())


# model = CalibratedClassifierCV(SGDClassifier(loss='squared_hinge',random_state=42),method='isotonic')
# print('Isotonic calibration: ', cross_val_score(model,X_train,y_train,scoring='roc_auc').mean())

Output file

@lorentzenchr
Copy link
Member

It seems there is happening something. The ValueError in the output is a good hint.
But the snippet above is still wrong: have a look at the return value of the minimizer. I strongly recommend to use minimize(…, method=…) as already written. It returns an OptimizationResult object.

@OmarManzoor
Copy link
Contributor Author

scipy.optimize.minimize(func, x0,  method="L-BFGS-B", options={"iprint": 101,))

Okay I fixed it. It was a simple matter of just picking the first value in the result array.

AB_ = fmin_l_bfgs_b(objective, AB0, fprime=grad, iprint=101)

return AB_[0][0], AB_[0][1]

Here is the output
https://gist.github.com/OmarManzoor/fe4a992b14a7007cb540f778e33e1b99

@sktin
Copy link

sktin commented Aug 1, 2023

What is the root cause that lbfgs fails? I would like to understand that better before introducing a fallback solver (use scipy.optimize.minimize(..., method=’L-BFGS-B, verbose=3)). Also, could we use average instead of sum in the objective, please. It helps the solver.

@OmarManzoor @lorentzenchr Thanks for looking into this. Since the original submission of the issue, I took another look at what's going on in the optimization. The failure comes from within scipy.optimize, in the _zoom routine that implements the "zoom algorithm" ( Nocedal and Wright 1999 Algorithm 3.3) to determine the step size. In the scipy implementation, the maximum number of iterations is hardcoded to 10.

https://github.com/scipy/scipy/blob/v1.11.1/scipy/optimize/_linesearch.py#L539

Increasing maxiter would resolve the example I concocted, but there is no guarantee that it would work for more extreme cases, not to mention that we would need to ask the scipy team to change their code.

from scipy.optimize._linesearch import _cubicmin, _quadmin

def _zoom_debug(a_lo, a_hi, phi_lo, phi_hi, derphi_lo,
          phi, derphi, phi0, derphi0, c1, c2, extra_condition):
    """Zoom stage of approximate linesearch satisfying strong Wolfe conditions.

    Part of the optimization algorithm in `scalar_search_wolfe2`.

    Notes
    -----
    Implements Algorithm 3.6 (zoom) in Wright and Nocedal,
    'Numerical Optimization', 1999, pp. 61.

    """
    
    maxiter = 100 # DEBUG: Increase from 10 to 100
    i = 0
    delta1 = 0.2  # cubic interpolant check
    delta2 = 0.1  # quadratic interpolant check
    phi_rec = phi0
    a_rec = 0
    while True:
        # interpolate to find a trial step length between a_lo and
        # a_hi Need to choose interpolation here. Use cubic
        # interpolation and then if the result is within delta *
        # dalpha or outside of the interval bounded by a_lo or a_hi
        # then use quadratic interpolation, if the result is still too
        # close, then use bisection

        dalpha = a_hi - a_lo
        if dalpha < 0:
            a, b = a_hi, a_lo
        else:
            a, b = a_lo, a_hi

        # minimizer of cubic interpolant
        # (uses phi_lo, derphi_lo, phi_hi, and the most recent value of phi)
        #
        # if the result is too close to the end points (or out of the
        # interval), then use quadratic interpolation with phi_lo,
        # derphi_lo and phi_hi if the result is still too close to the
        # end points (or out of the interval) then use bisection

        if (i > 0):
            cchk = delta1 * dalpha
            a_j = _cubicmin(a_lo, phi_lo, derphi_lo, a_hi, phi_hi,
                            a_rec, phi_rec)
        if (i == 0) or (a_j is None) or (a_j > b - cchk) or (a_j < a + cchk):
            qchk = delta2 * dalpha
            a_j = _quadmin(a_lo, phi_lo, derphi_lo, a_hi, phi_hi)
            if (a_j is None) or (a_j > b-qchk) or (a_j < a+qchk):
                a_j = a_lo + 0.5*dalpha

        # Check new value of a_j

        phi_aj = phi(a_j)
        if (phi_aj > phi0 + c1*a_j*derphi0) or (phi_aj >= phi_lo):
            phi_rec = phi_hi
            a_rec = a_hi
            a_hi = a_j
            phi_hi = phi_aj
        else:
            derphi_aj = derphi(a_j)
            if abs(derphi_aj) <= -c2*derphi0 and extra_condition(a_j, phi_aj):
                a_star = a_j
                val_star = phi_aj
                valprime_star = derphi_aj
                break
            if derphi_aj*(a_hi - a_lo) >= 0:
                phi_rec = phi_hi
                a_rec = a_hi
                a_hi = a_lo
                phi_hi = phi_lo
            else:
                phi_rec = phi_lo
                a_rec = a_lo
            a_lo = a_j
            phi_lo = phi_aj
            derphi_lo = derphi_aj
        i += 1
        if (i > maxiter):
            # Failed to find a conforming step size
            print('Failed to find a conforming step size') # DEBUG: maxiter exceeded
            a_star = None
            val_star = None
            valprime_star = None
            break
    return a_star, val_star, valprime_star

import scipy
scipy.optimize._linesearch._zoom =  _zoom_debug
model = CalibratedClassifierCV(SGDClassifier(loss='squared_hinge',random_state=42))
print('Logistic calibration: ', cross_val_score(model,X_train,y_train,scoring='roc_auc').mean())

Expected output:

Logistic calibration:  1.0

@lorentzenchr
Copy link
Member

Please have a look on how the inner max iter is set in lbfgs in

opt_res = scipy.optimize.minimize(

Maybe that solves the issue.

@OmarManzoor
Copy link
Contributor Author

Please have a look on how the inner max iter is set in lbfgs in

opt_res = scipy.optimize.minimize(

Maybe that solves the issue.

Is there anything to do in this PR then? If not maybe we can close it.

@lorentzenchr
Copy link
Member

Is there anything to do in this PR then? If not maybe we can close it.

Yes. The added test is valuable. And I would like to see the change from the current solver to (copy & paste) the GLM way to to it hoping that the tests pass.

@OmarManzoor
Copy link
Contributor Author

@lorentzenchr I tried using this. All the tests pass aside from the one that we added in this PR. Even with this we get a cross val score of 0.5.

opt_res = minimize(
        objective,
        AB0,
        method="L-BFGS-B",
        jac=grad,
        options={
            "maxiter": 100,
            "maxls": 100,
            "gtol": 1e-6,
            # The constant 64 was found empirically to pass the test suite.
            # The point is that ftol is very small, but a bit larger than
            # machine precision for float64, which is the dtype used by lbfgs.
            "ftol": 64 * np.finfo(float).eps,
        }
    )
AB_ = opt_res.x
return AB_[0], AB_[1]

@lorentzenchr
Copy link
Member

@OmarManzoor Could you try replacing the objective and grad by the ones from sklearn._loss? Something lik

def loss_grad(AB):
        l, g = loss.loss_gradient(y_true=T, raw_prediction=(AB[0] * F + AB[1]))
        grad = [F @ g, g.sum()]
        return l.sum(), grad

opt_res = minimize(
    loss_grad,
    AB0,
    method="L-BFGS-B",
    jac=True,
    options={"iprint": 1, #"gtol": tol, "maxiter": max_iter},
)
AB_ = opt_res.x

It might helpt to compare to LinearModelLoss.

@OmarManzoor
Copy link
Contributor Author

@lorentzenchr

Code:

    loss = HalfBinomialLoss()

    def loss_grad(AB):
        l, g = loss.loss_gradient(y_true=T, raw_prediction=-(AB[0] * F + AB[1]))
        grad = np.array([F @ g, g.sum()])
        return l.sum(), grad

    opt_res = minimize(
        loss_grad,
        AB0,
        method="L-BFGS-B",
        jac=True,
        options={
            "iprint": 1,
            "maxiter": 15000,
            "maxls": 100,
            "gtol": 1e-6,
            "ftol": 64 * np.finfo(float).eps,
        },
    )
    AB_ = opt_res.x
    return AB_[0], AB_[1]

Output of test output.txt

@lorentzenchr
Copy link
Member

@OmarManzoor Thanks for testing.
BTW, you can use a <details> ... </details> section in a github comment to make larger output available and not disturb reading.

The idea of using the _loss module was that it's implementations are numerically stable.

What happens, if we fit a LogisticRegression on T with F as single feature and fit_intercept=True?

@ogrisel pointed out that we might also supply the hessian as it is very easy to calculate. This corresponds to LogisticRegression(solver="newton-cholesky", penalty=None).fit(T, F[:, None]).

Another idea, already mentioned somewhere, is to re-scale T by a constant, such that max(abs(T)) < reasonable number, say, 10. Indeed, logistic regression without penalty is invariant to multiplying each feature X by a constant per feature.

After that I run out of ideas.

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Aug 9, 2023

@OmarManzoor Thanks for testing. BTW, you can use a <details> ... </details> section in a github comment to make larger output available and not disturb reading.

The idea of using the _loss module was that it's implementations are numerically stable.

What happens, if we fit a LogisticRegression on T with F as single feature and fit_intercept=True?

@ogrisel pointed out that we might also supply the hessian as it is very easy to calculate. This corresponds to LogisticRegression(solver="newton-cholesky", penalty=None).fit(T, F[:, None]).

Another idea, already mentioned somewhere, is to re-scale T by a constant, such that max(abs(T)) < reasonable number, say, 10. Indeed, logistic regression without penalty is invariant to multiplying each feature X by a constant per feature.

After that I run out of ideas.

I tried this

lg = LogisticRegression(solver="newton-cholesky", penalty=None).fit(T[:, None], F[:, None])

This is the error we get

ValueError: Unknown label type: continuous. Maybe you are trying to fit a classifier, which expects discrete classes on a regression target with continuous values.

This is a preview of the T array

[0.99090909 0.99090909 0.99090909 0.99090909 0.99090909 0.99090909, 0.99090909 0.99090909 0.99090909 0.99090909  ..................... ]

This is a preview of the F array

[ 3.65029719e+13  3.65029719e+13  3.65029719e+13  3.65029719e+13,  3.65029719e+13  3.65029719e+13  3.65029719e+13  3.65029719e+13,  3.65029719e+13  3.65029719e+13  .................]

@sktin
Copy link

sktin commented Aug 9, 2023

@lorentzenchr @OmarManzoor, @ogrisel LogisticRegression would not work as is because it checks for binary target.

The way I see it, if we are to keep the quasi-Newton approach (instead of a derivative-free approach), the only robust fix would be to scale the confidence scores if they are too large. We can scale only if the scores are really large (e.g., larger than 1e7 numerically) so that most working existing code would continue to get identical outputs from CalibratedClassifierCV.

For example, the following code should work (feel free to adjust according to sklearn coding standard).

def _sigmoid_calibration(predictions, y, sample_weight=None):
    """Probability Calibration with sigmoid method (Platt 2000)
    Parameters
    ----------
    predictions : ndarray of shape (n_samples,)
        The decision function or predict proba for the samples.
    y : ndarray of shape (n_samples,)
        The targets.
    sample_weight : array-like of shape (n_samples,), default=None
        Sample weights. If None, then samples are equally weighted.
    Returns
    -------
    a : float
        The slope.
    b : float
        The intercept.
    References
    ----------
    Platt, "Probabilistic Outputs for Support Vector Machines"
    """
    predictions = column_or_1d(predictions)
    y = column_or_1d(y)

    F = predictions  # F follows Platt's notations
    
    ### New code begins
    M0 = 1e7 # or some suitably chosen threshold 
    M = np.amax(np.abs(F))
    if M > M0:
        F /= M
    ### New code ends

    # Bayesian priors (see Platt end of section 2.2):
    # It corresponds to the number of samples, taking into account the
    # `sample_weight`.
    mask_negative_samples = y <= 0
    if sample_weight is not None:
        prior0 = (sample_weight[mask_negative_samples]).sum()
        prior1 = (sample_weight[~mask_negative_samples]).sum()
    else:
        prior0 = float(np.sum(mask_negative_samples))
        prior1 = y.shape[0] - prior0
    T = np.zeros_like(y, dtype=np.float64)
    T[y > 0] = (prior1 + 1.0) / (prior1 + 2.0)
    T[y <= 0] = 1.0 / (prior0 + 2.0)
    T1 = 1.0 - T

    def objective(AB):
        # From Platt (beginning of Section 2.2)
        P = expit(-(AB[0] * F + AB[1]))
        loss = -(xlogy(T, P) + xlogy(T1, 1.0 - P))
        if sample_weight is not None:
            return (sample_weight * loss).sum()
        else:
            return loss.sum()

    def grad(AB):
        # gradient of the objective function
        P = expit(-(AB[0] * F + AB[1]))
        TEP_minus_T1P = T - P
        if sample_weight is not None:
            TEP_minus_T1P *= sample_weight
        dA = np.dot(TEP_minus_T1P, F)
        dB = np.sum(TEP_minus_T1P)
        return np.array([dA, dB])

    AB0 = np.array([0.0, log((prior0 + 1.0) / (prior1 + 1.0))])
    AB_ = fmin_bfgs(objective, AB0, fprime=grad, disp=False)
    ### Changed code begins
    return AB_[0]/(M if M>M0 else 1), AB_[1]
    ### Changed code ends

@OmarManzoor
Copy link
Contributor Author

OmarManzoor commented Aug 9, 2023

@lorentzenchr @OmarManzoor, @ogrisel LogisticRegression would not work as is because it checks for binary target.

The way I see it, if we are to keep the quasi-Newton approach (instead of a derivative-free approach), the only robust fix would be to scale the confidence scores if they are too large. We can scale only if the scores are really large (e.g., larger than 1e7 numerically) so that most working existing code would continue to get identical outputs from CalibratedClassifierCV.

For example, the following code should work (feel free to adjust according to sklearn coding standard).
...

Thank you for sharing. @lorentzenchr What do you suggest?

@lorentzenchr
Copy link
Member

Thank you for sharing. @lorentzenchr What do you suggest?

Let's try the scaling approach and fix the threshold later. Additional, I'd like to see the loss replaced by the existing _loss module like in the above snippets.

@OmarManzoor
Copy link
Contributor Author

@lorentzenchr It seems to work now. The test passes with using re-scaling.

@lorentzenchr
Copy link
Member

LogisticRegression would not work as is because it checks for binary target.

BTW, continuous targets are supported by the private

from sklearn._loss import HalfBinomialLoss
from sklearn.linear_model._glm import _GeneralizedLinearRegressor


class BinomialRegressor(_GeneralizedLinearRegressor):
    def _get_loss(self):
        return HalfBinomialLoss()

See https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/linear_model/_glm/tests/test_glm.py.

@OmarManzoor
Copy link
Contributor Author

@lorentzenchr Could you kindly have a look and see if this looks okay? Then we can integrate the _loss module methods to replace the current obj and grad functions.

@lorentzenchr
Copy link
Member

@OmarManzoor Can you give mark the resolved comments as resolved, and may also add a 👍 in the thread?

sklearn/calibration.py Outdated Show resolved Hide resolved
sklearn/calibration.py Outdated Show resolved Hide resolved
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, I like the minimally invasive scaling solution. I think the tests could be improved as suggested below:

doc/whats_new/v1.4.rst Show resolved Hide resolved
sklearn/tests/test_calibration.py Outdated Show resolved Hide resolved
sklearn/calibration.py Outdated Show resolved Hide resolved
@ogrisel
Copy link
Member

ogrisel commented Aug 24, 2023

Has anyone tried to see if #26913 (comment) has a similar numerical stability problem in the end?

If so we need to make sure that the error message is informative enough, e.g. to tell the user to scale their data or increase the penalization. But this should be done a dedicated PR if needed (not to delay the merge of this fix).

@lorentzenchr
Copy link
Member

Has anyone tried to see if #26913 (comment) has a similar numerical stability problem in the end?

import numpy as np
from sklearn._loss import HalfBinomialLoss
from sklearn.linear_model._glm import _GeneralizedLinearRegressor
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score


class BinomialRegressor(_GeneralizedLinearRegressor):
    def _get_loss(self):
        return HalfBinomialLoss()


r = 0.67
N = 1000
y_train = np.array([1] * int(N * r)+ [0] * (N - int(N * r)))
X_train = 1e5 * y_train.reshape((-1,1)) + np.random.default_rng(42).normal(size=N)
sgd = SGDClassifier(loss='squared_hinge',random_state=42).fit(X_train, y_train)
F = sgd.decision_function(X_train)
print(f"{F.min()=}, {F.max()=}")  #(-7221324.4004922705, 25003110945558.523)

glm = BinomialRegressor(alpha=0)
glm.fit(F[:, None], y_train)
print(f"{glm.intercept_=} {glm.coef_}")  # 0.7081850579244856, 3.47115448e-06

No errors, no warnings.

sklearn/calibration.py Outdated Show resolved Hide resolved
scale_constant = max_prediction
# We rescale the features in a copy: inplace rescaling could confuse
# the caller and make the code harder to reason about.
F = F / scale_constant
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍 I did not realize that the way I had done it actually modified predictions too. Thanks for catching this! I guess this was the main reason why I was getting wrong results.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (assuming green CI)!

Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
@OmarManzoor Are you interested in a follow up PR replacing obj and grad by our _loss?

@OmarManzoor
Copy link
Contributor Author

LGTM @OmarManzoor Are you interested in a follow up PR replacing obj and grad by our _loss?

@lorentzenchr Yes sure.

@lorentzenchr lorentzenchr merged commit 8e9cd7d into scikit-learn:main Aug 25, 2023
27 checks passed
@OmarManzoor OmarManzoor deleted the calibrated_classifier_cv branch August 25, 2023 11:50
glemaitre pushed a commit to glemaitre/scikit-learn that referenced this pull request Aug 29, 2023
glemaitre pushed a commit to glemaitre/scikit-learn that referenced this pull request Sep 18, 2023
jeremiedbb pushed a commit that referenced this pull request Sep 20, 2023
…26913)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
REDVM pushed a commit to REDVM/scikit-learn that referenced this pull request Nov 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

CalibratedClassifierCV fails silently with large confidence scores
4 participants