# Optimization and Hyperparameters

#### Setup 

In [1]:
import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
import torch as t
from torch import Tensor, optim
import torch.nn.functional as F
from torchvision import datasets
from torch.utils.data import DataLoader, Subset
from typing import Callable, Iterable, Tuple, Optional
from jaxtyping import Float
from dataclasses import dataclass, replace
from tqdm.notebook import tqdm
from pathlib import Path
import numpy as np
from IPython.display import display, HTML

# Make sure exercises are in the path
chapter = r"chapter0_fundamentals"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part3_optimization"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import bar, imshow, plot_train_loss_and_test_accuracy_from_trainer
from part2_cnns.solutions import IMAGENET_TRANSFORM, ResNet34
from part2_cnns.solutions_bonus import get_resnet_for_feature_extraction
from part3_optimization.utils import plot_fn, plot_fn_with_points
import part3_optimization.tests as tests

# device = t.device("cuda" if t.cuda.is_available() else "cpu")
device = t.device("mps" if t.backends.mps.is_available() else "cpu")

MAIN = __name__ == "__main__"

## Optimizers

### Visualizing optimizations with pathological curvatures

In [2]:
def pathological_curve_loss(x: t.Tensor, y: t.Tensor):
    # Example of a pathological curvature. There are many more possible, feel free to experiment here!
    x_loss = t.tanh(x) ** 2 + 0.01 * t.abs(x)
    y_loss = t.sigmoid(y)
    return x_loss + y_loss


plot_fn(pathological_curve_loss)

### Exercise - implement plt_fn_with_sgd

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪

You should spend up to 15-20 minutes on this exercise.

In [22]:
def opt_fn_with_sgd(fn: Callable, xy: t.Tensor, lr=0.01, momentum=0.98, n_iters: int = 100):
    '''
    Optimize the a given function starting from the specified point.

    xy: shape (2,). The (x, y) starting point.
    n_iters: number of steps.
    lr, momentum: parameters passed to the torch.optim.SGD optimizer.

    Return: (n_iters, 2). The (x,y) BEFORE each step. So out[0] is the starting point.
    '''
    xys = t.zeros(n_iters, 2)
    optimizer = t.optim.SGD([xy], lr=lr, momentum=momentum)

    for i in range(n_iters):
        xys[i] = xy.detach()
        out = fn(xy[0], xy[1])
        out.backward()
        optimizer.step()
        optimizer.zero_grad()
    return xys

In [23]:
points = []

optimizer_list = [
    (optim.SGD, {"lr": 0.1, "momentum": 0.0}),
    (optim.SGD, {"lr": 0.02, "momentum": 0.99}),
]

for optimizer_class, params in optimizer_list:
    xy = t.tensor([2.5, 2.5], requires_grad=True)
    xys = opt_fn_with_sgd(pathological_curve_loss, xy=xy, lr=params['lr'], momentum=params['momentum'])

    points.append((xys, optimizer_class, params))

plot_fn_with_points(pathological_curve_loss, points=points)

### Exercise - implement SGD

Difficulty: 🔴🔴🔴🔴🔴
Importance: 🔵🔵🔵⚪⚪

You should spend up to 25-35 minutes on this exercise.

This is the first of several exercises like it. The first will probably take the longest.

In [46]:
class SGD:
    def __init__(self, params: Iterable[t.nn.parameter.Parameter], 
        lr: float, momentum: float = 0.0, weight_decay: float = 0.0):
        '''Implements SGD with momentum.

        Like the PyTorch version, but assume nesterov=False, maximize=False, and dampening=0
            https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD

        '''
        params = list(params) # turn params into a list (because it might be a generator)
        self.params = params
        self.lr = lr
        self.mu = momentum
        self.lmda = weight_decay
        self.gs = [t.zeros_like(param) for param in self.params]

    def zero_grad(self) -> None:
        '''Zeros all gradients of the parameters in `self.params`.
        '''
        for param in self.params:
            param.grad = None

    @t.inference_mode()     #similar to t.no_grad()
    def step(self) -> None:
        '''Performs a single optimization step of the SGD algorithm.
        update the params based on grad
        
        for all params
            - we are changing the gradient we are using to update params by momentum/weight decay
        '''
        
        for i, param in enumerate(self.params):
            g = param.grad
            if self.lmda > 0:
                g += self.lmda * param
            if self.mu > 0:
                g += self.mu * self.gs[i]
            param -= self.lr * g
            self.gs[i] = g.detach()

    def __repr__(self) -> str:
        return f"SGD(lr={self.lr}, momentum={self.mu}, weight_decay={self.lmda})"


tests.test_sgd(SGD)

Testing configuration:  {'lr': 0.1, 'momentum': 0.0, 'weight_decay': 0.0}
Testing configuration:  {'lr': 0.1, 'momentum': 0.7, 'weight_decay': 0.0}
Testing configuration:  {'lr': 0.1, 'momentum': 0.5, 'weight_decay': 0.0}
Testing configuration:  {'lr': 0.1, 'momentum': 0.5, 'weight_decay': 0.05}
Testing configuration:  {'lr': 0.2, 'momentum': 0.8, 'weight_decay': 0.05}
All tests in `test_sgd` passed!


### Exercise - implement RMSProp

In [75]:
class RMSprop:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.01,
        alpha: float = 0.99,
        eps: float = 1e-08,
        weight_decay: float = 0.0,
        momentum: float = 0.0,
    ):
        '''Implements RMSprop.

        Like the PyTorch version, but assumes centered=False
            https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html

        '''
        params = list(params) # turn params into a list (because it might be a generator)
        self.params = params
        self.lr = lr
        self.alpha = alpha
        self.eps = eps
        self.lmda = weight_decay
        self.mu = momentum
        self.vs = [t.zeros_like(p) for p in self.params]
        self.buf = [t.zeros_like(p) for p in self.params]


    def zero_grad(self) -> None:
        for param in self.params:
            param.grad = None

    @t.inference_mode()
    def step(self) -> None:
        for i, param in enumerate(self.params):
            g = param.grad
            if self.lmda != 0:
                g += self.lmda * param
            self.vs[i] = self.alpha * self.vs[i] + (1-self.alpha) * g**2
            if self.mu != 0:
                self.buf[i] = self.mu * self.buf[i] + g/(self.vs[i] ** 0.5 + self.eps)
                param -= self.lr * self.buf[i]
            else:
                param -= self.lr * g/(self.vs[i] ** 0.5 + self.eps)


    def __repr__(self) -> str:
        return f"RMSprop(lr={self.lr}, eps={self.eps}, momentum={self.mu}, weight_decay={self.lmda}, alpha={self.alpha})"


tests.test_rmsprop(RMSprop)

Testing configuration:  {'lr': 0.1, 'alpha': 0.9, 'eps': 0.001, 'weight_decay': 0.0, 'momentum': 0.0}
Testing configuration:  {'lr': 0.1, 'alpha': 0.95, 'eps': 0.0001, 'weight_decay': 0.05, 'momentum': 0.0}
Testing configuration:  {'lr': 0.1, 'alpha': 0.95, 'eps': 0.0001, 'weight_decay': 0.05, 'momentum': 0.5}
Testing configuration:  {'lr': 0.1, 'alpha': 0.95, 'eps': 0.0001, 'weight_decay': 0.05, 'momentum': 0.0}
All tests in `test_rmsprop` passed!


### Exercise - implement Adam

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to 15-20 minutes on this exercise.

In [80]:
class Adam:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.001,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-08,
        weight_decay: float = 0.0,
    ):
        '''Implements Adam.

        Like the PyTorch version, but assumes amsgrad=False and maximize=False
            https://pytorch.org/docs/stable/generated/torch.optim.Adam.html
        '''
        self.params = list(params) # turn params into a list (because it might be a generator)
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.lmda = weight_decay
        self.mt = [t.zeros_like(p) for p in self.params]
        self.vt = [t.zeros_like(p) for p in self.params]
        self.t = 1

    def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None

    @t.inference_mode()
    def step(self) -> None:
        
        for i, param in enumerate(self.params):
            g = param.grad
            if self.lmda != 0:
                g += self.lmda * param
            self.mt[i] = self.beta1 * self.mt[i] + (1 - self.beta1) * g
            self.vt[i] = self.beta2 * self.vt[i] + (1- self.beta2) * g ** 2

            mt_hat = self.mt[i] / (1 - self.beta1 ** self.t)
            vt_hat = self.vt[i] / (1 - self.beta2 ** self.t)

            param -= self.lr * mt_hat / (vt_hat ** 0.5 + self.eps)

        self.t += 1

    def __repr__(self) -> str:
        return f"Adam(lr={self.lr}, beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}, weight_decay={self.lmda})"


