Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Isotonic calibration changes rank-based test metrics values #16321

Open
dsleo opened this issue Jan 30, 2020 · 3 comments
Open

Isotonic calibration changes rank-based test metrics values #16321

dsleo opened this issue Jan 30, 2020 · 3 comments
Labels
Bug

Comments

@dsleo
Copy link
Contributor

@dsleo dsleo commented Jan 30, 2020

[As discussed with @ogrisel]

Describe the bug

Using isotonic calibration changes metrics values. This is because it is a non-strictly monotonic calibration. Sigmoid calibration being strictly monotonic doesn't suffer from this.

Steps/Code to Reproduce

Here is quick example where we split in three train/calibration/test sets and compare ROC AUC on the test set before and after calibrating for both isotonic and sigmoid..

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import roc_auc_score

X, y = datasets.make_classification(n_samples=100000, n_features=20,
                                    n_informative=18, n_redundant=0,
                                    random_state=42)

X_train, X_, y_train, y_ = train_test_split(X, y, test_size=0.5,
                                                    random_state=42)
X_test, X_calib, y_test, y_calib = train_test_split(X_, y_, test_size = 0.5,
                                                    random_state = 42)

clf = LogisticRegression(C=1.)
clf.fit(X_train, y_train)

y_pred = clf.predict_proba(X_test)
print(roc_auc_score(y_test, y_pred[:,1]))

The ROC AUC is then: 0.88368

isotonic = CalibratedClassifierCV(clf, method='isotonic', cv='prefit')
isotonic.fit(X_calib, y_calib)

y_pred_calib = isotonic.predict_proba(X_test)
print(roc_auc_score(y_test, y_pred_calib[:,1]))

After isotonic calibration, the ROC AUC becomes: 0.88338

isotonic = CalibratedClassifierCV(clf, method='sigmoid', cv='prefit')
isotonic.fit(X_calib, y_calib)

y_pred_calib = isotonic.predict_proba(X_test)
print(roc_auc_score(y_test, y_pred_calib[:,1]))

As expected for sigmoid calibration, the ROC AUC is constant. 0.88368

Versions

System:
python: 3.8.1 | packaged by conda-forge | (default, Jan 29 2020, 15:06:10) [Clang 9.0.1 ]
executable: /Users/leodreyfusschmidt/opt/miniconda2/envs/isotonic/bin/python
machine: macOS-10.13.4-x86_64-i386-64bit

Python dependencies:
pip: 20.0.2
setuptools: 45.1.0.post20200119
sklearn: 0.22.1
numpy: 1.16.5
scipy: 1.4.1
Cython: None
pandas: None
matplotlib: 3.1.2
joblib: 0.14.1

Built with OpenMP: True

@dsleo dsleo added the Bug label Jan 30, 2020
@ogrisel

This comment has been minimized.

Copy link
Member

@ogrisel ogrisel commented Jan 31, 2020

Thank you very much @dsleo for the detailed report. I started from your code and pushed the analysis further.

If one introspect the calibrator, on can observe the following:

y_pred_calib = isotonic.predict_proba(X_test)
print(roc_auc_score(y_test, y_pred_calib[:, 1]))

calibrator = isotonic.calibrated_classifiers_[0].calibrators_[0]

import matplotlib.pyplot as plt
plt.figure(figsize=(16, 10))
plt.plot(calibrator._necessary_X_, calibrator._necessary_y_)

fig1

Note that this kind of plot is very interesting and there is another related PR (#16289) to make it possible to expose those thresholds for inspection reasons here.

This is a bit weird to have a constant step wise function to map y_raw_pred to y_cal_pred.

Furthermore a close look at the (x, y) pairs of thresholds look as follows (for the first 10 pairs):

for x, y in zip(calibrator._necessary_X_[:10],
                calibrator._necessary_y_[:10]):
    print(x, y)
-8.31123478939124 0.0
-5.600958281618134 0.0
-5.590794325886634 0.00980392156862745
-5.044899575456449 0.00980392156862745
-5.042974497498811 0.017804154302670624
-4.303066186168206 0.017804154302670624
-4.301859876807849 0.024193548387096774
-4.151532225028019 0.024193548387096774
-4.149327735774874 0.027548209366391185
-3.763637782629291 0.027548209366391185

The y values are monotonic (but not strictly) but what is weird is that the x values are not even monotonic and go a bit back in time at each step. This is a bit fishy but fair enough, once we build the scipy.interpolation.interp1d function on the threshold in IsotonicRegression these artifacts can be ignored.

I don't know why we want to use such a piece-wise constant mapping. One could instead use a piece-wise linear mapping that would be trivially strictly monotonic. Here is a quick hack to show that this would trivially fix the issue when using isotonic calibration for classifier calibration:

X_thresholds_fixed = np.concatenate((calibrator._necessary_X_[::2],
                                     calibrator._necessary_X_[-1:]))
y_thresholds_fixed = np.concatenate((calibrator._necessary_y_[::2],
                                     calibrator._necessary_y_[-1:]))


plt.figure(figsize=(16, 10))
plt.plot(calibrator._necessary_X_, calibrator._necessary_y_)
plt.plot(X_thresholds_fixed, y_thresholds_fixed)

fig2

Linearly interpolating using those thresholds makes the mapping strictly monotonic and the ROC-AUC of the original model is recovered:

calibrator._build_f(X_thresholds_fixed, y_thresholds_fixed)
y_pred_calib = isotonic.predict_proba(X_test)

print(roc_auc_score(y_test, y_pred[:, 1]))
print(roc_auc_score(y_test, y_pred_calib[:, 1]))
0.8836876399954915
0.8836876399954915

Instead of the max values of each steps, one could have used mid-points. This could be explored in a fix in IsotonicRegression itself (maybe as a new option to keep the piece-wise constant mapping as the default behavior for backward compat reasons).

@ogrisel

This comment has been minimized.

Copy link
Member

@ogrisel ogrisel commented Feb 2, 2020

The above analysis is wrong: I made wrong expectations on the structure of the steps. They are not always exact steps. There are already piece-wise linear components in the default prediction function.

Still it would be nice to have an option to make it possible to enforce a strict monotonicity option, maybe by adding a small eps on one of the edges whenever y is constant on a segment.

@ogrisel ogrisel changed the title Isotonic calibration changes test metrics values Isotonic calibration changes rank-based test metrics values Feb 3, 2020
@ogrisel

This comment has been minimized.

Copy link
Member

@ogrisel ogrisel commented Feb 3, 2020

Metrics that should not be impacted by Isotonic Calibration but are:

  • ROC AUC
  • Average Precision
  • NDCG...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
2 participants
You can’t perform that action at this time.