Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Logistic regression coefficients (for feature importance) significantly differ from Scikit-learn #3645

Closed
cjnolet opened this issue Mar 22, 2021 · 3 comments
Labels
bug Something isn't working

Comments

@cjnolet
Copy link
Member

cjnolet commented Mar 22, 2021

Linking in the initial issue from our rapids-single-cell-examples repository: NVIDIA-Genomics-Research/rapids-single-cell-examples#29

TLDR; there has been an ongoing discrepancy of the resulting gene rankings between Scikit-learn and cuML. We originally thought a bug fix in cuML's regularization would solve this issue but it doesn't appear to have done so. This use-case specifically relies on regularization for feature ranking.

Here's a small reproducer:

import pandas as pd
import numpy as np
import cudf

penalty='l2'

X = np.load("scaled_normalized_5000_most_variable_genes.npy")
genes = pd.read_csv("genes.csv").index

print("Genes: %s" % genes)

cluster_labels = np.load("cluster_labels.npy")

print("Unique labels: %s" % (np.unique(cluster_labels)))

reference_indices = np.arange(len(genes), dtype=int)

# Get results from cuml
from cuml.linear_model import LogisticRegression as cuLR
cu_clf = cuLR(penalty=penalty)
cu_clf.fit(X, cluster_labels)

scores_all = cu_clf.coef_.T
scores_all[:, :]

scores = scores_all[0]
partition = cp.argpartition(scores, -20)[-20:]

print("gpu scores: %s" % scores)
partial_indices = cp.argsort(scores[partition])[::-1]
global_indices = reference_indices[partition][partial_indices]
v1 = set(genes[global_indices])

# Get results from sklearn
from sklearn.linear_model import LogisticRegression as skLR
clf = skLR(penalty=penalty, solver='lbfgs', multi_class="multinomial")
clf.fit(X, cluster_labels)

scores_all_2 = clf.coef_
scores_all_2[:, :]

scores = scores_all_2[0]

print("cpu scores: %s" % scores)

partition = np.argpartition(scores, -20)[-20:]
partial_indices = np.argsort(scores[partition])[::-1]
global_indices = reference_indices[partition][partial_indices]
v2 = set(genes[global_indices])

print(v1.difference(v2))

This code is executed on highly variable genes, so there shouldn't be much correlation at all between features. I've attached the datasets to run the example. The input has also been centered and normalized into z-scores.

Here's the output for penalty='l2'. The last line contains the elements that differed between the two rankings. I've also tried setting the regularization weight to a few different values without much benefit to the corresponding rankings.

Genes: RangeIndex(start=0, stop=5000, step=1)
Unique labels: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33.]
gpu scores: [0.9988345 1.017644  1.0146879 ... 0.9996884 1.0217109 0.9978616]
cpu scores: [ 9.36030523e-06  5.40392233e-02  6.03259959e-03 ...  6.87140608e-05
  2.09311302e-02 -5.89751794e-04]
{2945, 3235, 3213, 4816, 4628, 3220, 87, 2713, 4157}

Here are the correlations between the coefficients for each class:

for i in range(cu_clf.coef_.shape[1]):
    print("label=%s, corr=%s" % (i, np.corrcoef(cu_clf.coef_.T[i], clf.coef_[i])[0,1]))

Output:

label=0, corr=0.8722224532465772
label=1, corr=0.9071688027178909
label=2, corr=0.8594651479788517
label=3, corr=0.7429112395548328
label=4, corr=0.8959628180657048
label=5, corr=0.8738619348785087
label=6, corr=0.7677549532252999
label=7, corr=0.8148034429585321
label=8, corr=0.9224657818847507
label=9, corr=0.9213459227542046
label=10, corr=0.7954198534191471
label=11, corr=0.9334880778429713
label=12, corr=0.9333968992708077
label=13, corr=0.6589866893939161
label=14, corr=0.6921231806295167
label=15, corr=0.7771132972291463
label=16, corr=0.9370104163617797
label=17, corr=0.8695761183736089
label=18, corr=0.6340255296688718
label=19, corr=0.8089468431354706
label=20, corr=0.8858338331801668
label=21, corr=0.8668927858705701
label=22, corr=0.9581991893962013
label=23, corr=0.8915519718694837
label=24, corr=0.8062668358508315
label=25, corr=0.7936559811933184
label=26, corr=0.7107820580088036
label=27, corr=0.8627937167303378
label=28, corr=0.9164664452598484
label=29, corr=0.8991919117085125
label=30, corr=0.8066033530177997
label=31, corr=0.9412297580697218
label=32, corr=0.8399293870289924
label=33, corr=0.8948670340898179

