# AM207 Final Report

## A Simple Baseline for Bayesian Uncertainty in Deep Learning

Group Members:

$
\text{Yiwen Wang}\\
\text{Zihao Xu}\\
\text{Ruoxi Yang}\\
\text{Liyang Zhao}\\
$

#### Problem statement - what is the problem the paper aims to solve?

Bayesian methods provide a natural probabilistic representation of uncertainty in deep learning, and previously had been a gold standard for inference with neural networks. However, existing approaches are often highly sensitive to hyperparameter choices, and hard to scale to modern datasets and architectures, which limits their general applicability in modern deep learning. The paper aims to develop an algorithm that is able to obtain convenient, efficient, accurate and well-calibrated predictions in a broad range of tasks in computer vision.

#### Context/scope - why is this problem important or interesting?

The problem is important since ultimately, machine learning models are used to make decisions and representing uncertainty is crucial for decision making. For example, in medical diagnoses and autonomous vehicles we want to protect against rare but costly mistakes. Deep learning models typically lack a representation of uncertainty, and provide overconfident and miscalibrated predictions. Therefore, this problem will be one of key issues that impede current deep learning models to become more practically applicable.

#### Existing work - what has been done in literature?

There are several existing work that has been done in literature:

**Bayesian Methods**:
1. **Markov chain Monte Carlo (MCMC)**: HMC requires full gradients, which is computationally intractable for modern neural networks. 

2. **Stochastic gradient HMC (SGHMC)**: SGHMC allows for stochastic gradients to be used in Bayesian inference, crucial for both scalability and exploring a space of solutions that provide good generalization. Theoretically, this method asymptotically sample from the posterior in the limit of infinitely small step sizes. In practice, using finite learning rates introduces approximation errors, and tuning stochastic gradient MCMC methods can be quite difficult.

3. **Variational Inference**: VI Fits a Gaussian variational posterior approximation over the weights of neural networks. While variational methods achieve strong performance for moderately sized networks, they are empirically noted to be difficult to train on larger architectures such as deep residual networks. Recent key advances in variational inference for deep learning typically focus on smaller-scale datasets and architectures. 

4. **Dropout Variational Inference**: DVI uses a spike and slab variational distribution to view dropout at test time as approximate variational Bayesian inference. Concrete dropout extends this idea to optimize the dropout probabilities as well. From a practical perspective, these approaches are quite appealing as they only require ensembling dropout predictions at test time, and they were succesfully applied to several downstream tasks.

5. **Laplace Approximations**: This method assume a Gaussian posterior, $\mathcal{N} (\theta_{∗},I(\theta_{*})^{−1})$, where $\theta_{∗}$ is a MAP estimate and $I(\theta_{∗})^{−1}$ is the inverse of the Fisher information matrix (expected value of the Hessian evaluated at $\theta_{∗}$). 

**SGD Based Method for Bayesian Deep Learning**

1. **Stochastic Gradient Descent (SGD)**: Stochastic gradient descent is an iterative method for optimizing an objective function with suitable smoothness properties (e.g. differentiable or subdifferentiable). It can be regarded as a stochastic approximation of gradient descent optimization, since it replaces the actual gradient (calculated from the entire data set) by an estimate thereof (calculated from a randomly selected subset of the data). Especially in high-dimensional optimization problems this reduces the computational burden, achieving faster iterations in trade for a lower convergence rate.

2. **Stochastic Weight Averaging (SWA)**: The main idea of SWA is to run SGD with a constant learning rate schedule starting from a pre-trained solution, and to average the weights of the models it traverses. Intuition for SWA comes from empirical observation that local minima at the end of each learning rate cycle tend to accumulate at the border of areas on loss surface where loss value is low. By taking the average of several such points, it is possible to achieve a wide, generalizable solution with even lower loss.

#### Contribution - what is the gap in literature that the paper is trying to fill? What is the unique contribution

Previous Bayesian methods all have their drawbacks as mentioned above. In this paper, authors find that theoretical analysis of the stationary distribution of SGD iterates suggests that the SGD trajectory contains useful information about the geometry of the posterior. Specifically, they find in the low-dimensional subspace spanned by SGD iterates the shape of the posterior distribution is approximately Gaussian within a basin of attraction. Therefore, they propose a new algorithm ``SWAG`` based on ``SWA``. ``SWAG`` additionally computes a low-rank plus diagonal approximation to the covariance of the iterates, which is used together with the ``SWA mean``, to define a Gaussian posterior approximation over neural network weights. 

#### Technical content (high level) - what are the high level ideas behind their technical contribution

SWAG fits a Gaussian using the SWA solution as the first moment and a low rank plus diagonal covariance also derived from the SGD iterates, forming an approximate posterior distribution over neural network weights; SWAG then samples from this Gaussian distribution to perform Bayesian model averaging. Authors find that the Gaussian distribution fitted to the first two moments of SGD iterates, with a modified learning rate schedule, captures the local geometry of the posterior surprisingly well.

#### Technical content (details) - highlight the relevant details that are important to focus on (e.g. if there's a model, define it; if there is a theorem, state it and explain why it's important, etc).


We first define the following parameters:
$$
\begin{cases}
\theta: \text{pretrained weights}\\
\eta: \text{learning rate}\\
T: \text{number of steps}\\
c: \text{moment update frequency}\\
K: \text{maximum number of columns in deviation matrix}\\
S: \text{number of samples in Bayesian model averaging}
\end{cases}
$$

`` **Training SWAG** ``:

We first initalize $\bar{\theta} \leftarrow \theta_{0}$ and $\bar{\theta^{2}} \leftarrow \theta_{0}^{2}$, then we perform the regular SGD update for T times (i.e for $i$ in 1,2, $\cdots$, T):
$$
\theta_{i} \leftarrow \theta_{i-1} - \eta \Delta_{\theta}\mathcal{L}(\theta_{i-1})
$$
Where $\mathcal{L}$ is the loss function.

Notice that, within the loop, we also need to update moments if $MOD(i,c) = 0$. We create a zero matrix D of shape $d\times K$, where $d = len(\theta).$

If $MOD(i,c) = 0$, then we denote number of models $n = \frac{i}{c}$, and then update the first and second moments as:
$$
\bar{\theta} = \frac{n\bar{\theta} + \theta_{i}}{n+1}
$$

$$
\bar{\theta^{2}} = \frac{n\bar{\theta^{2}} + \theta_{i}^{2}}{n+1}
$$

During the process, we will store $\theta_{i} - \bar{\theta}$ as a new column to the matrix D. If number of none $0$ columns in $D$ is equal to $K$, then we simply remove the first column of $D$ and then store $\theta_{i} - \bar{\theta}$ as a new column to the matrix D.

After the loop finishes, we denote $\theta_{SWA} = \bar{\theta}$ and $\Sigma_{diag}= \bar{\theta^{2}} - \bar{\theta}^{2}$


``**Test Bayesian Model Averaging**``

For $i$ in (1,2,$\cdots$,S), we will do the following:
$$
\text{Draw} \;\tilde{\theta_{i}} \sim \mathcal{N} (\theta_{SWA}, \frac{1}{2}\Sigma_{diag} + \frac{DD^{\top}}{2(K-1)})
$$
Update batch norm statistics with new sample:
$$
p(y_{*}\mid \text{Data}) += \frac{1}{S}p(y_{*}\mid \tilde{\theta_{i}})
$$

Note: Since $\frac{DD^{\top}}{2(K-1)}$ is computational expansive, we apply a trick when sampling from $\mathcal{N} (\theta_{SWA}, \frac{1}{2}\Sigma_{diag} + \frac{DD^{\top}}{2(K-1)})$. We will use the following identity instead:

$$
\tilde{\theta} = \theta_{SWA} + \frac{1}{\sqrt{2}}\cdot \Sigma^{\frac{1}{2}}_{diag}z_{1} + \frac{1}{\sqrt{2(K-1)}}D_{z_{2}}
$$
where $z_{1} \sim \mathcal{N}(0,I_{d}), z_{2} \sim \mathcal{N}(0,I_{K})$

``**Output**``

For the training process, we obatin $\theta_{swag}, \Sigma_{diag}, D$.

For the test Bayesian model, we get the approximation of the posterior: $p(y^{*}\mid \text{Data})$

#### Experiments - which types of experiments were performed? What claims were these experiments trying to prove? Did the results prove the claims?

The following experiments are conducted for a thorough empirical evaluation of SWAG. Compared to a range of baseline models, the paper demonstrates that SWAG achieves great performance in terms of predictions, uncertainty estimates on image classification tasks, as well as transfer learning and out-of-domain data detection.
 
1.	Accuracy & calibration of uncertainty:
 
To evaluate predictive accuracy and the quality of uncertainty estimate, the paper used image classification tasks such as CIFAR-10, CIFAR-100, ImageNet as dataset, and used different networks such as VGG-16, PreResNet-164, etc. It used negative log-likelihood (NLL) to reflect both the accuracy and the quality of predictive uncertainty. Specifically, to evaluate the calibration of uncertainty estimates, the paper used a variant of reliability diagrams and show the difference between a method’s confidence in its predictions and its accuracy. To produce this plot for a given method it splits the test data into 20 bins uniformly based on the confidence of a method (maximum predicted probability), then evaluates the accuracy and mean confidence of the method on the images from each bin, and plots the difference between confidence and accuracy. For a well-calibrated model, this difference should be close to zero for each bin.
 
Uncertainty calibration: The concept of calibration pertains to the agreement between predictions and the actual observed relative frequency. For example, for a binary classification task, if we were to inspect the samples that were estimated to be positive with p=0.85, we would expect that 85% of them are in fact positive.
 
Result:
In terms of predictive accuracy, comparing the NLL for all methods and datasets, the paper shows that SWAG and SWAG-diagonal perform comparably or better than all the considered alternatives, by having lowest NLL for all the tasks.
Regards to calibration of uncertainty estimates, the paper shows the reliability plots for all combinations of datasets and architectures. Since a perfectly calibrated network has no difference between confidence and accuracy and SWAG’s results are the closest to the horizontal well-calibrated line, from the plots the author concludes that SWAG is better calibrated than other alternatives.
 
2.	Comparison to ensembling SGD solutions:
 
The paper evaluated ensembles of independently trained SGD solutions (Deep Ensembles) on PreResNet-164 and CIFAR-100. Although the ensembles of 3 SGD solutions has high accuracy, the NLL it achieves is no smaller than a single SWAG solution, which means while the accuracy of this ensemble is high, SWAG solutions are much better calibrated. An ensemble of 5 SGD solutions achieves a similar result of A single SWAG, which means SWAG is 5 times more efficient to train. These results demonstrate that SWAG is better calibrated than ensembling SGD methods.
 
 
3.	Out-of-Domain Image Detection:
 
To evaluate SWAG on out-of-domain data detection, the paper trains a WideResNet on the data from five classes of the CIFAR-10 dataset, and then analyzes predictions of SWAG variants along with the baselines on the full test set. We expect the outputted class probabilities on objects that belong to classes that were not present in the training data to have high-entropy reflecting the model’s high uncertainty in its predictions, and considerably lower entropy on the images that are similar to those on which the network was trained.
 
By visualizing the histograms of predictive entropies on the in-domain and out-of-domain classes, and computed symmetrized KL divergence between the binned in and out of sample distributions (the larger the better), it shows that SWAG perform best on this measure.
 
 
4. 	Language Modeling with LSTMs:
 
The paper also applied SWAG to an LSTM network on language modeling tasks on Penn Treebank and WikiText-2 datasets, and it demonstrated that SWAG outperformed other baseline models in terms of test and validation perplexities.
 
Perplexity: is a metric used to judge how good a language model is. It can be defined as the inverse probability of the test set, normalized by the number of words. In simpler terms, we can think of it as a weighted branching factor: If we have a perplexity of 100, it means that whenever the model is trying to guess the next word it is as confused as if it had to pick between 100 words.
 
5.	Regression:
 
The paper also applies SWAG on a set of UCI regression tasks and compares to additional approximate BNN inference methods. Based on the test log-likelihoods, RMSEs and test calibration results, we can see that SWAG is competitive with these methods. Specifically, even though all models predict heteroscedastic uncertainty, SWAG outperforms other methods on three of the six dataset. Additionally, we note the strong performance of well-tuned SGD as a baseline against the other approximate inference methods. The author also compares the calibration (coverage of the 95% credible sets of SWAG and 95% confidence regions of SGD) of both SWAG and SGD. Note that neither is ever too over-confident (far beneath 95% coverage) and that SWAG is considerably better calibrated on four of the six datasets.
