# 4. Advanced Learning Rate Optimization

## What types of hyperparameters are there?
- Learning rate
- Batch size
- Weight decay
- Dropout
- And more!

## Why the concentration on learning rate?
- It's one of the most important hyperparameter that, if properly tuned gives 2 main benefits:
    - Better generalization (higher validation and test accuracy)
    - Faster convergence (less time spent on training)
    
## We've learnt 2 basic ways to optimize learning rate
1. Step-wise Decay 
2. Reduce on Loss Plateau Decay

## 2 new advanced learning rate optimization
- SGD restarts with snapshots
- SGD hypergradient descent

## 1. SGD Restarts with warm restarts
- **Problem**
    - With a larger model capacity (more layers and more complex), more number of local minima with different generalization ability
    - Taking a snapshot of the model at different local minima and combining them gives us an ensemble of models for free  
- **Benefits**
    - Warm restart in optimization allows us to improve generalization
        - Allows us to escape bad local minima
        - Allows us to explore wider loss surface
- **How**
    - Let the learning rate decay then restarting it to a high learning rate when it converges to a local minima
        - Using cosine annealing to decay our learning rate
        - $\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +\cos(\frac{T_{cur}}{T_{max}}\pi))$
            - $\eta_{min}$: Minimum learning rate
            - $\eta_{max}$: Maximum learning rate
            - $T_{max}$: Maximum number of epochs
            - $T_{cur}$: Current epoch
    - Before each reset, snapshot (save) the model's parameters via a checkpoint
    - When we've cycled through our snapshots, we can average the models' softmax outputs to obtain a final averaged categorical distribution (probability distribution over N possible outcomes)
        - Given:
            - $x$ be in the input
            - $m$ is the last m model's softmax outputs
            - $h_i (x)$ is the softmax output of snapshot $i$
        - Then the ensemble's output is $h_{ensemble} = \frac{1}{m} \sum^{m-1}_0 h_{M - i} (x)$ which is simple the average of the last $m$ snapshots where we have to choose $m$ 
    - Instead of just choosing the last $m$ models, we can choose m best snapshots which in practical experience, gives better results than just the last $m$ models

In [4]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.optim.lr_scheduler import _LRScheduler

In [5]:
class SGDR(_LRScheduler):
    """Set the learning rate of each parameter group using a cosine annealing
    schedule, where :math:`\eta_{max}` is set to the initial lr and
    :math:`T_{cur}` is the number of epochs since the last restart in SGDR:

    .. math::

        \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
        \cos(\frac{T_{cur}}{T_{max}}\pi))

    When last_epoch=-1, sets initial lr as lr.    
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        T_max (int): Maximum number of iterations.
        eta_min (float): Minimum learning rate. Default: 0.
        last_epoch (int): The index of last epoch. Default: -1.
        T_mult (float): Increase T_max by a factor of T_mult after every restart to improve performance. Default: 1.
        model (Model): The model to save.
        save_dir (str): Directory to save snapshots. Default: '/'.
        save_model (bool): Saves the model after every restart. Default: True.
    """

    def __init__(self, optimizer, T_max, model, eta_min=0, last_epoch=-1, T_mult=1, save_dir='/', save_model=True):
        self.T_max = T_max
        self.T_mult = T_mult
        self.Te = self.T_max
        self.eta_min = eta_min
        self.current_epoch = last_epoch
        
        self.model = model
        self.save_dir = save_dir
        self.take_snapshot = take_snapshot
        
        self.lr_history = []
        
        super(CosineAnnealingLR_with_Restart, self).__init__(optimizer, last_epoch)
    
    # Default function given by PyTorch
    def get_lr(self):
        new_lrs = [self.eta_min + (base_lr - self.eta_min) *
                (1 + math.cos(math.pi * self.current_epoch / self.Te)) / 2
                for base_lr in self.base_lrs]
        
        # Append learning rates to tracker so we can print the behavior
        self.lr_history.append(new_lrs)
        return new_lrs
    
    def step(self, epoch=None):
        if epoch is None:
        
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        self.current_epoch += 1
        
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr
        
        ## restart
        if self.current_epoch == self.Te:
            print("restart at epoch {:03d}".format(self.last_epoch + 1))
            
            if self.take_snapshot:
                torch.save({
                    'epoch': self.T_max,
                    'state_dict': self.model.state_dict()
                }, self.save_dir + "/" + 'snapshot_e_{:03d}.pth.tar'.format(self.T_max))
            
            ## reset epochs since the last reset
            self.current_epoch = 0
            
            ## reset the next goal
            self.Te = int(self.Te * self.T_mult)
            self.T_max = self.T_max + self.Te