In [None]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim


## ðŸ“¦ Gradients - Toy example

### Stochastic gradient descent SGD

* Vanilla SGD
$$
\theta_{t+1} = \theta_t - \eta \cdot \nabla_\theta \mathcal{L}(\theta_t)
$$

* SGD with Momentum
$$
\begin{aligned}
v_{t+1} &= \mu \cdot v_t - \eta \cdot \nabla_\theta \mathcal{L}(\theta_t) \\\\
\theta_{t+1} &= \theta_t + v_{t+1}
\end{aligned}
$$

* Nesterov Accelerated Gradient (NAG)
$$
\begin{aligned}
v_{t+1} &= \mu \cdot v_t - \eta \cdot \nabla_\theta \mathcal{L}(\theta_t + \mu \cdot v_t) \\\\
\theta_{t+1} &= \theta_t + v_{t+1}
\end{aligned}
$$

* Weight Decay (L2 Regularization)
$$
\nabla_\theta \mathcal{L}_{\text{reg}} = \nabla_\theta \mathcal{L} + \lambda \cdot \theta
$$ 

In [187]:
# Data
x = torch.tensor([1.0, 2.0, 3.0, 4.0]).reshape(-1, 1)
y = torch.tensor([0.0, 0.0, 1.0, 1.0]).reshape(-1, 1)
cls_target = torch.cat([y, 1 - y], dim=1)  # One-hot target: [1,0] or [0,1]

# Model: Linear -> Softmax
layer = nn.Sequential(
    nn.Linear(1, 2, bias=False),
    nn.Softmax(dim=1)
)

# Hyperparameters
lr = 1e-1
momentum = 0.9
weight_decay = 1e-2
v = torch.zeros_like(layer[0].weight)

# Loss function: manual cross-entropy
def loss_fn(pred, target):
    return -(target * torch.log(pred)).sum(dim=1).mean()

# Training loop
for _ in range(100):
    # 1. Lookahead step
    lookahead = layer[0].weight + momentum * v
    with torch.no_grad():
        layer[0].weight.copy_(lookahead)

    # 2. Forward pass and loss
    pred = layer(x)
    loss = loss_fn(pred, cls_target)

    # 3. Backward pass
    loss.backward()

    # 4. Weight decay (L2 regularization)
    grad = layer[0].weight.grad + weight_decay * layer[0].weight

    # 5. Momentum update
    v = momentum * v - lr * grad

    # 6. Final weight update (Nesterov style)
    with torch.no_grad():
        layer[0].weight.copy_(lookahead - momentum * v + v)

    # 7. Reset gradient
    layer[0].weight.grad.zero_()

# Output
with torch.no_grad():
    print("Predictions after training:\n", layer(x).squeeze().numpy())
    print("Final weights:\n", layer[0].weight.data.numpy())

Predictions after training:
 [[0.5707908  0.42920917]
 [0.6387993  0.36120063]
 [0.7016642  0.29833582]
 [0.7577372  0.2422628 ]]
Final weights:
 [[0.34792542]
 [0.06284701]]


## Adam (Adaptive moment)

## Weight initilization #TODO

kaiming weight initialization makes sure vairance through layers stays stable