Here's the data files to run the MRE: https://drive.google.com/file/d/1SU5EVP8Om0Q7ZBo9ifkk4AYlIvWNTO32/view?usp=sharing

@cjnolet cjnolet added the bug Something isn't working label Mar 22, 2021
@cjnolet cjnolet added this to Needs prioritizing in Bug Squashing via automation Mar 22, 2021
@cjnolet
Copy link
Member Author

cjnolet commented Mar 22, 2021

cc @avantikalal

@JohnZed JohnZed added this to Issue-Needs prioritizing in v21.06 Release via automation Apr 8, 2021
@JohnZed JohnZed moved this from Issue-Needs prioritizing to Issue-P1 in v21.06 Release Apr 8, 2021
@tfeher
Copy link
Contributor

tfeher commented Apr 15, 2021

The cuML's QN solver stops much earlier than sklearn for this test case. A few issues with QN stopping condition was spotted by @achirkin while updating the QN solver default parameters. He will fix the issues with the stopping condition and have a look how these affect this issue.

rapids-bot bot pushed a commit that referenced this issue Apr 21, 2021
Change the starting coefficients of the QN model from `ones` to `zeros` for a few reasons:

 - This behavior matches better sklearn reference implementation
 - It makes the initial model state to predict all classes with the same probabilities (for both sigmoid and softmax losses)
 - It makes the model converge faster in some cases

In addition, it enables the `warm_start` feature (same as in sklearn).

Contributes to solving #3645

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #3774
rapids-bot bot pushed a commit that referenced this issue Apr 23, 2021
…sklearn closer (#3766)

Change the  QN solver (logistic regression) stopping conditions to avoid early stops in some cases (#3645):
  - primary: 
    ```
    || f' ||_inf <= fmag * param.epsilon
    ```
  - secondary:
    ```
    |f - f_prev| <= fmag * param.delta
    ```
where `fmag = max(|f|, param.epsilon)`.

Also change the default value of `tol` in QN solver  (which sets `param.delta`) to be consistent (`1e-4`) with the logistic regression solver.


#### Background

The original primary stopping condition is inconsistent with the sklearn reference implementation and is often triggered too early:
```
|| f' ||_2 <= param.epsilon * max(1.0, || x ||_2)
```

Here are the sklearn conditions for reference:
  - primary: 
    ```
    || grad f ||_inf <= gtol
    ```
  - secondary:
    ```
    |f - f_prev| <= ftol * max(|f|, |f_prev|, 1.0)
    ```
where `gtol` is and exposed parameter like `param.epsilon`, and `ftol = 2.2e-9` (hardcoded).
In addition, `f` in sklearn is scaled with the sample size (softmax or sigmoid over the dataset), so it's not exactly comparable to cuML version.

Currently, cuML checks the gradient w.r.t. the logistic regression weights `x`. As a result, the tolerance value goes up with the number of classes and features; the model stops too early and stays underfit. This may in part be a reason for #3645.
In this proposal I change the stopping condition to be closer to the sklearn version, but compromise the consistency with sklearn for better scaling (tolerance scales with the absolute values of the objective function). Without this scaling sklearn version seems to often run till the maximum iteration limit is reached.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #3766
rapids-bot bot pushed a commit that referenced this issue Apr 26, 2021
- Expose a parameter `delta` of the `QN` solver to control the loss value change stopping condition
 - Set a reasonable default for the parameter value that should keep the behavior close to sklearn in most cases

Note, this change does not expose `delta` to the wrapper class `LogisticRegression`.

Note, although this change does not break the python API, it does break the C/C++ API.

Contributes to solving #3645

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #3777
@achirkin
Copy link
Contributor

As we've anticipated, the issue was in the stopping conditions, forcing cuML solver to stop much earlier than sklearn. With PRs #3766, #3774, #3777, cuML is now much more careful as for when to stop the optimization process. Yet, if you use the default parameters for both variants, some difference still remains:

Genes: RangeIndex(start=0, stop=5000, step=1)
Unique labels: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33.]
gpu scores: [ 0.00011192  0.05343969  0.00733729 ... -0.00075544  0.01274147
 -0.00045122]
cpu scores: [ 9.36025545e-06  5.40392171e-02  6.03260444e-03 ...  6.87133030e-05
  2.09311163e-02 -5.89751603e-04]
{1442, 4628}

However, this discrepancy now is due to sklearn's default iteration limit 100 (while cuML's is 1000). Changing the sklearn iteration limit to 1000 produces almost identical scores compared to cuML.

Genes: RangeIndex(start=0, stop=5000, step=1)
Unique labels: [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33.]
gpu scores: [ 0.00011524  0.0538229   0.00724842 ... -0.00081786  0.01299944
 -0.00045121]
cpu scores: [ 9.20957968e-05  5.32906829e-02  7.21765057e-03 ... -7.04705335e-04
  1.29110687e-02 -4.55303706e-04]
set()

Here is the convergence plot after the fixes (1000 iterations max). Mind the constant difference of loss functions, it suggests the results must be very close:
issue-3645

With this I believe we can close the issue now :)

Bug Squashing automation moved this from Needs prioritizing to Closed Apr 27, 2021
v21.06 Release automation moved this from Issue-P1 to Done Apr 27, 2021
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this issue Oct 9, 2023
…idsai#3774)

Change the starting coefficients of the QN model from `ones` to `zeros` for a few reasons:

 - This behavior matches better sklearn reference implementation
 - It makes the initial model state to predict all classes with the same probabilities (for both sigmoid and softmax losses)
 - It makes the model converge faster in some cases

In addition, it enables the `warm_start` feature (same as in sklearn).

Contributes to solving rapidsai#3645

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#3774
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this issue Oct 9, 2023
…sklearn closer (rapidsai#3766)

Change the  QN solver (logistic regression) stopping conditions to avoid early stops in some cases (rapidsai#3645):
  - primary: 
    ```
    || f' ||_inf <= fmag * param.epsilon
    ```
  - secondary:
    ```
    |f - f_prev| <= fmag * param.delta
    ```
where `fmag = max(|f|, param.epsilon)`.

Also change the default value of `tol` in QN solver  (which sets `param.delta`) to be consistent (`1e-4`) with the logistic regression solver.


#### Background

The original primary stopping condition is inconsistent with the sklearn reference implementation and is often triggered too early:
```
|| f' ||_2 <= param.epsilon * max(1.0, || x ||_2)
```

Here are the sklearn conditions for reference:
  - primary: 
    ```
    || grad f ||_inf <= gtol
    ```
  - secondary:
    ```
    |f - f_prev| <= ftol * max(|f|, |f_prev|, 1.0)
    ```
where `gtol` is and exposed parameter like `param.epsilon`, and `ftol = 2.2e-9` (hardcoded).
In addition, `f` in sklearn is scaled with the sample size (softmax or sigmoid over the dataset), so it's not exactly comparable to cuML version.

Currently, cuML checks the gradient w.r.t. the logistic regression weights `x`. As a result, the tolerance value goes up with the number of classes and features; the model stops too early and stays underfit. This may in part be a reason for rapidsai#3645.
In this proposal I change the stopping condition to be closer to the sklearn version, but compromise the consistency with sklearn for better scaling (tolerance scales with the absolute values of the objective function). Without this scaling sklearn version seems to often run till the maximum iteration limit is reached.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Tamas Bela Feher (https://github.com/tfeher)

URL: rapidsai#3766
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this issue Oct 9, 2023
- Expose a parameter `delta` of the `QN` solver to control the loss value change stopping condition
 - Set a reasonable default for the parameter value that should keep the behavior close to sklearn in most cases

Note, this change does not expose `delta` to the wrapper class `LogisticRegression`.

Note, although this change does not break the python API, it does break the C/C++ API.

Contributes to solving rapidsai#3645

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#3777
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Development

No branches or pull requests

3 participants