-
-
Notifications
You must be signed in to change notification settings - Fork 25.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Isotonic calibration changes rank-based test metrics values #16321
Comments
Thank you very much @dsleo for the detailed report. I started from your code and pushed the analysis further. If one introspect the calibrator, on can observe the following: y_pred_calib = isotonic.predict_proba(X_test)
print(roc_auc_score(y_test, y_pred_calib[:, 1]))
calibrator = isotonic.calibrated_classifiers_[0].calibrators_[0]
import matplotlib.pyplot as plt
plt.figure(figsize=(16, 10))
plt.plot(calibrator._necessary_X_, calibrator._necessary_y_) Note that this kind of plot is very interesting and there is another related PR (#16289) to make it possible to expose those thresholds for inspection reasons here. This is a bit weird to have a constant step wise function to map Furthermore a close look at the (x, y) pairs of thresholds look as follows (for the first 10 pairs): for x, y in zip(calibrator._necessary_X_[:10],
calibrator._necessary_y_[:10]):
print(x, y)
The y values are monotonic (but not strictly) I don't know why we want to use such a piece-wise constant mapping. One could instead use a piece-wise linear mapping that would be trivially strictly monotonic. Here is a quick hack to show that this would trivially fix the issue when using isotonic calibration for classifier calibration: X_thresholds_fixed = np.concatenate((calibrator._necessary_X_[::2],
calibrator._necessary_X_[-1:]))
y_thresholds_fixed = np.concatenate((calibrator._necessary_y_[::2],
calibrator._necessary_y_[-1:]))
plt.figure(figsize=(16, 10))
plt.plot(calibrator._necessary_X_, calibrator._necessary_y_)
plt.plot(X_thresholds_fixed, y_thresholds_fixed) Linearly interpolating using those thresholds makes the mapping strictly monotonic and the ROC-AUC of the original model is recovered: calibrator._build_f(X_thresholds_fixed, y_thresholds_fixed)
y_pred_calib = isotonic.predict_proba(X_test)
print(roc_auc_score(y_test, y_pred[:, 1]))
print(roc_auc_score(y_test, y_pred_calib[:, 1]))
Instead of the max values of each steps, one could have used mid-points. This could be explored in a fix in |
The above analysis is wrong: I made wrong expectations on the structure of the steps. They are not always exact steps. There are already piece-wise linear components in the default prediction function. Still it would be nice to have an option to make it possible to enforce a strict monotonicity option, maybe by adding a small eps on one of the edges whenever y is constant on a segment. |
Metrics that should not be impacted by Isotonic Calibration but are:
|
This would solve the problem but we would have to make clear that we are not strictly doing isotonic regression. I had a look at what R does. While the core stats package and a number of external packages implement isotonic regression but none seem to deal with this issue. There is also few mentions of this in literature (that I could fine). 'Predicting accurate probabilities with a ranking loss' mentions:
This would be a simple way to ensure rank stays consistent and we would still strictly be using isotonic regression but the API would be difficult to work out. The only other paper I found that touches on this 'Smooth Isotonic Regression: A New Method to Calibrate Predictive Models'. Their method would solve the rank problem but I think it is too complex for our purposes. Summary of their method:
|
Should we just add a note in the UG, something like
Slightly related, it would be interesting to discuss in the UG when one might prefer sigmoid vs isotonic. I personally have no idea. |
I am happy to add this to the doc. This was explained in some of the papers I looked at. |
please ping me on the PR :) |
I am curious if this can be worked into our |
sigmoid is parametric and strongly biased (assume the calibration curve of the un-calibrated model has a sigmoid shape) while isotonic calibration is non-parametric and does not make this assumption. In general, the non-parametric calibration would tend to overfit on dataset with a small number of samples while the parametric sigmoid calibration can underfit if the assumption is wrong. TL;DR: sigmpoid on small datasets, isotonic on large datasets. |
Not all application of isotonic regression are for probabilistic calibration of classifiers. For the calibration use case it might make sense to further impose strict monotonicity. This could be a non-default option for the |
Maybe @dsleo has suggestions for how to deal with this issue :) |
Thanks @lucyleeow for the references and the upcoming additions to the doc ! The solution of the second reference seems a bit expensive. Regarding the first reference, this is not clear to me:
Wouldn't that break monotonicity by substituting calibrated probabilities by their original predicted probabilities ? We can extend @ogrisel hack with another by adding linear interpolation on all constant sub-arrays of the Here's a draft code: def interpolate(min_v, max_v, lenght):
delta = max_v - min_v
eps = float(delta)/lenght
return np.arange(0, delta, eps)
necessary_y = calibrator._necessary_y_
necessary_y_fixed = necessary_y.copy()
sub_ = np.split(necessary_y_fixed,
np.nonzero(np.diff(necessary_y_fixed))[0] + 1)
n_splits = len(sub_)
for i in range(n_splits - 1):
sub_length = len(sub_[i])
if sub_length > 1:
min_v = sub_[i][0]
max_v = sub_[i+1][0]
correction = interpolate(min_v=min_v, max_v=max_v, lenght=sub_length)
sub_[i] += correction And we can check as a sanity check, that indeed calibrator._build_f(calibrator._necessary_X_, necessary_y_fixed)
y_pred_calib = isotonic.predict_proba(X_test)
print(roc_auc_score(y_test, y_pred[:, 1]))
print(roc_auc_score(y_test, y_pred_calib[:, 1])) gives
Does that seems a reasonable enough strategy ? It'll be nice to have something of that sort as an optional post-processing of |
Thanks @dsleo! I would be in favor of exploring that solution in a PR. |
Having read through the issue, I recommend to add a note to the UG as suggested by @NicolasHug in #16321 (comment), and close this issue without further action. I think that it is in fact a non-issue. Isotonic regression is doing well in predicting constant values for monotonicity violating points. It has very desirable properties which make it a very good choice for recalibration of a classifier. If the goal is to (auto-) calibrate a classifier such that afterwards we have
It is impossible to get a monotonic (in y) regression without changing the order of the predicted proba. Isotonic regression does this minimally in the sense that it sets some of them to the same value, but not more (differently). This part could go into the UG. Furthermore, the given example with logistic regression does not make sense to me. LR already uses a logistic function, so recalibrating it with a logistic function is like a no-op (apart, maybe, from using a different data set for this step). That's the reason why the score/metric does not change. In addition, LR is usually already quite good calibrated due to the balance property (a consequence of a canonical link function for GLMs). The way to improve it is then feature engineering but not recalibration. The example would be more convincing for SVM that does not give probability predictions out of the box, but really needs recalibration. At last, if the goal is to use the model for ranking (only), as suggested by looking at AUC, then there is no need for recalibration. |
I suppose that there is also a problem when isotonic calibration learns flat segments on smallish calibration sets (even with an originally but non-diagonal reliability curve) that would yield collapsed predictions on a final held-out test set which would cause degradation in the resolution measured by ROC AUC. I think it's good to document this problem with vanilla isotonic calibration but that should not prevent us to study alternatives that do not have this problem. For instance Centered Isotonic Calibration could be considered as a slightly regularized variant of isotonic regression that should would probably perform slightly better than vanilla isotonic calibration when the calibration set has a limited size both in terms of ECE^2 and full Brier score (or nll) thanks to it's ability to introduce fewer resolution degradation. |
I think it is better to calibrate on a large training set than on a (too) small additional set. I also think it is good to teach a user that her/his model performance measures have an uncertainty and the potential drop in AUC with iso cali is likely within this uncertainty. On top of this, using log loss or the Brier score is even better for measuring model performance:smirk: After all, they are sensitive to (mis-)calibration while AUC is not. My goal is to close this issue with #25900 such that further improvements can be discussed in new issues. |
In many settings, data is limited and in those case some form of regularization is generally the best thing we can practically do to improve the modeling (although no silver bullet either). As I just commented in #21454 (comment) I think CIR is a valid, practical solution to the original problem originally reported by @dsleo. But I agree it needs empirical validation to back those claims. +1 for leaving this issue open in the mean time (while merging #25900 ASAP as a net improvement). |
[As discussed with @ogrisel]
Describe the bug
Using isotonic calibration changes metrics values. This is because it is a non-strictly monotonic calibration. Sigmoid calibration being strictly monotonic doesn't suffer from this.
Steps/Code to Reproduce
Here is quick example where we split in three train/calibration/test sets and compare ROC AUC on the test set before and after calibrating for both isotonic and sigmoid..
The ROC AUC is then: 0.88368
After isotonic calibration, the ROC AUC becomes: 0.88338
As expected for sigmoid calibration, the ROC AUC is constant. 0.88368
Versions
System:
python: 3.8.1 | packaged by conda-forge | (default, Jan 29 2020, 15:06:10) [Clang 9.0.1 ]
executable: /Users/leodreyfusschmidt/opt/miniconda2/envs/isotonic/bin/python
machine: macOS-10.13.4-x86_64-i386-64bit
Python dependencies:
pip: 20.0.2
setuptools: 45.1.0.post20200119
sklearn: 0.22.1
numpy: 1.16.5
scipy: 1.4.1
Cython: None
pandas: None
matplotlib: 3.1.2
joblib: 0.14.1
Built with OpenMP: True
The text was updated successfully, but these errors were encountered: