sure. here's the clean mathematical breakdown, step by step:

---

### **1. Momentum**

exponential moving average of gradient — smooths direction:

$$
\begin{aligned}
m_t &= \beta m_{t-1} + (1 - \beta) g_t \\
\theta_t &= \theta_{t-1} - \eta \cdot m_t
\end{aligned}
$$

* $\beta \in [0,1)$: momentum coefficient (typically 0.9)
* $\eta$: learning rate
* $g_t = \nabla_\theta L(\theta_t)$: gradient at time $t$

---

### **2. RMSprop**

scales update by EMA of squared gradient — adaptive step size:

$$
\begin{aligned}
v_t &= \beta v_{t-1} + (1 - \beta) g_t^2 \\
\theta_t &= \theta_{t-1} - \eta \cdot \frac{g_t}{\sqrt{v_t} + \epsilon}
\end{aligned}
$$

* $\epsilon$: small constant (e.g. $10^{-8}$) to avoid divide-by-zero

---

### **3. Adam**

momentum + RMSprop + **bias correction**:

$$
\begin{aligned}
m_t &= \beta_1 m_{t-1} + (1 - \beta_1) g_t \quad &\text{(1st moment)} \\
v_t &= \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \quad &\text{(2nd moment)} \\
\hat{m}_t &= \frac{m_t}{1 - \beta_1^t} \quad &\text{(bias-corrected)} \\
\hat{v}_t &= \frac{v_t}{1 - \beta_2^t} \\
\theta_t &= \theta_{t-1} - \eta \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
\end{aligned}
$$

* typical values:
  $\beta_1 = 0.9,\ \beta_2 = 0.999,\ \epsilon = 10^{-8}$

---

each is a refinement on the last:

* **momentum** smooths updates.
* **RMSprop** adapts learning rate per parameter.
* **Adam** does both and fixes startup bias.

you can think of Adam as the fusion reactor built from the raw fuel of momentum and RMSprop.


In [None]:
# *gasp* Adam!
def adam_update_single(params, m, v, t, g, lr=1e-3, b1=0.9, b2=0.999, eps=1e-8):
    # g is gradient at timestep opt_state['t']
    t = opt_state['t'] + 1
    
    m = b1*opt_state['m'] + (1 - b1)*g
    v = b2*opt_state['v'] + (1 - b2)*jnp.square(g)

    m /= 1 - b1**t
    v /= 1 - b2**t
    
    params = opt_state['params'] - lr * (m / (jnp.sqrt(v) + eps))
    
    return {'params': params, 'm': m, 'v': v, 't': t}

@jax.jit
def adam_update(params, m, v, t, g, lr=1e-3, b1=0.9, b2=0.999, eps=1e-8):
    opt_state = jax.tree.map(lambda params, m, v, t, g: adam_update_single(params, m, v, t, g, lr, b1, b2, eps), params, m, v, t, g)
    
















    

In [None]:
# Adam!
@jax.jit
def adam_update(opt_state, X, y, lr=1e-3, b1=0.9, b2=0.999, eps=1e-8):
    grads = grad_loss(opt_state['params'], X, y)
    t = opt_state['t'] + 1

    m = jax.tree_util.tree_map(lambda m, g: b1 * m + (1 - b1) * g, opt_state['m'], grads)
    v = jax.tree_util.tree_map(lambda v, g: b2 * v + (1 - b2) * (g * g), opt_state['v'], grads)

    m_hat = jax.tree_util.tree_map(lambda m: m / (1 - b1**t), m)
    v_hat = jax.tree_util.tree_map(lambda v: v / (1 - b2**t), v)

    params = jax.tree_util.tree_map(
        lambda p, m, v: p - lr * m / (jnp.sqrt(v) + eps),
        opt_state['params'], m_hat, v_hat
    )

    return {'params': params, 'm': m, 'v': v, 't': t}