tests.test_adam(Adam)

Testing configuration:  {'lr': 0.1, 'betas': (0.8, 0.95), 'eps': 0.001, 'weight_decay': 0.0}
Testing configuration:  {'lr': 0.1, 'betas': (0.8, 0.9), 'eps': 0.001, 'weight_decay': 0.05}
Testing configuration:  {'lr': 0.2, 'betas': (0.9, 0.95), 'eps': 0.01, 'weight_decay': 0.08}
All tests in `test_adam` passed!


### Exercise - implement AdamW

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to 10-15 minutes on this exercise.

In [82]:
class AdamW:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.001,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-08,
        weight_decay: float = 0.0,
    ):
        '''Implements Adam.

        Like the PyTorch version, but assumes amsgrad=False and maximize=False
            https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
        '''
        self.params = list(params) # turn params into a list (because it might be a generator)
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.lmda = weight_decay
        self.mt = [t.zeros_like(p) for p in self.params]
        self.vt = [t.zeros_like(p) for p in self.params]
        self.t = 1

    def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None

    @t.inference_mode()
    def step(self) -> None:
        
        for i, param in enumerate(self.params):
            g = param.grad
            param -= self.lr * self.lmda * param
            self.mt[i] = self.beta1 * self.mt[i] + (1 - self.beta1) * g
            self.vt[i] = self.beta2 * self.vt[i] + (1 - self.beta2) * g ** 2

            mt_hat = self.mt[i] / (1 - self.beta1 ** self.t)
            vt_hat = self.vt[i] / (1 - self.beta2 ** self.t)
            param -= self.lr * mt_hat / (vt_hat ** 0.5 + self.eps)

        self.t += 1


    def __repr__(self) -> str:
        return f"AdamW(lr={self.lr}, beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}, weight_decay={self.lmda})"


tests.test_adamw(AdamW)

Testing configuration:  {'lr': 0.1, 'betas': (0.8, 0.95), 'eps': 0.001, 'weight_decay': 0.0}
Testing configuration:  {'lr': 0.1, 'betas': (0.8, 0.9), 'eps': 0.001, 'weight_decay': 0.05}
Testing configuration:  {'lr': 0.2, 'betas': (0.9, 0.95), 'eps': 0.01, 'weight_decay': 0.08}
All tests in `test_adamw` passed!


### Exercise - Plotting multiple optimisers

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to 10-15 minutes on this exercise.

In [85]:
def opt_fn(fn: Callable, xy: t.Tensor, optimizer_class, optimizer_hyperparams: dict, n_iters: int = 100):
    '''Optimize the a given function starting from the specified point.

    optimizer_class: one of the optimizers you've defined, either SGD, RMSprop, or Adam
    optimzer_kwargs: keyword arguments passed to your optimiser (e.g. lr and weight_decay)
    '''
    xys = t.zeros(n_iters, 2)
    optimizer = optimizer_class([xy], **optimizer_hyperparams)
    for i in range(n_iters):
        xys[i] = xy.detach()
        out = fn(xy[0], xy[1])
        out.backward()
        optimizer.step()
        optimizer.zero_grad()
    return xys

In [88]:
points = []

optimizer_list = [
    (SGD, {"lr": 0.03, "momentum": 0.99}),
    (RMSprop, {"lr": 0.02, "alpha": 0.99, "momentum": 0.8}),
    (Adam, {"lr": 0.3, "betas": (0.99, 0.99), "weight_decay": 0.005}),
    (AdamW, {"lr": 0.3, "betas": (0.99, 0.99), "weight_decay": 0.005}),
]

for optimizer_class, params in optimizer_list:
    xy = t.tensor([2.5, 2.5], requires_grad=True)
    xys = opt_fn(pathological_curve_loss, xy=xy, optimizer_class=optimizer_class, optimizer_hyperparams=params)
    points.append((xys, optimizer_class, params))

plot_fn_with_points(pathological_curve_loss, points=points)