Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Poisson loss for HistGradientBoostingRegressor #16692

Merged
merged 23 commits into from
Apr 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions doc/modules/ensemble.rst
Original file line number Diff line number Diff line change
Expand Up @@ -952,8 +952,9 @@ controls the number of iterations of the boosting process::
>>> clf.score(X_test, y_test)
0.8965

Available losses for regression are 'least_squares' and
'least_absolute_deviation', which is less sensitive to outliers. For
Available losses for regression are 'least_squares',
'least_absolute_deviation', which is less sensitive to outliers, and
'poisson', which is well suited to model counts and frequencies. For
classification, 'binary_crossentropy' is used for binary classification and
'categorical_crossentropy' is used for multiclass classification. By default
the loss is 'auto' and will select the appropriate loss depending on
Expand Down
7 changes: 6 additions & 1 deletion doc/whats_new/v0.23.rst
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ Changelog
to obtain the input to the meta estimator.
:pr:`16539` by :user:`Bill DeRose <wderose>`.

- |Feature| Added additional option `loss="poisson"` to
:class:`ensemble.HistGradientBoostingRegressor`, which adds Poisson deviance
with log-link useful for modeling count data.
:pr:`16692` by :user:`Christian Lorentzen <lorentzenchr>`

:mod:`sklearn.feature_extraction`
.................................

Expand Down Expand Up @@ -296,7 +301,7 @@ Changelog
- |API| Changed the formatting of values in
:meth:`metrics.ConfusionMatrixDisplay.plot` and
:func:`metrics.plot_confusion_matrix` to pick the shorter format (either '2g'
or 'd'). :pr:`16159` by :user:`Rick Mackenbach <Rick-Mackenbach>` and
or 'd'). :pr:`16159` by :user:`Rick Mackenbach <Rick-Mackenbach>` and
`Thomas Fan`_.

- |Enhancement| :func:`metrics.pairwise.pairwise_distances_chunked` now allows
Expand Down
33 changes: 31 additions & 2 deletions sklearn/ensemble/_hist_gradient_boosting/_loss.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ from cython.parallel import prange
import numpy as np
cimport numpy as np

from libc.math cimport exp
from libc.math cimport exp, log

from .common cimport Y_DTYPE_C
from .common cimport G_H_DTYPE_C
Expand All @@ -27,7 +27,7 @@ def _update_gradients_least_squares(

n_samples = raw_predictions.shape[0]
for i in prange(n_samples, schedule='static', nogil=True):
# Note: a more correct exp is 2 * (raw_predictions - y_true)
# Note: a more correct expression is 2 * (raw_predictions - y_true)
# but since we use 1 for the constant hessian value (and not 2) this
# is strictly equivalent for the leaves values.
gradients[i] = raw_predictions[i] - y_true[i]
Expand Down Expand Up @@ -87,6 +87,35 @@ def _update_gradients_least_absolute_deviation(
gradients[i] = 2 * (y_true[i] - raw_predictions[i] < 0) - 1


def _update_gradients_hessians_poisson(
G_H_DTYPE_C [::1] gradients, # OUT
G_H_DTYPE_C [::1] hessians, # OUT
const Y_DTYPE_C [::1] y_true, # IN
const Y_DTYPE_C [::1] raw_predictions, # IN
const Y_DTYPE_C [::1] sample_weight): # IN

cdef:
int n_samples
int i
Y_DTYPE_C y_pred

n_samples = raw_predictions.shape[0]
if sample_weight is None:
for i in prange(n_samples, schedule='static', nogil=True):
# Note: We use only half of the deviance loss. Therefore, there is
# no factor of 2.
y_pred = exp(raw_predictions[i])
gradients[i] = (y_pred - y_true[i])
hessians[i] = y_pred
else:
for i in prange(n_samples, schedule='static', nogil=True):
# Note: We use only half of the deviance loss. Therefore, there is
# no factor of 2.
y_pred = exp(raw_predictions[i])
gradients[i] = (y_pred - y_true[i]) * sample_weight[i]
hessians[i] = y_pred * sample_weight[i]


def _update_gradients_hessians_binary_crossentropy(
G_H_DTYPE_C [::1] gradients, # OUT
G_H_DTYPE_C [::1] hessians, # OUT
Expand Down
23 changes: 16 additions & 7 deletions sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,11 +765,13 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):

Parameters
----------
loss : {'least_squares', 'least_absolute_deviation'}, \
loss : {'least_squares', 'least_absolute_deviation', 'poisson'}, \
optional (default='least_squares')
The loss function to use in the boosting process. Note that the
"least squares" loss actually implements an "half least squares loss"
to simplify the computation of the gradient.
"least squares" and "poisson" losses actually implement
"half least squares loss" and "half poisson deviance" to simplify the
computation of the gradient. Furthermore, "poisson" loss internally
uses a log-link and requires ``y >= 0``
learning_rate : float, optional (default=0.1)
The learning rate, also known as *shrinkage*. This is used as a
multiplicative factor for the leaves values. Use ``1`` for no
Expand Down Expand Up @@ -875,7 +877,8 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
0.98...
"""

_VALID_LOSSES = ('least_squares', 'least_absolute_deviation')
_VALID_LOSSES = ('least_squares', 'least_absolute_deviation',
'poisson')

def __init__(self, loss='least_squares', learning_rate=0.1,
max_iter=100, max_leaf_nodes=31, max_depth=None,
Expand Down Expand Up @@ -908,14 +911,20 @@ def predict(self, X):
y : ndarray, shape (n_samples,)
The predicted values.
"""
# Return raw predictions after converting shape
# (n_samples, 1) to (n_samples,)
return self._raw_predict(X).ravel()
check_is_fitted(self)
# Return inverse link of raw predictions after converting
# shape (n_samples, 1) to (n_samples,)
return self.loss_.inverse_link_function(self._raw_predict(X).ravel())

def _encode_y(self, y):
# Just convert y to the expected dtype
self.n_trees_per_iteration_ = 1
y = y.astype(Y_DTYPE, copy=False)
if self.loss == 'poisson':
# Ensure y >= 0 and sum(y) > 0
if not (np.all(y >= 0) and np.sum(y) > 0):
raise ValueError("loss='poisson' requires non-negative y and "
"sum(y) > 0.")
return y

def _get_loss(self, sample_weight):
Expand Down
54 changes: 52 additions & 2 deletions sklearn/ensemble/_hist_gradient_boosting/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import ABC, abstractmethod

import numpy as np
from scipy.special import expit, logsumexp
from scipy.special import expit, logsumexp, xlogy

from .common import Y_DTYPE
from .common import G_H_DTYPE
Expand All @@ -19,11 +19,13 @@
from ._loss import _update_gradients_hessians_least_absolute_deviation
from ._loss import _update_gradients_hessians_binary_crossentropy
from ._loss import _update_gradients_hessians_categorical_crossentropy
from ._loss import _update_gradients_hessians_poisson
from ...utils.stats import _weighted_percentile


class BaseLoss(ABC):
"""Base class for a loss."""

def __init__(self, hessians_are_constant):
self.hessians_are_constant = hessians_are_constant

Expand Down Expand Up @@ -153,6 +155,7 @@ class LeastSquares(BaseLoss):
the computation of the gradients and get a unit hessian (and be consistent
with what is done in LightGBM).
"""

def __init__(self, sample_weight):
# If sample weights are provided, the hessians and gradients
# are multiplied by sample_weight, which means the hessians are
Expand Down Expand Up @@ -195,6 +198,7 @@ class LeastAbsoluteDeviation(BaseLoss):

loss(x_i) = |y_true_i - raw_pred_i|
"""

def __init__(self, sample_weight):
# If sample weights are provided, the hessians and gradients
# are multiplied by sample_weight, which means the hessians are
Expand Down Expand Up @@ -265,6 +269,51 @@ def update_leaves_values(self, grower, y_true, raw_predictions,
# Note that the regularization is ignored here


class Poisson(BaseLoss):
"""Poisson deviance loss with log-link, for regression.

For a given sample x_i, Poisson deviance loss is defined as::

loss(x_i) = y_true_i * log(y_true_i/exp(raw_pred_i))
- y_true_i + exp(raw_pred_i))

This actually computes half the Poisson deviance to simplify
the computation of the gradients.
"""

def __init__(self, sample_weight):
super().__init__(hessians_are_constant=False)

inverse_link_function = staticmethod(np.exp)

def pointwise_loss(self, y_true, raw_predictions):
# shape (1, n_samples) --> (n_samples,). reshape(-1) is more likely to
# return a view.
raw_predictions = raw_predictions.reshape(-1)
# TODO: For speed, we could remove the constant xlogy(y_true, y_true)
# Advantage of this form: minimum of zero at raw_predictions = y_true.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we taking advantage of this advantage somewhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that I know of. Might be interesting to see, if it matters (at all).

loss = (xlogy(y_true, y_true) - y_true * (raw_predictions + 1)
+ np.exp(raw_predictions))
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
return loss

def get_baseline_prediction(self, y_train, sample_weight, prediction_dim):
y_pred = np.average(y_train, weights=sample_weight)
eps = np.finfo(y_train.dtype).eps
y_pred = np.clip(y_pred, eps, None)
return np.log(y_pred)

def update_gradients_and_hessians(self, gradients, hessians, y_true,
raw_predictions, sample_weight):
# shape (1, n_samples) --> (n_samples,). reshape(-1) is more likely to
# return a view.
raw_predictions = raw_predictions.reshape(-1)
gradients = gradients.reshape(-1)
hessians = hessians.reshape(-1)
_update_gradients_hessians_poisson(gradients, hessians,
y_true, raw_predictions,
sample_weight)


class BinaryCrossEntropy(BaseLoss):
"""Binary cross-entropy loss, for binary classification.

Expand Down Expand Up @@ -372,5 +421,6 @@ def predict_proba(self, raw_predictions):
'least_squares': LeastSquares,
'least_absolute_deviation': LeastAbsoluteDeviation,
'binary_crossentropy': BinaryCrossEntropy,
'categorical_crossentropy': CategoricalCrossEntropy
'categorical_crossentropy': CategoricalCrossEntropy,
'poisson': Poisson,
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import pytest
from numpy.testing import assert_allclose, assert_array_equal
from sklearn.datasets import make_classification, make_regression
from sklearn.datasets import make_low_rank_matrix
from sklearn.preprocessing import KBinsDiscretizer, MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.base import clone, BaseEstimator, TransformerMixin
from sklearn.pipeline import make_pipeline
from sklearn.metrics import mean_poisson_deviance
from sklearn.dummy import DummyRegressor

# To use this experimental feature, we need to explicitly ask for it:
from sklearn.experimental import enable_hist_gradient_boosting # noqa
Expand Down Expand Up @@ -192,6 +195,45 @@ def test_least_absolute_deviation():
assert gbdt.score(X, y) > .9


@pytest.mark.parametrize('y', [([1., -2., 0.]), ([0., 0., 0.])])
def test_poisson_y_positive(y):
# Test that ValueError is raised if either one y_i < 0 or sum(y_i) <= 0.
err_msg = r"loss='poisson' requires non-negative y and sum\(y\) > 0."
gbdt = HistGradientBoostingRegressor(loss='poisson', random_state=0)
with pytest.raises(ValueError, match=err_msg):
gbdt.fit(np.zeros(shape=(len(y), 1)), y)


def test_poisson():
# For Poisson distributed target, Poisson loss should give better results
# than least squares measured in Poisson deviance as metric.
rng = np.random.RandomState(42)
n_train, n_test, n_features = 500, 100, 100
X = make_low_rank_matrix(n_samples=n_train+n_test, n_features=n_features,
random_state=rng)
# We create a log-linear Poisson model and downscale coef as it will get
# exponentiated.
coef = rng.uniform(low=-2, high=2, size=n_features) / np.max(X, axis=0)
y = rng.poisson(lam=np.exp(X @ coef))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=n_test,
random_state=rng)
gbdt_pois = HistGradientBoostingRegressor(loss='poisson', random_state=rng)
gbdt_ls = HistGradientBoostingRegressor(loss='least_squares',
random_state=rng)
gbdt_pois.fit(X_train, y_train)
gbdt_ls.fit(X_train, y_train)
dummy = DummyRegressor(strategy="mean").fit(X_train, y_train)

for X, y in [(X_train, y_train), (X_test, y_test)]:
metric_pois = mean_poisson_deviance(y, gbdt_pois.predict(X))
# least_squares might produce non-positive predictions => clip
metric_ls = mean_poisson_deviance(y, np.clip(gbdt_ls.predict(X), 1e-15,
None))
metric_dummy = mean_poisson_deviance(y, dummy.predict(X))
assert metric_pois < metric_ls
assert metric_pois < metric_dummy


def test_binning_train_validation_are_separated():
# Make sure training and validation data are binned separately.
# See issue 13926
Expand Down
39 changes: 35 additions & 4 deletions sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def get_hessians(y_true, raw_predictions):
# ('binary_crossentropy', 0.3, 0),
('binary_crossentropy', -12, 1),
('binary_crossentropy', 30, 1),
('poisson', 12., 1.),
('poisson', 0., 2.),
('poisson', -22., 10.),
])
@pytest.mark.skipif(sp_version == (1, 2, 0),
reason='bug in scipy 1.2.0, see scipy issue #9608')
Expand All @@ -76,17 +79,19 @@ def fprime(x):
def fprime2(x):
return get_hessians(y_true, x)

optimum = newton(func, x0=x0, fprime=fprime, fprime2=fprime2)
optimum = newton(func, x0=x0, fprime=fprime, fprime2=fprime2,
maxiter=70, tol=2e-8)
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
assert np.allclose(loss.inverse_link_function(optimum), y_true)
assert np.allclose(loss.pointwise_loss(y_true, optimum), 0)
assert np.allclose(get_gradients(y_true, optimum), 0)
assert np.allclose(get_gradients(y_true, optimum), 0, atol=1e-7)


@pytest.mark.parametrize('loss, n_classes, prediction_dim', [
('least_squares', 0, 1),
('least_absolute_deviation', 0, 1),
('binary_crossentropy', 2, 1),
('categorical_crossentropy', 3, 3),
('poisson', 0, 1),
])
@pytest.mark.skipif(Y_DTYPE != np.float64,
reason='Need 64 bits float precision for numerical checks')
Expand All @@ -100,6 +105,8 @@ def test_numerical_gradients(loss, n_classes, prediction_dim, seed=0):
n_samples = 100
if loss in ('least_squares', 'least_absolute_deviation'):
y_true = rng.normal(size=n_samples).astype(Y_DTYPE)
elif loss in ('poisson'):
y_true = rng.poisson(size=n_samples).astype(Y_DTYPE)
else:
y_true = rng.randint(0, n_classes, size=n_samples).astype(Y_DTYPE)
raw_predictions = rng.normal(
Expand All @@ -114,7 +121,7 @@ def test_numerical_gradients(loss, n_classes, prediction_dim, seed=0):

# Approximate gradients
# For multiclass loss, we should only change the predictions of one tree
# (here the first), hence the use of offset[:, 0] += eps
# (here the first), hence the use of offset[0, :] += eps
# As a softmax is computed, offsetting the whole array by a constant would
# have no effect on the probabilities, and thus on the loss
eps = 1e-9
Expand Down Expand Up @@ -164,6 +171,27 @@ def test_baseline_least_absolute_deviation():
assert baseline_prediction == pytest.approx(np.median(y_train))


def test_baseline_poisson():
rng = np.random.RandomState(0)

loss = _LOSSES['poisson'](sample_weight=None)
y_train = rng.poisson(size=100).astype(np.float64)
# Sanity check, make sure at least one sample is non-zero so we don't take
# log(0)
assert y_train.sum() > 0
baseline_prediction = loss.get_baseline_prediction(y_train, None, 1)
assert np.isscalar(baseline_prediction)
assert baseline_prediction.dtype == y_train.dtype
assert_all_finite(baseline_prediction)
# Make sure baseline prediction produces the log of the mean of all targets
assert_almost_equal(np.log(y_train.mean()), baseline_prediction)

# Test baseline for y_true = 0
y_train.fill(0.)
baseline_prediction = loss.get_baseline_prediction(y_train, None, 1)
assert_all_finite(baseline_prediction)


def test_baseline_binary_crossentropy():
rng = np.random.RandomState(0)

Expand Down Expand Up @@ -215,7 +243,8 @@ def test_baseline_categorical_crossentropy():
('least_squares', 'regression'),
('least_absolute_deviation', 'regression'),
('binary_crossentropy', 'classification'),
('categorical_crossentropy', 'classification')
('categorical_crossentropy', 'classification'),
('poisson', 'poisson_regression'),
])
@pytest.mark.parametrize('sample_weight', ['ones', 'random'])
def test_sample_weight_multiplies_gradients(loss, problem, sample_weight):
Expand All @@ -232,6 +261,8 @@ def test_sample_weight_multiplies_gradients(loss, problem, sample_weight):

if problem == 'regression':
y_true = rng.normal(size=n_samples).astype(Y_DTYPE)
elif problem == 'poisson_regression':
y_true = rng.poisson(size=n_samples).astype(Y_DTYPE)
else:
y_true = rng.randint(0, n_classes, size=n_samples).astype(Y_DTYPE)

Expand Down