New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] Add Huber Estimator to sklearn linear models #5291

Merged
merged 1 commit into from Feb 25, 2016

Conversation

Projects
None yet
8 participants
@MechCoder
Member

MechCoder commented Sep 21, 2015

Add robust regression model that filters outliers based on http://statweb.stanford.edu/~owen/reports/hhu.pdf

  • Add fix for random OverflowErrors.
  • Add documentation to the helper function
  • Add extensive testing
  • Add narrative docs
  • Add example
  • Support for sparse data
  • Support sample_weights
@jmschrei

This comment has been minimized.

Member

jmschrei commented Sep 21, 2015

HuberLoss is already implemented in ensemble/GradientBoosting.py. I feel like we need to have a more centralized loss module, rather than reimplementing them when needed. What are your thoughts?

@MechCoder

This comment has been minimized.

Member

MechCoder commented Sep 21, 2015

I agree with you. Has there been an ongoing discussion already that I've missed?

@jmschrei

This comment has been minimized.

Member

jmschrei commented Sep 21, 2015

#5044 has some comments about it, but no consensus has been reached yet.

@MechCoder

This comment has been minimized.

Member

MechCoder commented Sep 21, 2015

I just looked through the code in sklearn.ensemble.gradient_boosting . I do agree that the code refactoring has to be done but I'm not sure if that should block this PR.

  1. The current code in HuberLoss does not take into account regularization it looks like the alpha is not the regularization alpha, but a way to calculate the epsilon value)
  2. I can also wrap my methods with calls to HuberLoss(n_c)(y, pred) and negative_gradient(). But there is some code duplication which might be expensive while calculating the loss and the gradient.

What do you think?

@MechCoder

This comment has been minimized.

Member

MechCoder commented Sep 21, 2015

Btw, this fixes #4990

@jmschrei

This comment has been minimized.

Member

jmschrei commented Sep 21, 2015

If there is a way of doing it currently, it may be worth benchmarking it to see how expensive it is. If there is a significant slowdown, you should go ahead if it solves a pressing issue.

@mblondel

This comment has been minimized.

Member

mblondel commented Sep 22, 2015

I feel like we need to have a more centralized loss module, rather than reimplementing them when needed. What are your thoughts?

I don't think this would be so useful. This PR uses the gradient w.r.t. linear model parameters (size is n_features). GradientBoosting uses the gradient with respect to the predictions (size is n_samples). So the code is different. The one part that could be possibly shared is computing the objective value but the regularizer is different.

@jmschrei

This comment has been minimized.

Member

jmschrei commented Sep 22, 2015

Okay, then!

@dbtsai

This comment has been minimized.

dbtsai commented Sep 26, 2015

I am working on robust regression for Spark's MLlib project based on Prof. Art Owen's paper, A robust hybrid of lasso and ridge regression. In MLlib/Breeze, since we don't support L-BFGS-B while the scaling factor in Eq.(6) \sigma has to be >= 0, we're going to replace it by \exp(\sigma). However, the second derivative of Huber loss is not continuous, this will cause some stability issue since L-BFGS requires it for guaranteed convergence. The workaround I'm going to implement will be Pseudo-Huber loss function which can be used as a smooth approximation of the Huber loss function, and ensures that derivatives are continuous for all degrees.

BTW, in robust regression, the scaling factor \sigma has to be estimated as well, and this is \epsilon in your case. This value can not be a constant. Imagine that, when the optimization is just started with some initial condition, if the initial guess is not good, then most of the training instances will be treated as outliners. As a result, \epsilon will be larger, but will be one of the parameters that will be estimated. See the details in Prof. Art Owen's paper in section 4. Thanks.

@MechCoder

This comment has been minimized.

Member

MechCoder commented Sep 28, 2015

