Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX more precise log loss gradient and hessian #28048

Merged
merged 8 commits into from
Jan 9, 2024

Conversation

lorentzenchr
Copy link
Member

Reference Issues/PRs

Fixes #28046.

What does this implement/fix? Explain your changes.

This PR improves gradient and hessian of HalfBinomialLoss thereby preventing overflow of exp(large number) resulting in inf/nan return values.

The implemented change is very carefully designed and tested for minimal to no runtime/performance penalty.

Any other comments?

@github-actions github-actions bot added the cython label Jan 2, 2024
Copy link

github-actions bot commented Jan 2, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 0293fad. Link to the linter CI: here

@lorentzenchr
Copy link
Member Author

lorentzenchr commented Jan 2, 2024

Benchmark for gradient on arrays of length 100_000:

Gradient version timing with large values timing values in [-10, 10]
sklearn v1.3 (version 1) 723 µs ± 54.9 µs 691 µs ± 71.8 µs
stable see [1] (version 2) 1.02 ms ± 116 µs 673 µs ± 54.3 µs
this PR (version 3) 740 µs ± 81.5 µs 658 µs ± 17.2 µs

[1] https://fa.bianp.net/blog/2019/evaluate_logistic/

%load_ext cython

import numpy as np

# 1. numpy ufunc version, the stable version 1
#   Problem: Returns NaN for large negative values of raw.
def np_gradient_stable1(y_true, raw):
    exp_tmp = np.exp(-raw)
    return ((1 - y_true) - y_true * exp_tmp) / (1 + exp_tmp)


# 2. numpy ufunc version, the more stable version 2
# See https://fa.bianp.net/blog/2019/evaluate_logistic/
def np_gradient_stable2(y_true, raw):
    """Compute expit(x) - b component-wise."""
    out = np.empty_like(raw)
    idx = raw < 0
    exp_r = np.exp(raw[idx])
    y_idx = y_true[idx]
    out[idx] = ((1 - y_idx) * exp_r - y_idx) / (1 + exp_r)
    exp_nr = np.exp(-raw[~idx])
    y_nidx = y_true[~idx]
    out[~idx] = ((1 - y_nidx) - y_nidx * exp_nr) / (1 + exp_nr)
    return out


%%cython -3
# distutils: extra_compile_args = -O3
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False

import cython
import numpy as np

from libc.math cimport exp
cimport numpy as np


np.import_array()

    
cdef inline double c_gradient1(double y_true, double raw) nogil:
    cdef double exp_tmp = exp(-raw)
    return ((1 - y_true) - y_true * exp_tmp) / (1 + exp_tmp)


cdef inline double c_gradient2(double y_true, double raw) nogil:
    cdef double exp_tmp
    if raw < 0:
        exp_tmp = exp(raw)
        return ((1 - y_true) * exp_tmp - y_true) / (1 + exp_tmp)
    else:
        exp_tmp = exp(-raw)
        return ((1 - y_true) - y_true * exp_tmp) / (1 + exp_tmp)


cdef inline double c_gradient3(double y_true, double raw) nogil:
    cdef double exp_tmp
    # Help branch prediction.
    # Note that scipy.special.logit(np.finfo(float).eps) ~ -36.04365
    if raw > -37:
        exp_tmp = exp(-raw)
        return ((1 - y_true) - y_true * exp_tmp) / (1 + exp_tmp)
    else:
        # expit(raw) = exp(raw) for raw < -37
        return exp(raw) - y_true


### 2. Cython function: loop over ndarray by calling C level functions
def cy_gradient_stable1(double[::1] y_true, double[::1] raw):
    cdef:
        int n_samples
        int i
        cdef double[::1] out = np.empty_like(y_true)
    
    n_samples = y_true.shape[0]
    for i in range(n_samples):
        out[i] = c_gradient1(y_true[i], raw[i])
        
    return out


def cy_gradient_stable2(double[::1] y_true, double[::1] raw):
    cdef:
        int n_samples
        int i
        cdef double[::1] out = np.empty_like(y_true)
    
    n_samples = y_true.shape[0]
    for i in range(n_samples):
        out[i] = c_gradient2(y_true[i], raw[i])
        
    return out


def cy_gradient_stable3(double[::1] y_true, double[::1] raw):
    cdef:
        int n_samples
        int i
        cdef double[::1] out = np.empty_like(y_true)
    
    n_samples = y_true.shape[0]
    for i in range(n_samples):
        out[i] = c_gradient3(y_true[i], raw[i])
        
    return out


rng = np.random.default_rng(0)
y_true = rng.binomial(1, 0.5, size=100_000).astype(np.float64)
raw = 20 * rng.standard_normal(100_000, dtype=np.float64)  # make sure some values are <= -37 and > 33
print(f"min and max raw = {np.min(raw)}, {np.max(raw)}")
# min and max raw = -91.87980390791544, 85.34683410460212

np.allclose(np_gradient_stable1(y_true, raw), np_gradient_stable2(y_true, raw)), \
np.allclose(np_gradient_stable1(y_true, raw), cy_gradient_stable1(y_true, raw)), \
np.allclose(np_gradient_stable1(y_true, raw), cy_gradient_stable2(y_true, raw)), \
np.allclose(np_gradient_stable1(y_true, raw), cy_gradient_stable3(y_true, raw))
# True

%timeit -r20 np_gradient_stable1(y_true, raw)
# 666 µs ± 20.3 µs per loop (mean ± std. dev. of 20 runs, 1,000 loops each)

%timeit -r20 np_gradient_stable2(y_true, raw)
3.88 ms ± 223 µs per loop (mean ± std. dev. of 20 runs, 100 loops each)

%timeit -r20 cy_gradient_stable1(y_true, raw)
# 723 µs ± 54.9 µs per loop (mean ± std. dev. of 20 runs, 1,000 loops each)

%timeit -r20 cy_gradient_stable2(y_true, raw)
# 1.02 ms ± 116 µs per loop (mean ± std. dev. of 20 runs, 1,000 loops each)

%timeit -r20 cy_gradient_stable3(y_true, raw)
740 µs ± 81.5 µs per loop (mean ± std. dev. of 20 runs, 1,000 loops each)


# Same for smaller values of raw
raw2 = np.linspace(-10, 10, 100_000)

%timeit -r20 cy_gradient_stable1(y_true, raw2)
# 691 µs ± 71.8 µs per loop (mean ± std. dev. of 20 runs, 1,000 loops each)

%timeit -r20 cy_gradient_stable2(y_true, raw2)
# 673 µs ± 54.3 µs per loop (mean ± std. dev. of 20 runs, 1,000 loops each)

%timeit -r20 cy_gradient_stable3(y_true, raw2)
# 658 µs ± 17.2 µs per loop (mean ± std. dev. of 20 runs, 1,000 loops each)

Precision

image

"version 1" is the actual implementation, "version 3" this PR.
Observe the red outliers in the top right corner of version 1!

import mpmath as mp


# Stolen from scipy
def mpf2float(x):
    """
    Convert an mpf to the nearest floating point number. Just using
    float directly doesn't work because of results like this:

    with mp.workdps(50):
        float(mpf("0.99999999999999999")) = 0.9999999999999999

    """
    return float(mp.nstr(x, 17, min_fixed=0, max_fixed=0))


def mp_gradient(y_true, raw, dps=50):
    y_true, raw = np.asarray(y_true), np.asarray(raw)
    out = np.empty_like(y_true)
    with mp.workdps(dps):
        for i in range(len(y_true)):
            y, r = mp.mpf(float(y_true[i])), mp.mpf(float(raw[i]))
            res = mp.mpf(1) / (mp.mpf(1) + mp.exp(-r)) - y
            out[i] = mpf2float(res)
    return out


def rel_accuracy(test, reference):
    result = np.abs((test - reference) / np.maximum(1, reference))
    result[np.isnan(result)] = 10
    return result


import matplotlib.pyplot as plt


def scatter_with_outlier(ax, x, y, threshold=1e-5, **kwargs):
    ax.scatter(x, y, **kwargs)
    mask = y >= threshold
    ax.scatter(x[mask], y[mask], color="red", **kwargs)

raw3 = np.sinh(np.linspace(np.arcsinh(-1000), np.arcsinh(1000), 20_001))
exact = mp_gradient(np.ones_like(raw3), raw3)
result1 = np.asarray(cy_gradient_stable1(np.ones_like(raw3), raw3))
result2 = np.asarray(cy_gradient_stable2(np.ones_like(raw3), raw3))
result3 = np.asarray(cy_gradient_stable3(np.ones_like(raw3), raw3))


fig, axes = plt.subplots(ncols=3, figsize=(12, 4), sharey=True)
scatter_with_outlier(axes[0], raw3, rel_accuracy(result1, exact), s=1, label="version 1")
scatter_with_outlier(axes[1], raw3, rel_accuracy(result2, exact), s=1, label="version 2")
scatter_with_outlier(axes[2], raw3, rel_accuracy(result3, exact), s=1, label="version 3")
axes[0].set_ylabel('relative precision')
for i in range(len(axes)):
    axes[i].set_xlabel('raw_prediction')
    axes[i].set_xscale('symlog')
    axes[i].set_yscale('symlog', linthresh=1e-30)
    axes[i].set_title(f"version {i+1}")
fig.suptitle("Precision with y_true=1")

Details for mpmath implementation:

def mp_logloss(y_true, raw):
    with mp.workdps(100):
        y_true, raw = mp.mpf(float(y_true)), mp.mpf(float(raw))
        out = mp.log1p(mp.exp(raw)) - y_true * raw
    return mpf2float(out)


def mp_gradient(y_true, raw):
    with mp.workdps(100):
        y_true, raw = mp.mpf(float(y_true)), mp.mpf(float(raw))
        out = mp.mpf(1) / (mp.mpf(1) + mp.exp(-raw)) - y_true
    return mpf2float(out)

def mp_hessian(y_true, raw):
    with mp.workdps(100):
        y_true, raw = mp.mpf(float(y_true)), mp.mpf(float(raw))
        p = mp.mpf(1) / (mp.mpf(1) + mp.exp(-raw))
        out = p * (mp.mpf(1) - p)
    return mpf2float(out)


y, raw = 0.0, 37.
mp_logloss(y, raw), mp_gradient(y, raw), mp_hessian(y, raw)

@lorentzenchr
Copy link
Member Author

@jjerphan gentle ping for a review. The actual change is quite small, comments and tests are the majority of the diff.

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you, @lorentzenchr.

Side-comment: I wonder if we could use mpmath optionally in tests ; see one of the comments in context.

sklearn/_loss/tests/test_loss.py Outdated Show resolved Hide resolved
sklearn/_loss/_loss.pyx.tp Show resolved Hide resolved
sklearn/_loss/_loss.pyx.tp Show resolved Hide resolved
sklearn/_loss/tests/test_loss.py Outdated Show resolved Hide resolved
sklearn/_loss/_loss.pyx.tp Show resolved Hide resolved
@lorentzenchr
Copy link
Member Author

lorentzenchr commented Jan 8, 2024

@glemaitre @lesteve While working on #28063, I played a little bit around in https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_regularization.html#sphx-glr-auto-examples-ensemble-plot-gradient-boosting-regularization-py and got some errors pretty soon.
Those errors are fixed by this PR. So I would advocate to bring it in v1.4.1 as bugfix.

I'm sorry to have introduced this bug with the new loss functions for the old Gradient Boosting. There are, however, no tests for it in the old GB tests!!!

@lorentzenchr lorentzenchr changed the title ENH more precise log loss gradient and hessian FIX more precise log loss gradient and hessian Jan 8, 2024
@lorentzenchr lorentzenchr added this to the 1.4 milestone Jan 9, 2024
@glemaitre
Copy link
Member

We can introduce in 1.4.0 still. This is a regression that we got before releasing (that is nice). To be certain, it would still be good to have the entry in the changelog because the bug could occur in the HistGradientBoosting prior to 1.4?

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I played a little bit around in https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_regularization.html#sphx-glr-auto-examples-ensemble-plot-gradient-boosting-regularization-py and got some errors pretty soon.

Out of curiosity, what kind of hyperparameter combinations lead to large raw_prediction values leading to numerical stability problems during training?

Would it make sense to include such a case as a public API level non-regression tests?

Otherwise I am fine with the private-level loss module non-regression tests only. They are quite extensive and look good to me.

Please let move the changelog entry to 1.4.0 prior to merging this PR though.

:class:`ensemble.HistGradientBoostingClassifier` and
:class:`linear_model.LogisticRegression`.
:pr:`28048` by :user:`Christian Lorentzen <lorentzenchr>`.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this fix should be included in 1.4.0. Please move the changelog entry accordingly.

@lorentzenchr
Copy link
Member Author

lorentzenchr commented Jan 9, 2024

Out of curiosity, what kind of hyperparameter combinations lead to large raw_prediction values leading to numerical stability problems during training?

I changed "min_samples_split": 5, to "min_samples_leaf": 2,.

Would it make sense to include such a case as a public API level non-regression tests?

I think the test of the loss functions cover that very thoroughly. To be sure, I added 0293fad.

Please let move the changelog entry to 1.4.0 prior to merging this PR though.

Done. But have a look at the place. It effects different classes from very different modules.

@jjerphan jjerphan merged commit 5ad8e45 into scikit-learn:main Jan 9, 2024
27 checks passed
@lorentzenchr lorentzenchr deleted the logloss_gradient branch January 9, 2024 19:30
@ogrisel
Copy link
Member

ogrisel commented Jan 10, 2024

Thanks for the fix. @glemaitre @jeremiedbb this will require a backport to 1.4.X.

@glemaitre glemaitre added the To backport PR merged in master that need a backport to a release branch defined based on the milestone. label Jan 10, 2024
jeremiedbb pushed a commit to jeremiedbb/scikit-learn that referenced this pull request Jan 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug cython Numerical Stability Performance To backport PR merged in master that need a backport to a release branch defined based on the milestone.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Log Loss gradient and hessian returns NaN for large negative values
4 participants