-
Notifications
You must be signed in to change notification settings - Fork 1
/
calibrating_models.tex
127 lines (96 loc) · 14.9 KB
/
calibrating_models.tex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
\section{\Ourcal{}}
\label{sec:calibrating_models}
Section~\ref{sec:challenges-measuring} shows that the problem with scaling methods is we cannot estimate their calibration error. The upside of scaling methods is that if the function family\pl{assume $O(1)$ parameters} has at least one function that can achieve calibration error $\epsilon$, they require $O(1/\epsilon^2)$ samples to reach calibration error $\epsilon$, while histogram binning requires $O(B/\epsilon^2)$ samples. Can we devise a method that is sample efficient to calibrate and one where it's possible to estimate the calibration error?\pl{at this point, maybe skip the rhetorical question and streamline into: we propose \ourcal{} which is both ...}
To achieve this, we propose \ourcal{} (Figure~\ref{fig:var_red_binning}) where we first fit a scaling function, and then bin the outputs of the scaling function.
% Note that in previous work binning the outputs of a scaling function was used for evaluation whereas here it is used for the method itself.
\subsection{Algorithm}
We split the recalibration data $T$ of size $n$ into 3 sets: $T_1$, $T_2$, $T_3$. \Ourcal{}, illustrated in Figure~\ref{fig:variance_reduced_illustration}, outputs $\hat{g_{\mathcal{B}}}$ such that $\hat{g_{\mathcal{B}}} \circ f$ has low calibration error:
\textbf{Step 1 (Function fitting):} Select $g = \argmin_{g \in \mathcal{G}} \sum_{(z, y) \in T_1} (y - g(z))^2$.
\pl{$g$ is used in both inside argmin and out; should be $\hat g$ outside?}
\pl{say this is Platt scaling}
\textbf{Step 2 (Binning scheme construction):} We choose the bins so that an equal number of $g(z_i)$ in $T_2$\pl{make this part more precise} land in each bin $I_j$ for each $j \in \{1, \dots, B\}$---this uniform-mass binning scheme as opposed to equal-width binning~\cite{guo2017calibration} is essential for being able to estimate the calibration error in Section~\ref{sec:verifying_calibration}.
\pl{should cite the papers that do "our scheme"}
\textbf{Step 3 (Discretization):} Discretize $g$, by outputting the average $g$ value in each bin---these are the gray circles in Figure~\ref{fig:var_red_binning}. Let $\mu(S) = \frac{1}{|S|} \sum_{s \in S} s$ denote the mean of a set of values $S$.
Let $\hat{\mu}[j] = \mu(\{ g(z_i) \; | \; g(z_i) \in I_j \wedge (z_i, y_i) \in T_3 \})$ be the mean of the $g(z_i)$ values that landed in the $j$-th bin.
Recall that if $z \in I_j$, \pl{use $\mathcal B$ as discussed before; this notation should appear in step 2 to connect things} $\beta(z) = j$ is the interval z lands in.
Then we set $\hat{g_{\mathcal{B}}}(z) = \hat{\mu}[\beta(g(z))]$---that is we simply output the mean value in the bin that $g(z)$ falls in.
\pl{why is $g$ not hatted and $\hat{g_\mathcal{B}}$ hatted? make consistent}
\ak{one reason is that I use $g_{\mathcal{B}}$ below. This is the binned version of $g$, assuming infinite data for binning. $\hat{g_\mathcal{B}}$ is the empirically binned version of $g$. The hat refers to the empirical binning. What do you think?}
\pl{sure, but then should be $\hat g$ for all of the above here}
\subsection{Analysis}
We now show that \ourcal{} requires $O(B + 1/\epsilon^2)$ samples to calibrate, and in Section~\ref{sec:verifying_calibration} we show that we can efficiently measure its calibration error. For the main theorem, we make some standard regularity assumptions on $\mathcal{G}$ which we formalize in Appendix~\ref{sec:calibrating_models_appendix}. Our result is a generalization result---we show that if $\mathcal{G}$ contains some $g^*$ with low calibration error, then our method is \emph{at least} almost as well-calibrated as $g^*$ given sufficiently many samples.
\newcommand{\finalCalibText}{
Assume regularity conditions on $\mathcal{G}$ (finite parameters, injectivity, Lipschitz-continuity, consistency, twice differentiability). Given $\delta \in (0, 1)$, there is a constant $c$ such that \emph{for all} $B, \epsilon > 0$, with $n \geq c\Big(B\log{B} + \frac{\log{B}}{\epsilon^2}\Big)$ samples, \ourcal{} finds $\hat{g}_{\mathcal{B}}$ with $(\ce(\hat{g}_{\mathcal{B}}))^2 \leq 2\min_{g \in \mathcal{G}}(\ce(g))^2 + \epsilon^2$, with probability $\geq 1 - \delta$.
}
\begin{theorem}[Calibration bound]
\label{thm:final-calib}
\finalCalibText{}
\end{theorem}
\pl{hope the $L_2$ notation can just be removed}
\newtheorem*{finalCalib}{Restatement of Theorem~\ref{thm:final-calib}}
% and potentially worsen the downstream performance of the model. As such, practitioners should use \ourcal{} when it is important to ensure that their model is well-calibrated, but in other applications should also try directly using a scaling method.
Note that our method can potentially be better calibrated than $g^*$, because we bin the outputs of the scaling function, which reduces its calibration error (Proposition~\ref{prop:bin_low_bound}). While binning worsens the sharpness and can increase the mean-squared error of the model, in Proposition~\ref{prop:mse-finite-binning} we show that if we use many bins, binning the outputs cannot increase the mean-squared error by much.
We prove Theorem~\ref{thm:final-calib} in Appendix~\ref{sec:calibrating_models_appendix} but give a sketch here. Step 1 of our algorithm is Platt scaling, which simply fits a function $g$ to the data---standard results in asymptotic statistics show that $g$ converges in $O(\frac{1}{\epsilon^2})$ samples.
% As usual with asymptotic analysis, the constant $c$ can depend arbitrarily on $\delta$ and $\mathcal{G}$ \pldel{---making these dependencies explicit is an important direction for future research.} \pl{I don't think it's that important}
% and only shows that if $\mathcal{G}$ is well-chosen, for example if $\min_{g \in \mathcal{G}}\squaredce(g) = 0$, then for large enough $n$ \ourcal{} will have lower calibration error than histogram binning.
\pl{make it clear whether uniform-mass bins are needed for this theorem}
Step 3, where we bin the outputs of $g$, is the main step of the algorithm. If we had infinite data, Proposition~\ref{prop:bin_low_bound} showed that the binned version $g_{\bins{}}$ has lower calibration error than $g$, so we would be done. However we do not have infinite data---the core of our proof is to show that the empirically binned $\hat{g_{\bins{}}}$ is within $\epsilon$ of $g_{\bins{}}$ in $O(B + \frac{1}{\epsilon^2})$ samples, instead of the $O(B + \frac{B}{\epsilon^2})$ samples required by histogram binning. The intuition is in Figure~\ref{fig:variance_reduced_illustration}---the $g(z_i)$ values in each bin (gray circles in Figure~\ref{fig:var_red_binning}) are in a narrower range than the $y_i$ values (black crosses in Figure~\ref{fig:hist_binning}) and thus have lower variance so when we take the average we incur less estimation error. The perhaps surprising part is that we are estimating $B$ numbers with $\widetilde{O}(1/\epsilon^2)$ samples. In fact, there may be a small number of bins where the $g(z_i)$ values are not in a narrow range, but our proof still shows that the overall estimation error is small.
Our uniform-mass binning scheme allows us to estimate the calibration error efficiently (see Section~\ref{sec:verifying_calibration}), unlike for scaling methods where we cannot estimate the calibration error (Section~\ref{sec:challenges-measuring}).
Recall that we chose our bins so that each bin has an equal proportion of points in the recalibration set.
Lemma~\ref{lem:well-balanced} shows that this property approximately holds in the population as well.
This allows us to estimate the calibration error efficiently (Theorem~\ref{thm:final-ours}).
\begin{definition}[Well-balanced binning]
Given a binning scheme $\mathcal{B}$ of size $B$, and $\alpha \geq 1$. We say $\mathcal{B}$ is $\alpha$-well-balanced if for all $j$,
\[ \frac{1}{\alpha B} \leq \prob(Z \in I_j) \leq \frac{\alpha}{B}\]
\end{definition}
\newcommand{\wellBalancedText}{
For universal constant $c$, if $n \geq cB\log{\frac{B}{\delta}}$, with probability at least $1 - \delta$, the binning scheme $\bins{}$ we chose is 2-well-balanced.
}
\begin{lemma}
\label{lem:well-balanced}
\wellBalancedText{}
\end{lemma}
\newtheorem*{wellBalanced}{Restatement of Lemma~\ref{lem:well-balanced}}
While the way we choose bins is not novel~\cite{zadrozny2001calibrated}, we believe the guarantees around it are---not all binning schemes in the literature allow us to efficiently estimate the calibration error; for example, the binning scheme in~\cite{guo2017calibration} does not. Our proof of Lemma~\ref{lem:well-balanced} is in Appendix~\ref{sec:calibrating_models_appendix}. We use a discretization argument to prove the result---this gives a tighter bound than applying Chernoff bounds or a standard VC dimension argument which would tell us we need $O(B^2\log{\frac{B}{\delta}})$ samples.
\subsection{Experiments}
Our experiments on CIFAR-10 and ImageNet show that in the low-data regime, for example when we use $\leq 1000$ data points to recalibrate, \ourcal{} produces models with much lower calibration error than histogram binning. The uncalibrated model outputs a confidence score associated with each class. We recalibrated each class separately as in~\cite{zadrozny2002transforming}, using $B$ bins per class, and evaluated calibration using the marginal calibration error (Definition~\ref{dfn:marginal-ce}).
We describe our experimental protocol for CIFAR-10.
The CIFAR-10 validation set has 10,000 data points. We sampled, with replacement, a recalibration set of 1,000 points. We ran either \ourcal{} (we fit a sigmoid in the function fitting step) or histogram binning and measured the marginal calibration error on the entire set of 10K points.
% \footnote{This is equivalent to using the empirical distribution on the 10K validation points as the true data distribution, and comparing how these methods perform. We do this since we cannot measure the ground truth calibration error.}
We repeated this entire procedure 100 times and computed mean and 90\% confidence intervals, and we repeated this varying the number of bins $B$. Figure~\ref{fig:marginal_calibrator_comparison_cifar} shows that \ourcal{} produces models with lower calibration error, for example $35\%$ lower calibration error when we use 100 bins per class.
\pl{what method do we use to estimate calibration error, plugin? should say this and say that we could use a fancier method (Section 5)}
Using more bins allows a model to produce more fine-grained predictions, e.g.~\cite{brocker2012empirical} use $B = 51$ bins, which improves the quality of predictions as measured by the mean-squared error---Figure~\ref{fig:cifar_calibrator_cmp_mse_ce} shows that our method achieves better mean-squared errors for any given calibration constraint.
% \footnote{By the decomposition of the mean-squared error, this means our method achieves better `sharpness' as well.}
More concretely, the figure shows a scatter plot of the mean-squared error and squared calibration error for histogram binning and \ourcal{} when we vary the number of bins. For example, if we want our models to have a calibration error $\leq 0.02 = 2\%$ we get a $9\%$ lower mean-squared error. In Appendix~\ref{sec:calibrating_models_appendix_experiments} we show that we get \emph{5x lower top-label calibration error on ImageNet}, and give further experiment details.
% We also run an ablation to show that if we have more data points then as the theory predicts histogram binning catches up to the variance-reduced calibrator, and that the way we construct bins is crucial, the result do not hold if we use the binning scheme in~\cite{guo2017calibration}.
\pl{bit confused - this paragraph seems to be mostly about using more bins to lower MSE, but then we're talking about lower calibration error again;
based on the Fig 3b, what I gathered is that simply because calibration error is lower for scaling-binning, for the sample calibration error,
you can use more bins, which corresponds to a lower MSE
}
\textbf{Validating theoretical bounds}: In Appendix~\ref{sec:calibrating_models_appendix_experiments} we run synthetic experiments to validate the bound in Theorem~\ref{thm:final-calib}. In particular, we show that if we fix the number of samples $n$, and vary the number of bins $B$, the squared calibration error for \ourcal{} is nearly constant but for histogram binning increases nearly linearly with $B$. For both methods, the squared calibration error decreases approximately as $1/n$---that is when we double the number of samples the squared calibration error halves.
% We ran synthetic experiments to validate the bound in Theorem~\ref{thm:final-calib}. Our theory predicts that to get calibration error $\epsilon$, $n \gtrsim 1/\epsilon^2 + B$ for our method (if the model family is well-specified) but for histogram binning $n \gtrsim B/\epsilon^2$. We set the ground truth \pl{mathbb} $P(Y = 1 \mid Z=z) = g(z)$ where $g$ is from the Platt scaling family $\mathcal{G}$. In the first experiment, we fix $B$ and vary $n$, computing 90\% confidence intervals from 1000 trials---we see that $1/\epsilon^2$ is approximately linear in $n$ for both calibrators. For example, when $B=10$ if we increase from $n=1000$ to $n=2000$ the $\squaredce$ of histogram binning decreases by $2.00 \pm 0.06$ times, and the $\squaredce$ of our method decreases by $1.98 \pm 0.09$ times. In the second experiment, we fix $n$ and vary $B$---as predicted by the theory, for \ourcal{} $1/\epsilon^2$ is nearly constant, but for histogram binning $1/\epsilon^2$ scales close to $1/B$. When $n = 2000$ and we increase from $5$ to $20$ bins, our method's $\squaredce$ decreases by $2\% \pm 7\%$ but for histogram binning it increases by $3.71 \pm 0.15$ \emph{times} \pl{to make it easier to compare, say $371\%$}.
% We give experimental details and plots in Appendix~\ref{sec:calibrating_models_appendix_experiments}.
% \pl{if we don't have room, could put this in Appendix since it's more of a sanity check}
\begin{figure}
\centering
\centering
\begin{subfigure}[b]{0.54\textwidth}
\centering
\includegraphics[width=0.8\textwidth]{marginal_ces.png}
\caption{Effect of number of bins on squared calibration error. \pl{make vertical axis start at 0}}
\label{fig:marginal_calibrator_comparison_cifar}
\end{subfigure}
\hfill
\begin{subfigure}[b]{0.44\textwidth}
\centering
\includegraphics[width=\textwidth]{marginal_mse_vs_ces.png}
\caption{Tradeoff between calibration and MSE. \pl{get rid of title of plot since we already have caption}}
\label{fig:cifar_calibrator_cmp_mse_ce}
\end{subfigure}
\caption{
(\textbf{Left}) Recalibrating using 1,000 data points on CIFAR-10, \ourcal{} achieves lower squared calibration error than histogram binning, especially when the number of bins $B$ is large.
(\textbf{Right}) For a fixed calibration error, \ourcal{} allows us to use more bins. This results in models with more predictive power which can be measured by the mean-squared error. Note the vertical axis range is $[0.04, 0.08]$ to zoom into the relevant region.
\pl{axis labels say L2 squared... in text we say $L^2$ - this is confusing; make consistent}
}
\label{fig:nan2}
\end{figure}