Thanks for the comment and the link to the paper. (And it comes at a time when my benchmarks weren't looking too great)

Previously I was using grid search to find out the optimal value of epsilon, but it always corresponded to the lowest value of epsilon. (i,e assuming both X and y are centered and scaled)

To describe the paper in a short manner

  1. It seems that the epsilon over here corresponds more to the the parameter M in the paper which is said to be fixed at 1.35.
  2. In addition to that the gradient term y - X*w -c is scaled down by a factor sigma which makes the algorithm scaling independent.
  3. And since the new function as described in 8 is jointly convex, we could optimize sigma together with the coefficients right?
@dbtsai

This comment has been minimized.

dbtsai commented Sep 28, 2015

You are right. But you may want to replace \sigma to \exp(\alpha) so you don't need to have the condition that \sigma > 0. In theory, the hssian is not continuous, so LBFGS may not work well. But I don't know the exact impact on this. We may need to do some benchmark on this.

@dbtsai

This comment has been minimized.

dbtsai commented Sep 28, 2015

Also, for Pseudo-Huber loss, there is no proof that it will be jointly convex with \sigma. Although I guess it will if we go through the proof.

@MechCoder

This comment has been minimized.

Member

MechCoder commented Sep 28, 2015

Great I'll try two things.

Change the present loss function to accomodate minimizing sigma. (and after that)
And after that I can try the pseudo Huber loss to check if there is any noticeable change in convergence (etc).

@dbtsai

This comment has been minimized.

dbtsai commented Sep 28, 2015

Sounds great. Let me know the result, so I can learn from you when I implement this in Spark. Thanks.

@MechCoder

This comment has been minimized.

Member

MechCoder commented Sep 29, 2015

@dbtsai I've made changes to the loss function, but I'm not getting good results. Could you please verify if the loss function is right?

@MechCoder

This comment has been minimized.

Member

MechCoder commented Sep 29, 2015

Good results meaning that this is the plot that I generated from the coefficients :/

figure_1

The red line is the one from the HuberRegressor and the green is the RidgeRegression. As you can clearly see that it is not what it is supposed to look like.

@dbtsai

This comment has been minimized.

dbtsai commented Sep 29, 2015

@MechCoder I will compare the note I have in my home tonight. What do you mean you don't get good result? How do you test it? Also, when \epsilon is large, does it converge to normal LiR? Thanks.

@dbtsai

This comment has been minimized.

dbtsai commented Sep 29, 2015

Can you try to make \epsilon very large and see if you can reproduce the RidgeRegression?

@MechCoder

This comment has been minimized.

Member

MechCoder commented Sep 29, 2015

I tried that as well, but it seems that epsilon has almost no effect (for both very high and very low values ) since

since |(y - X'w) / exp(sigma)| < M

|y - X'w| < M*exp(sigma) . So the limit will change for every iteration the loss function is called no?

Or am I understanding it wrong?

@MechCoder

This comment has been minimized.

Member

MechCoder commented Sep 29, 2015

Just in case you are interested with the plot generation

(https://gist.github.com/MechCoder/8205f0fce4395a9ab907)

@MechCoder

This comment has been minimized.

Member

MechCoder commented Sep 29, 2015

oops seems like I made a mistake. Just a second.

@dbtsai

This comment has been minimized.

dbtsai commented Sep 29, 2015

Here is the note I compute dL/d\sigma Let's compare if we get the same formula.

img_0271

@MechCoder

This comment has been minimized.

Member

MechCoder commented Sep 29, 2015

@dbtsai I modified the loss function just before your comment :P . I have commented out the gradient out for now and I set approx_grad=True in fmin_l_bfgs_b.

I just wanted to have an approximate idea if the loss function is correct. After making the changes to the loss function (it should be clearer now), I am able to replicate the behavior of ridge for high values of epsilon. (note that the lines coincide)

figure_1

@MechCoder

This comment has been minimized.

Member

MechCoder commented Sep 29, 2015

I will readd the gradient and verify it from your note in a bit.

@dbtsai

This comment has been minimized.

dbtsai commented Sep 29, 2015

Cool. It looks nice! How about small \epsilon? Does it help to filter out the outlier?

@dbtsai

This comment has been minimized.

dbtsai commented Sep 29, 2015

Also, can you try smooth huber loss as well? Thanks.

@MechCoder

This comment has been minimized.

Member

MechCoder commented Sep 29, 2015

I just derived the gradient function and added it back. I verified from the formula in your note that it is correct. Anyhow, could you also check from the code if the gradient is right?

I can play with the code tomorrow with different values of epsilon and different data and check the performance.

After that we can check about the smooth huber loss.

Thanks.

grad[-1] -= n_outliers * epsilon**2 * exp(sigma)
grad[-1] -= squared_loss
return X.shape[0] * exp(sigma) + squared_loss + outlier_loss + alpha * np.dot(w, w), grad

This comment has been minimized.

@dbtsai

dbtsai Sep 30, 2015

In fact, I don't know how to handle regularization easily since that term should be \alpha / exp(sigma) in order to get the same result compared with the one without robust. As a result, that term will contribute to dL/da

outliers_true_pos = np.logical_and(linear_loss >= 0, outliers_true)
outliers_true_neg = np.logical_and(linear_loss < 0, outliers_true)
grad[:n_features] -= epsilon * X[outliers_true_pos, :].sum(axis=0)
grad[:n_features] += epsilon * X[outliers_true_neg, :].sum(axis=0)

This comment has been minimized.

@dbtsai

dbtsai Sep 30, 2015

I thought this is 2 * epsilon?

This comment has been minimized.

@MechCoder
@dbtsai

This comment has been minimized.

dbtsai commented Sep 30, 2015

Oh, it seems that you are mixing two notations. In general, I prefer to have / 2 in the square loss, but it seems that you are mixing both if I read it correctly.

  1. Since \epsilon has to be >= 1, unless we handle it based on eq. (14). But this is rare case, so let's make it working for M >= 1.
  2. Regularization may not work like this as the comment. Let's work on this later.
  3. It's easier to check in python. Can you add couple strong outliers in your synthetic data, and tune M from 1.0 to infinity to see how it perform? When M is very large, we should get the normal ridge regression.
  4. Although it's a scaling constant, I wonder how the robust regression will perform with/without dividing by two in the square loss.

Thanks.

with the number of samples
* :ref:`HuberRegressor <huber_regression>` should be faster than
:ref:`RANSAC <ransac_regression>` and :ref:`Theil Sen <theil_sen_regression>`
unless the number of samples are very large, i.e n_samples >> n_features.

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

nitpick: double backticks around n_samples and n_features

a break point above which it performs worst than OLS.
a break point above which it performs worse than OLS.
- HuberRegressor should not differ much in performance to both RANSAC

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

this seems to be slightly in conflict in what the user guide says. also: is performance here accuracy or training time?

Parameters
----------
w: ndarray, shape (n_features + 1,) or (n_features + 2,)

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

nitpick: space in front of : for consistency

----------
w: ndarray, shape (n_features + 1,) or (n_features + 2,)
Feature vector.
w[:n_features] gives the feature vector

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

I would say "is the" not "gives the". Also: feature vector? you mean coefficient, right?

loss: float
Huber loss.
gradient: ndarray, shape (n_features + 1,) or (n_features + 2,)

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

maybe shape len(w) ?

``|(y - X'w) / sigma| < epsilon`` and the absolute loss for the samples
where ``|(y - X'w) / sigma| > epsilon``, where w and sigma are parameters
to be optimized. The parameter sigma makes sure that if y is scaled up
or down by a certain factor, one does not need to rescale epsilon to acheive

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

achieve

----------
epsilon : float, greater than 1.0, default 1.35
The parameter epsilon controls the number of samples that should be
classified as outliers. The lesser the epsilon, the more robust it is

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

The smaller

to outliers.
max_iter : int, default 100
Number of iterations that scipy.optimize.fmin_l_bfgs_b should run for.

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

Maximum number of iterations that...

if the data is already centered around the origin.
tol : float, default 1e-5
The iteration will stop when max{|proj g_i | i = 1, ..., n} <= tol

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

How does that look in sphinx? maybe double backticks? Please check the rendering, ok? Maybe you want math?

scale_ : float
The value by which ``|y - X'w - c|`` is scaled down.
n_iter_: int

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

space before :

y : array-like, shape (n_samples,)
Target vector relative to X.
sample_weight: array-like, shape (n_samples,)

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

space before :

if self.epsilon < 1.0:
raise ValueError(
"epsilon should be greater than 1.0, got %f" % self.epsilon)

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

greater or equal?

assert_array_almost_equal(huber_sparse.coef_, huber_coef, 3)
def return_outliers(X, y, huber):

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

do we want to add an attribute to the estimator for that?

@@ -60,6 +60,26 @@ def safe_mask(X, mask):
return mask
def axis0_safe_slice(X, mask, len_mask):

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

have we reimplemented this in RANSAC?

This comment has been minimized.

@MechCoder

MechCoder Feb 25, 2016

Member

RANSAC raises an error if all the data points are classified as outliers. But in this case we don't have a fixed threshold, the term epsilon*sigma varies for each iteration

assert_greater(estimator.n_iter_, 0)
# HuberRegressor depends on scipy.optimize.fmin_l_bfgs_b
# which does return a n_iter for old versions of SciPy.

This comment has been minimized.

@amueller

amueller Feb 25, 2016

Member

does or doesn't?

@amueller

This comment has been minimized.

Member

amueller commented Feb 25, 2016

lgtm apart from nitpicks. Maybe adding an attribute that stores which points are outliers on the training set would be interesting.

@amueller

This comment has been minimized.

Member

amueller commented Feb 25, 2016

feel free to merge once tests pass. we can always add a an attribute for the outliers.
we do need an entry to whatsnew and a "versionadded" tag.

@MechCoder

This comment has been minimized.

Member

MechCoder commented Feb 25, 2016

I have 3 more minor todos

  1. Add a property for the outliers
  2. Check the gradient doc rendering
  3. Modify the documentation of the example.

Will address in a while

@agramfort

This comment has been minimized.

Member

agramfort commented Feb 25, 2016

@MechCoder

This comment has been minimized.

Member

MechCoder commented Feb 25, 2016

@agramfort done!!

Huber regressor
Add gradient calculation in _huber_loss_and_gradient

Add tests to check the correctness of the loss and gradient

Fix for old scipy

Add parameter sigma for robust linear regression

Add gradient formula to robust _huber_loss_and_gradient

Add fit_intercept option and fix tests

Add docs to HuberRegressor and the helper functions

Add example demonstrating ridge_regression vs huber_regression

Add sample_weight implementation

Add scaling invariant huber test

Remove exp and add bounds to fmin_l_bfgs_b

Add sparse data support

Add more tests and refactoring of code

Add narrative docs

review huber regressor

Minor additions to docs and tests

Minor fixes that deals with dealing with NaN values in targets
and old verions of SciPy and NumPy

Add HuberRegressor to robust estimator

Refactored computation of gradient and make docs render properly

Temp

Remove float64 dtype conversion

trivial optimizations and add a note about R

Remove sample_weights special_casing

address @amueller comments
@MechCoder

This comment has been minimized.

Member

MechCoder commented Feb 25, 2016

Tests pass!! Merging with master :D

MechCoder added a commit that referenced this pull request Feb 25, 2016

Merge pull request #5291 from MechCoder/huber_loss
[MRG+1] Add Huber Estimator to sklearn linear models

@MechCoder MechCoder merged commit 540c7c6 into scikit-learn:master Feb 25, 2016

4 checks passed

ci/circleci Your tests passed on CircleCI!
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
coverage/coveralls Coverage increased (+0.02%) to 94.209%
Details

@MechCoder MechCoder deleted the MechCoder:huber_loss branch Feb 25, 2016

@agramfort

This comment has been minimized.

Member

agramfort commented Feb 26, 2016

great work @MechCoder !

@amueller

This comment has been minimized.

Member

amueller commented Feb 26, 2016

thanks @MechCoder ! 🍻

@GaelVaroquaux

This comment has been minimized.

Member

GaelVaroquaux commented Feb 26, 2016

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment