Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency between coef_ and coef_path_ in LassoLars #17658

Open
FloWPs opened this issue Jun 22, 2020 · 5 comments
Open

Inconsistency between coef_ and coef_path_ in LassoLars #17658

FloWPs opened this issue Jun 22, 2020 · 5 comments

Comments

@FloWPs
Copy link

FloWPs commented Jun 22, 2020

Describe the issue linked to the documentation

Hello, the coef_ attribute from LassoLars returns the array of coefficients for the value of alpha given in parameter (1 by default). The coef_path_attribute is described in the doc as follow : "The varying values of the coefficients along the path".
However, the last column of the coef_path_ doesn't match the values of the coef_ attribute, whereas the alpha value for this column is corresponding to the one given in parameter, hence the values of coefficients should be the same.

For example, if I run :

from sklearn.linear_model import LassoLars
from sklearn.datasets import make_regression

X, y = make_regression(n_samples=50,
                       n_features=10,
                       n_informative=3)

lasso = LassoLars().fit(X, y)

coef_ will look like :

array([ 0.        , 39.89632517, 77.75976515, 88.68204422,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ])

and coef_path_ :

array([[  0.        ,   0.        ,   0.        ,   0.        ],
       [  0.        ,   0.        ,   0.        , 289.86068748],
       [  0.        ,   0.        , 122.80186215, 573.97268871],
       [  0.        , 123.33792921, 246.13979136, 584.906583  ],
       [  0.        ,   0.        ,   0.        ,   0.        ],
       [  0.        ,   0.        ,   0.        ,   0.        ],
       [  0.        ,   0.        ,   0.        ,   0.        ],
       [  0.        ,   0.        ,   0.        ,   0.        ],
       [  0.        ,   0.        ,   0.        ,   0.        ],
       [  0.        ,   0.        ,   0.        ,   0.        ]])

Suggest a potential alternative/fix

Either normalize the 2 attributes in the same way, or give additional information about the coef_path_ values in the doc.

@thomasjpfan
Copy link
Member

thomasjpfan commented Jun 26, 2020

The coef_ is normalized to the scale of X:

from sklearn.linear_model import LassoLars
from sklearn.datasets import make_regression

X, y = make_regression(n_samples=50,
                       n_features=10,
                       n_informative=3,
                       random_state=42)

lasso = LassoLars(random_state=42).fit(X, y)

X_scale = np.linalg.norm(X - np.average(X, axis=0), axis=0, ord=2) 

lasso.coef_ * X_scale gives:

>>> lasso.coef_ * X_scale
array([351.66080929,   0.        ,   0.        ,   0.        ,
         0.        , 235.00745073,   0.        ,   0.        ,
         0.        , 375.50644259])

which is the last entry of lasso.coef_path_. This happens when fit_intercept=True. Maybe the documentation could be more clear with this.

@FloWPs
Copy link
Author

FloWPs commented Jun 30, 2020

Thank you for your answer !

And what do you think of my suggestion to normalize the coef_path_ attribute as well ? I don't really see the point keeping the coefficients from coef_path_ at a different scale than that at which it has been defined for the coef__ attribute.

@thomasjpfan
Copy link
Member

And what do you think of my suggestion to normalize the coef_path_ attribute as well ?

This makes sense, but it will break backward compatibility. Maybe @rth or @agramfort can comment on this?

@agramfort
Copy link
Member

agreed this is inconsistent. I would consider this a bug fix and apply the scaling as suggested. If you not super worried about backward compatibility as coef_path_ is not used in any example

@jerryqhyu
Copy link

Why is this scaling done when fit_intercept is set to true even when normalize is False? Which of the following produces the correct result given X has already passed through StandardScaler?

  • ElasticNet(fit_intercept=True).fit(X,y)
  • ElasticNet(fit_intercept=False).fit(sm.add_constant(X),y)

this scaling affects downstream tasks that accesses coef_, like sklearn.feature_selection.SelectFromModel with a hardcoded threshold.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants