# Momentum 

## Why momentum matters

When one uses gradient descent to iteratively improve their parameter values, one might opt to incorporate a momentum coefficient in their training. The informal justification is that it makes the descent smoother (less "oscillatory") in multi-dimensional and stochastic settings, and thus training occurs faster. Let's see why this might be the case:

### Oscillations in gradient descent

First, we recall from the [previous chapter](./P13-C02-gd-and-sgd.ipynb) the formulation of gradient descent: for an objective function $f: \mathbb{R}^d \rightarrow \mathbb{R}$ taking $\mathbf{w} = [w_1, w_2, \ldots, w_d]^\top$ as input, we perform the update

$$\mathbf{w}_t := \mathbf{w}_{t-1} - \eta \nabla f(\mathbf{w}_{t}),$$

where $\nabla f(\mathbf{w}_t)$ is the gradient (vector of partial derivatives at $\mathbf{x}$) and $\eta$ is the learning rate.

**[TODO: continue here]**

In [None]:
from __future__ import print_function
import mxnet as mx
from mxnet import autograd
from mxnet import gluon
import numpy as np

# TODO: Plot f(x,y) = x^2 + 5y^2 in 3D space
# TODO: Plot f(x,y) as contour plot

Now suppose we want to find parameters $(w_1,w_2)$ that minimize the loss function $f(w_1,w_2)$. Let's see how gradient descent performs:

In [None]:
# TODO: Plot oscillating arrows on the contour plot

**[TODO: explain observation]**

### The momentum update

We now formulate the mathematical description of momentum for gradient descent. We replace the single update rule with the following two updates:

\begin{align}
\mathbf{v}_{t} &:= \gamma \mathbf{v}_{t-1} - \eta \nabla f(\mathbf{x}_t),\\
\mathbf{x}_{t} &:= \mathbf{x}_{t-1} + \mathbf{v}_{t}.
\end{align}

Here, the vector subscripts denote time (e.g., the current iteration). We have a new **velocity** vector $\mathbf{v}$ to keep track of, along with a new **momentum** hyperparameter $\gamma \in [0, 1)$ to set. When $\gamma = 0$, we get regular gradient descent; otherwise, we are updating $\mathbf{x}$ with a velocity vector $\mathbf{v}$, which is a weighted sum of the previous velocity vector and the negative gradient.

To see how this works over time, let's expand out the velocity term:

\begin{align}
\mathbf{v}_{t} &:= \gamma \mathbf{v}_{t-1} - \eta \nabla f(\mathbf{x}_t)\\
&:= \gamma (\gamma \mathbf{v}_{t-2} - \eta \nabla f(\mathbf{x}_{t-1})) - \eta \nabla f(\mathbf{x}_{t}) = \gamma^2 \mathbf{v}_{t-2} - \eta [\gamma \nabla f(\mathbf{x}_{t-1}) + \nabla f(\mathbf{x}_{t})]\\
&:= \gamma^2 (\gamma \mathbf{v}_{t-3} - \eta \nabla f(\mathbf{x}_{t-2})) - \eta [\gamma \nabla f(\mathbf{x}_{t-1}) - \nabla f(\mathbf{x}_{t})] = \gamma^3 \mathbf{v}_{t-3} - \eta [\gamma^2\nabla f(\mathbf{x}_{t-2}) + \gamma \nabla f(\mathbf{x}_{t-1}) + \nabla f(\mathbf{x}_{t})]\\
&:= \dotsb\\
&:= \gamma^t \mathbf{v}_0 -\eta \sum_{k=0}^{t} \gamma^k \nabla f(\mathbf{x}_{t-k}).
\end{align}

This shows that the current velocity $\mathbf{v}_{t}$ is an *exponential moving average* of the negative gradients (and whatever initial velocity $\mathbf{v}_0$ one chooses). Intuitively, we see that the newest gradient matters most, but all the past gradients play weaker and weaker roles as well.

### How momentum dampens oscillations

**[TODO: continue here]**

In [None]:
# TODO: Implement momentum update function

In [None]:
# TODO: Plot less oscillating arrows on the contour plot

**[TODO: Explain momentum (inertia) analogy]**

### Practicalities

**[TODO: expand]**

One might simply opt for $\mathbf{v}_0 = 0$. For the momentum parameter, practical default is $\gamma = 0.9$, but this can be tuned.

**[TODO: discuss bias correction?]**

## Nesterov momentum

**[TODO: discuss this?]**

## Experiments

**[TODO: compare convergence rates between GD, Momentum, Nesterov momentum?]**

For demonstrating the aforementioned gradient-based optimization algorithms, we use the regression problem in the [linear regression chapter](./P02-C01-linear-regression-scratch.ipynb) as a case study.

First, we import related libraries, generate the synthetic data, and construct the model:

In [None]:
#TODO: Reproduce linear regression; perform gradient, momentum, etc. and compare convergence

**[TODO: discuss results]**