# import

In [1]:
# export
from torch.optim import Optimizer

In [2]:
# export
import torch

In [3]:
# export
from functools import partial

In [4]:
# export
from collections import defaultdict

In [5]:
# export
from IPython.core import debugger as idb

# functions

## Adam ([2017](https://arxiv.org/abs/1412.6980v9))

### algorithm

$$\begin{align*}
\text{initial:}&\\
&m_0=0\\
&v_0=0\\
\\
\text{step:}&\\
&g_t\leftarrow\bigtriangledown_{\theta}\mathcal{L}(S_t,\theta_t)\\
\\
&m_t\leftarrow\beta_1\cdot m_{t-1}+(1-\beta_1)\cdot g_t\\
&v_t\leftarrow\beta_2\cdot v_{t-1}+(1-\beta_2)\cdot g_t^2\\
\\
&\hat{m}_t\leftarrow\frac{m_t}{1-\beta_1^t}\\
&\hat{v}_t\leftarrow\frac{v_t}{1-\beta_2^t}\\
\\
&\theta_{t+1}\leftarrow\theta_t-\eta\cdot\frac{\hat{m}_t}{\sqrt{\hat{v}_t}+\epsilon}
\end{align*}$$

### Adam

In [6]:
# export
class Adam(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
                 weight_decay=0):            
        # create parameter groups
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        super().__init__(params, defaults)
        
        # 初始化state
        self.state = self._get_init_state()

        
    def _get_init_state(self):
        state = defaultdict(dict)
        for group in self.param_groups:
            for p in group['params']:
                state_p = state[p]
                if len(state_p) == 0:
                    state_p['step'] = 0
                    state_p['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['update'] = torch.zeros_like(p, memory_format=torch.preserve_format)
        return state
        
    def _per_group(self,group):
        one_param = group['params'][0] # 获取任一个param
        t = self.state[one_param]['step']+1 # 此时step是未更新的，所以+1
        beta1, beta2 = group['betas']
        bias1 = 1-beta1**t
        bias2 = 1-beta2**t
        return bias1,bias2
        
        
    def update_state(self,group,state,grad):
        state['step'] += 1
        
        beta1, beta2 = group['betas']
        state['exp_avg'].mul_(beta1).add_(grad, alpha=1-beta1)
        state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=1-beta2)
        
        return state['exp_avg'], state['exp_avg_sq']
        
        
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            if len(group['params'])==0: continue
                
            bias1,bias2 = self._per_group(group)
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                # grad
                grad = p.grad
                
                # state
                state = self.state[p]
                m,v = self.update_state(group,state,grad)
                
                # debias
                m_hat = m / bias1
                v_hat = v / bias2
                
                # update value
                update = group['lr']*m_hat/(v_hat.sqrt().add(group['eps']))
                if group['weight_decay'] != 0: update += group['weight_decay']*p.data
                
                # do update
                p.sub_(update)
                
                # update state with new update-value
                state['update'] = update

## RAdam ([2019](https://arxiv.org/abs/1908.03265))

### algorithm

$$\begin{align*}
\text{initial:}&\\
&\rho_{\infty}\leftarrow\frac{2}{1-\beta_2}-1\\
&m_0\leftarrow0\\
&v_0\leftarrow0\\
\\
\text{step:}&\\
&g_t\leftarrow\bigtriangledown_{\theta}\mathcal{L}(S_t,\theta_t)\\
\\
&m_t\leftarrow\beta_1\cdot m_{t-1}+(1-\beta_1)\cdot g_t\\
&v_t\leftarrow\beta_2\cdot v_{t-1}+(1-\beta_2)\cdot g_t^2\\
\\
&\hat{m}_t\leftarrow\frac{m_t}{1-\beta_1^t}\\
\\
&\rho_t\leftarrow\rho_{\infty}-\frac{2t\cdot\beta_2^t}{1-\beta_2^t}\\
&\text{if}\space\space\rho_t>4\space\space\text{:}\\
&\quad\quad\hat{v}_t\leftarrow\frac{v_t}{1-\beta_2^t}\\
&\quad\quad r_t\leftarrow\frac{(\rho_{\infty}-4)(\rho_{\infty}-2)\rho_t}{(\rho_t-4)(\rho_t-2)\rho_{\infty}}\\
&\quad\quad\theta_{t+1}\leftarrow\theta_t-\eta\cdot\frac{\hat{m}_t}{\sqrt{r_t\cdot\hat{v}_t}+\epsilon}\\
&\text{else:}\\
&\quad\quad\theta_{t+1}\leftarrow\theta_t-\eta\cdot\hat{m}_t
\end{align*}$$

### RAdam

In [7]:
# export
from matplotlib import pyplot as plt
import numpy as np
def plot_RAdam_rectifier(beta,length):
    '''
    RAdam中有一个rectifier项，该项仅与梯度平方动量参数beta（在优化器中常称作beta2）有关，且随时间变化。
    该函数提供了rectifier项随时间变化的可视化。
    ------------------------------------
    参数：
    -- beta：梯度平方动量参数beta（在优化器中常称作beta2）
    -- length：绘制前多少个iteration的曲线。
    '''
    rho_inf = 2/(1-beta)-1
    
    def _cal_r(t):
        rho = rho_inf-(2*t*beta**t)/(1-beta**t)
        if rho>4: r = (rho_inf-4)*(rho_inf-2)*rho/((rho-4)*(rho-2)*(rho_inf))
        else: r = -1
        return rho,r
    
    rhos = []
    rs = []
    for t in range(1,length):
        rho,r = _cal_r(t)
        rhos += [rho]
        rs += [r]
        
    _,axs = plt.subplots(1,2,figsize=(10,5))
    axs[0].plot(rs,marker='o')
    axs[1].plot(rhos)
    axs[1].plot(range(length),np.zeros(length)+rho_inf)

In [8]:
# export
class RAdam(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
                 weight_decay=0):            
        # create parameter groups
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        super(RAdam, self).__init__(params, defaults)
        
        # 初始化state
        self.state = self._get_init_state()
        
        # 在RAdam中，betas[1]应该是固定的，所以可以在这里把rho_inf计算出来
        self.rho_infs = [2/(1-group['betas'][1])-1 for group in self.param_groups] # rho: 
        

    def _get_init_state(self):
        state = defaultdict(dict)
        for group in self.param_groups:
            for p in group['params']:
                state_p = state[p]
                if len(state_p) == 0:
                    state_p['step'] = 0
                    state_p['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['update'] = torch.zeros_like(p, memory_format=torch.preserve_format)
        return state
        
    def _per_group(self,group,rho_inf):
        one_param = group['params'][0] # 获取任一个param
        t = self.state[one_param]['step']+1 # 此时step是未更新的，所以+1
        beta1, beta2 = group['betas']
        bias1 = 1-beta1**t
        bias2 = 1-beta2**t

        rho = rho_inf - (2*t*beta2**t)/(1-beta2**t)
        if rho>4: 
            r = (rho_inf-4)*(rho_inf-2)*rho/((rho-4)*(rho-2)*rho_inf)
        else: 
            r = -1
        
        return bias1,bias2,r
            
        
    def _update_state(self,group,state,grad):
        state['step'] += 1
        
        beta1, beta2 = group['betas']
        state['exp_avg'].mul_(beta1).add_(grad, alpha=1-beta1)
        state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=1-beta2)
        
        return state['exp_avg'], state['exp_avg_sq']
        
        
    @torch.no_grad()
    def step(self):
        for group,rho_inf in zip(self.param_groups,self.rho_infs):
            if len(group['params'])==0: continue
                
            bias1,bias2,r = self._per_group(group,rho_inf)
            
            for p in group['params']:
                if p.grad is None:
                    continue
                    
                # grad
                grad = p.grad
                if group['weight_decay'] != 0:
                    grad = grad.add(p, alpha=group['weight_decay'])
                    
                # state
                state = self.state[p]
                m,v = self._update_state(group,state,grad)
                
                # debias
                m_hat = m / bias1

                # update value
                if r > 0:
                    v_hat = v / bias2
                    update = group['lr']*m_hat/(v_hat.mul(r).sqrt().add(group['eps']))
                else:
                    update = group['lr']*m_hat
                if group['weight_decay'] != 0: update += group['weight_decay']*p.data
                
                # do update
                p.sub_(update)
                
                # update state with new update-value
                state['update'] = update

## SAdam ([2019](https://arxiv.org/abs/1908.00700))

### algorithm

$$\begin{align*}
\text{initial:}&\\
&m_0=0\\
&v_0=0\\
\\
\text{step:}&\\
&g_t\leftarrow\bigtriangledown_{\theta}\mathcal{L}(S_t,\theta_t)\\
\\
&m_t\leftarrow\beta_1\cdot m_{t-1}+(1-\beta_1)\cdot g_t\\
&v_t\leftarrow\beta_2\cdot v_{t-1}+(1-\beta_2)\cdot g_t^2\\
\\
&\hat{m}_t\leftarrow\frac{m_t}{1-\beta_1^t}\\
&\hat{v}_t\leftarrow\frac{v_t}{1-\beta_2^t}\\
\\
&\theta_{t+1}\leftarrow\theta_t-\eta\cdot\frac{\hat{m}_t}{softplus(\sqrt{\hat{v}_t})}
\end{align*}$$

### SAdam

In [9]:
# export
class SAdam(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), smooth=50,
                 weight_decay=0):            
        # create parameter groups
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        super().__init__(params, defaults)
        
        # 初始化state
        self.state = self._get_init_state()

        # softplus
        self.sp = torch.nn.Softplus(smooth)
        
    def _get_init_state(self):
        state = defaultdict(dict)
        for group in self.param_groups:
            for p in group['params']:
                state_p = state[p]
                if len(state_p) == 0:
                    state_p['step'] = 0
                    state_p['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['update'] = torch.zeros_like(p, memory_format=torch.preserve_format)
        return state
        
    def _per_group(self,group):
        one_param = group['params'][0] # 获取任一个param
        t = self.state[one_param]['step']+1 # 此时step是未更新的，所以+1
        beta1, beta2 = group['betas']
        bias1 = 1-beta1**t
        bias2 = 1-beta2**t
        return bias1,bias2
        
        
    def update_state(self,group,state,grad):
        state['step'] += 1
        
        beta1, beta2 = group['betas']
        state['exp_avg'].mul_(beta1).add_(grad, alpha=1-beta1)
        state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=1-beta2)
        
        return state['exp_avg'], state['exp_avg_sq']
        
        
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            if len(group['params'])==0: continue
            
            bias1,bias2 = self._per_group(group)
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                # grad
                grad = p.grad
                
                # state
                state = self.state[p]
                m,v = self.update_state(group,state,grad)
                
                # debias
                m_hat = m / bias1
                v_hat = v / bias2
                
                # update value
                update = group['lr']*m_hat/(self.sp(v_hat.sqrt()))
                if group['weight_decay'] != 0: update += group['weight_decay']*p.data
                
                # do update
                p.sub_(update)
                
                # update state with new update-value
                state['update'] = update

In [10]:
# export
class SAdam(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), weight_decay=0,
                 smooth=50):            
        # create parameter groups
        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
        super(SRAdam, self).__init__(params, defaults)
        
        # 初始化state
        self.state = self._get_init_state()
        
        # 在RAdam中，betas[1]应该是固定的，所以可以在这里把rho_inf计算出来
        self.rho_infs = [2/(group['betas'][1])-1 for group in self.param_groups]
        
        # softplus
        self.sp = torch.nn.Softplus(smooth)

        
    def _get_init_state(self):
        state = defaultdict(dict)
        for group in self.param_groups:
            for p in group['params']:
                state_p = state[p]
                if len(state_p) == 0:
                    state_p['step'] = 0
                    state_p['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['update'] = torch.zeros_like(p, memory_format=torch.preserve_format)
        return state
        
    def _per_group(self,group,rho_inf):
        one_param = group['params'][0] # 获取任一个param
        t = self.state[one_param]['step']+1 # 此时step是未更新的，所以+1
        beta1, beta2 = group['betas']
        bias1 = 1-beta1**t
        bias2 = 1-beta2**t
        
        rho = rho_inf - (2*t*beta2**t)/(1-beta2**t)
        if rho>4: 
            r = (rho_inf-4)*(rho_inf-2)*rho/((rho-4)*(rho-2)*rho_inf)
        else: 
            r = -1
            
        return bias1,bias2,r
            
        
    def update_state(self,group,state,grad):
        state['step'] += 1
        
        beta1, beta2 = group['betas']
        state['exp_avg'].mul_(beta1).add_(grad, alpha=1-beta1)
        state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=1-beta2)
        
        return state['exp_avg'], state['exp_avg_sq']
        
        
    @torch.no_grad()
    def step(self):
        for group,rho_inf in zip(self.param_groups,self.rho_infs):
            bias1,bias2,r = self._per_group(group,rho_inf)
            
            for p in group['params']:
                if p.grad is None:
                    continue
                    
                # grad
                grad = p.grad
                    
                # state
                state = self.state[p]
                m,v = self.update_state(group,state,grad)
                
                # dibias
                m_hat = m / bias1
                
                # update value
                if r > 0:
                    v_hat = v / bias2
                    update = group['lr']*m_hat/self.sp(v_hat.mul(r).sqrt())
                else:
                    update = group['lr']*m_hat
                if group['weight_decay'] != 0: update += group['weight_decay']*p.data
                
                # do update
                p.sub_(update)
                
                # update state with new update-value
                state['update'] = update

## LAMB ([2020](https://arxiv.org/abs/1904.00962v5))

### algorithm

$$\begin{align*}
\text{initial:}&\\
&m_0=0\\
&v_0=0\\
\\
\text{step:}&\\
&g_t\leftarrow\bigtriangledown_{\theta}\mathcal{L}(S_t,\theta_t)\\
\\
&m_t\leftarrow\beta_1\cdot m_{t-1}+(1-\beta_1)\cdot g_t\\
&v_t\leftarrow\beta_2\cdot v_{t-1}+(1-\beta_2)\cdot g_t^2\\
\\
&\hat{m}_t\leftarrow\frac{m_t}{1-\beta_1^t}\\
&\hat{v}_t\leftarrow\frac{v_t}{1-\beta_2^t}\\
\\
&A_t\leftarrow\frac{\hat{m}_t}{\sqrt{\hat{v}_t}+\epsilon}+\lambda\cdot\theta_t\quad\quad\text{# adam step}\\
\\
&s_t^l\leftarrow\eta\cdot\lVert\theta_t^l\rVert_2\cdot\frac{A_t}{\lVert A_t\rVert_2}\\
&\theta_{t+1}^l\leftarrow\theta_t^l-s_t^l
\end{align*}$$

### LAMB

In [11]:
# export
class LAMB(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0):            
        # create parameter groups
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        super().__init__(params, defaults)
        
        # 初始化state
        self.state = self._get_init_state()

        
    def _get_init_state(self):
        state = defaultdict(dict)
        for group in self.param_groups:
            for p in group['params']:
                state_p = state[p]
                if len(state_p) == 0:
                    state_p['step'] = 0
                    state_p['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['update'] = torch.zeros_like(p, memory_format=torch.preserve_format)
        return state
        
    def _per_group(self,group):
        one_param = group['params'][0] # 获取任一个param
        t = self.state[one_param]['step']+1 # 此时step是未更新的，所以+1
        beta1, beta2 = group['betas']
        bias1 = 1-beta1**t
        bias2 = 1-beta2**t
        return bias1,bias2
        
        
    def update_state(self,group,state,grad):
        state['step'] += 1
        
        beta1, beta2 = group['betas']
        state['exp_avg'].mul_(beta1).add_(grad, alpha=1-beta1)
        state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=1-beta2)
        
        return state['exp_avg'], state['exp_avg_sq']
        
        
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            if len(group['params'])==0: continue
                
            bias1,bias2 = self._per_group(group)
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                # grad
                grad = p.grad
                
                # state
                state = self.state[p]
                m,v = self.update_state(group,state,grad)
                
                # debias
                m_hat = m / bias1
                v_hat = v / bias2
                
                # adam step
                adam_step = m_hat/(v_hat.sqrt().add(group['eps']))
                if group['weight_decay'] != 0: adam_step += group['weight_decay']*p.data
                
                # adam step scale
                adam_step_scale = adam_step.pow(2).mean().sqrt()
                
                # layer scale
                layer_scale = p.data.pow(2).mean().sqrt()
                
                # update value
                update = group['lr']*layer_scale*adam_step/adam_step_scale
                
                # do update
                p.sub_(update)
                
                # update state with new update-value
                state['update'] = update

## Adam_GCC ([2020](https://arxiv.org/abs/2004.01461))

GCC: Gradint Centeralization for Conv weights

### algorithm

$$\begin{align*}
\text{initial:}&\\
&m_0=0\\
&v_0=0\\
\\
\text{step:}&\\
&g_t\leftarrow\bigtriangledown_{\theta}\mathcal{L}(S_t,\theta_t)\\
\\
&\text{if }\theta\text{ is conv.weight:}\\
&\quad\quad\tilde{g}_t\leftarrow g_t-E_{chout}[g_t]\quad\quad\text{# }E_{chout}\text{ means average on chout}\\
\\
&m_t\leftarrow\beta_1\cdot m_{t-1}+(1-\beta_1)\cdot\tilde{g}_t\\
&v_t\leftarrow\beta_2\cdot v_{t-1}+(1-\beta_2)\cdot\tilde{g}_t^2\\
\\
&\hat{m}_t\leftarrow\frac{m_t}{1-\beta_1^t}\\
&\hat{v}_t\leftarrow\frac{v_t}{1-\beta_2^t}\\
\\
&\theta_{t+1}\leftarrow\theta_t-\eta\cdot\frac{\hat{m}_t}{\sqrt{\hat{v}_t}+\epsilon}
\end{align*}$$

### Adam_GCC

In [12]:
# export
class Adam_GCC(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
                 weight_decay=0):            
        # create parameter groups
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        super().__init__(params, defaults)
        
        # 初始化state
        self.state = self._get_init_state()

        
    def _get_init_state(self):
        state = defaultdict(dict)
        for group in self.param_groups:
            for p in group['params']:
                state_p = state[p]
                if len(state_p) == 0:
                    state_p['step'] = 0
                    state_p['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['update'] = torch.zeros_like(p, memory_format=torch.preserve_format)
        return state
        
    def _per_group(self,group):
        one_param = group['params'][0] # 获取任一个param
        t = self.state[one_param]['step']+1 # 此时step是未更新的，所以+1
        beta1, beta2 = group['betas']
        bias1 = 1-beta1**t
        bias2 = 1-beta2**t
        return bias1,bias2
        
        
    def update_state(self,group,state,grad):
        state['step'] += 1
        
        beta1, beta2 = group['betas']
        state['exp_avg'].mul_(beta1).add_(grad, alpha=1-beta1)
        state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=1-beta2)
        
        return state['exp_avg'], state['exp_avg_sq']
        
        
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            if len(group['params'])==0: continue
                
            bias1,bias2 = self._per_group(group)
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                # grad
                grad = p.grad
                
                # GCC
                # conv.weight.shape is (ch_out,ch_in,k1,k2)
                if len(list(grad.shape))==4: 
                    grad.sub_(grad.mean(dim=[1,2,3], keepdim = True))
                
                # state
                state = self.state[p]
                m,v = self.update_state(group,state,grad)
                
                # debias
                m_hat = m / bias1
                v_hat = v / bias2
                
                # update value
                update = group['lr']*m_hat/(v_hat.sqrt().add(group['eps']))
                if group['weight_decay'] != 0: update += group['weight_decay']*p.data
                
                # do update
                p.sub_(update)
                
                # update state with new update-value
                state['update'] = update

## Stable_Adam (xty)

想法：
- 训练开始时记录各（后跟bn的）层参数的scale，记M0；
- 在之后的各次迭代中，以当时的scale（记Mt）与M0的反比来调整step；
- stable的含义：原Adam中，后跟bn的卷积层参数的整体尺度会逐渐变大，导致有效学习率反比下降，该方法使有效学习率保持稳定，所以称stable.
- 注意：这里求scale并不是对一个weight内的所有元素作为整体求一个scale，而是按chout，一个chout内的元素作为整体求一个scale，总共求chout个scale. 
- 为什么按chout而不是按layer？对于activation，一个chout是一个特征图，一个特征图内的像素是在不同位置处的特征，但它们是同一个特征；而不同的特征图是不同的特征，不同的特征柔和在一起是没有意义的。在activation中如此，在weight中也是如此。

### algorithm

$$\begin{align*}
\text{initial:}&\\
&m_0=0\\
&v_0=0\\
\\
&\Theta=\{\text{all}\;\theta\;\text{that is followed by a batchnorm}\}\\
\\
\text{step:}&\\
&g_t\leftarrow\bigtriangledown_{\theta}\mathcal{L}(S_t,\theta_t)\\
\\
&m_t\leftarrow\beta_1\cdot m_{t-1}+(1-\beta_1)\cdot g_t\\
&v_t\leftarrow\beta_2\cdot v_{t-1}+(1-\beta_2)\cdot g_t^2\\
\\
&\hat{m}_t\leftarrow\frac{m_t}{1-\beta_1^t}\\
&\hat{v}_t\leftarrow\frac{v_t}{1-\beta_2^t}\\
\\
&s_t\leftarrow\frac{\hat{m}_t}{\sqrt{\hat{v}_t}+\epsilon}\\
\\
&\text{if}\;\theta\in\Theta\;:\\
&\quad\quad\text{for}\;\text{k}\;\text{in}\;range(chout)\;:\\
&\quad\quad\quad\quad r_t^k\leftarrow\frac{\lVert\theta_0^k\rVert_2}{\lVert\theta_t^k\rVert_2}\\
&\quad\quad\quad\quad\theta_{t+1}^k\leftarrow\theta_t^k-\eta\cdot r_t^k\cdot s_t\\
&\text{else}\;:\\
&\quad\quad\theta_{t+1}\leftarrow\theta_t-\eta\cdot s_t\\
\end{align*}$$

### Stable_Adam

In [13]:
# export
def conv_not_cinEq256_coutLt256(p):
    'check if p is conv weight, but not (its chin==256 and chout<256).'
    # if not conv weight, return False
    if len(p.shape)<4: return False
    
    # if cin==256 and cout<256, return False
    if p.shape[1]==256 and p.shape[0]<256: return False
    
    # else, match condition
    return True
    
class Stable_Adam(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
                 weight_decay=0, condition_func=conv_not_cinEq256_coutLt256):            
        # create parameter groups
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        super().__init__(params, defaults)
        
        self.cond_func = condition_func
        
        # 初始化state
        self.state = self._get_init_state()

        
    def _get_init_state(self):
        state = defaultdict(dict)
        for group in self.param_groups:
            for p in group['params']:
                state_p = state[p]
                if len(state_p) == 0:
                    state_p['step'] = 0
                    state_p['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['update'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    
                    if self.cond_func(p.data): 
                        state_p['init_mag'] = p.data.detach().pow(2).mean(dim=[1,2,3], keepdim=True).sqrt()
        return state
        
    def _per_group(self,group):
        one_param = group['params'][0] # 获取任一个param
        t = self.state[one_param]['step']+1 # 此时step是未更新的，所以+1
        beta1, beta2 = group['betas']
        bias1 = 1-beta1**t
        bias2 = 1-beta2**t
        return bias1,bias2
        
        
    def update_state(self,group,state,grad):
        state['step'] += 1
        
        beta1, beta2 = group['betas']
        state['exp_avg'].mul_(beta1).add_(grad, alpha=1-beta1)
        state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=1-beta2)
        
        return state['exp_avg'], state['exp_avg_sq']
        
        
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            if len(group['params'])==0: continue
                
            bias1,bias2 = self._per_group(group)
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                # grad
                grad = p.grad
                
                # state
                state = self.state[p]
                m,v = self.update_state(group,state,grad)
                
                # debias
                m_hat = m / bias1
                v_hat = v / bias2
                
                # magnitude change
                if 'init_mag' in state:
                    r = p.data.pow(2).mean(dim=[1,2,3], keepdim=True).sqrt()/state['init_mag']
                else: 
                    r = 1.
                
                # update value
                update = group['lr']*r*m_hat/(v_hat.sqrt().add(group['eps']))
                if group['weight_decay'] != 0: update += group['weight_decay']*p.data
                
                # do update
                p.sub_(update)
                
                # update state with new update-value
                state['update'] = update

## Adam_AWD (xty)

Adam with Adaptive Weight Decay
    - 实验发现，使用Stable_Adam优化器时，weight scale增长的更快，这不是一个好现象
    - 在文章里(https://blog.janestreet.com/l2-regularization-and-batch-norm/)写道，使用batchnorm而无weight decay时，weight scale会随时间有变大的趋势，而设置合适的weight decay可以抑制这个趋势，使weight scale保持稳定。
    - 这个weight decay应该设多大恰好可以抵消这个趋势？不好把握。所以是否可以用“自适应weight decay”
    - 自适应wd：
        - 应用于后跟bn的卷积层，
        - 目标是使其weight scale保持稳定，
        - 方法：记录初始scale，在之后的时刻，完成更新后，将weight归一化到初始scale，即乘以 init_scale/current_scale

### algorithm

$$\begin{align*}
\text{initial:}&\\
&m_0=0\\
&v_0=0\\
\\
&\Theta=\{\text{all}\;\theta\;\text{that is followed by a batchnorm}\}\\
\\
\text{step:}&\\
&g_t\leftarrow\bigtriangledown_{\theta}\mathcal{L}(S_t,\theta_t)\\
\\
&m_t\leftarrow\beta_1\cdot m_{t-1}+(1-\beta_1)\cdot g_t\\
&v_t\leftarrow\beta_2\cdot v_{t-1}+(1-\beta_2)\cdot g_t^2\\
\\
&\hat{m}_t\leftarrow\frac{m_t}{1-\beta_1^t}\\
&\hat{v}_t\leftarrow\frac{v_t}{1-\beta_2^t}\\
\\
&\theta_{t+1}\leftarrow\theta_t-\eta\cdot\frac{\hat{m}_t}{\sqrt{\hat{v}_t}+\epsilon}\\
\\
&\text{if}\;\theta\in\Theta\;\text{:}\\
&\quad\quad\text{for}\;\text{k}\;\text{in}\;range(chout)\;:\\
&\quad\quad\quad\quad\theta_{t+1}^k\leftarrow\theta_{t+1}^k\cdot\frac{\lVert\theta_0^k\rVert_2}{\lVert\theta_{t+1}^k\rVert_2}\\
\end{align*}$$

### Adam_AWD

In [14]:
# export
class Adam_AWD(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
                 condition_func=conv_not_cinEq256_coutLt256):            
        # create parameter groups
        defaults = dict(lr=lr, betas=betas, eps=eps)
        super().__init__(params, defaults)
        
        self.cond_func = condition_func
        
        # 初始化state
        self.state = self._get_init_state()

        
    def _get_init_state(self):
        state = defaultdict(dict)
        for group in self.param_groups:
            for p in group['params']:
                state_p = state[p]
                if len(state_p) == 0:
                    state_p['step'] = 0
                    state_p['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['update'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    
                    if self.cond_func(p.data): 
                        state_p['init_mag'] = p.data.detach().pow(2).mean(dim=[1,2,3], keepdim=True).sqrt()
        return state
        
    def _per_group(self,group):
        one_param = group['params'][0] # 获取任一个param
        t = self.state[one_param]['step']+1 # 此时step是未更新的，所以+1
        beta1, beta2 = group['betas']
        bias1 = 1-beta1**t
        bias2 = 1-beta2**t
        return bias1,bias2
        
        
    def update_state(self,group,state,grad):
        state['step'] += 1
        
        beta1, beta2 = group['betas']
        state['exp_avg'].mul_(beta1).add_(grad, alpha=1-beta1)
        state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=1-beta2)
        
        return state['exp_avg'], state['exp_avg_sq']
        
        
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            if len(group['params'])==0: continue
                
            bias1,bias2 = self._per_group(group)
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                # grad
                grad = p.grad
                
                # state
                state = self.state[p]
                m,v = self.update_state(group,state,grad)
                
                # debias
                m_hat = m / bias1
                v_hat = v / bias2
                
                # update value
                update = group['lr']*m_hat/(v_hat.sqrt().add(group['eps']))
                
                # do update
                p.sub_(update)
                
                # adaptive weight decay
                if 'init_mag' in state:
                    p.div_(p.data.pow(2).mean(dim=[1,2,3], keepdim=True).sqrt()).mul_(state['init_mag'])
                    
                # update state with new update-value
                state['update'] = update

## SRAdam (xty)

### algorithm

$$\begin{align*}
\text{initial:}&\\
&\rho_{\infty}\leftarrow\frac{2}{1-\beta_2}-1\\
&m_0\leftarrow0\\
&v_0\leftarrow0\\
\\
\text{step:}&\\
&g_t\leftarrow\bigtriangledown_{\theta}\mathcal{L}(S_t,\theta_t)\\
\\
&m_t\leftarrow\beta_1\cdot m_{t-1}+(1-\beta_1)\cdot g_t\\
&v_t\leftarrow\beta_2\cdot v_{t-1}+(1-\beta_2)\cdot g_t^2\\
\\
&\hat{m}_t\leftarrow\frac{m_t}{1-\beta_1^t}\\
\\
&\rho_t\leftarrow\rho_{\infty}-\frac{2t\cdot\beta_2^t}{1-\beta_2^t}\\
&\text{if}\space\space\rho_t>4\space\space\text{:}\\
&\quad\quad\hat{v}_t\leftarrow\frac{v_t}{1-\beta_2^t}\\
&\quad\quad r_t\leftarrow\frac{(\rho_{\infty}-4)(\rho_{\infty}-2)\rho_t}{(\rho_t-4)(\rho_t-2)\rho_{\infty}}\\
&\quad\quad\theta_{t+1}\leftarrow\theta_t-\eta\cdot\frac{\hat{m}_t}{softplus\left(\sqrt{r_t\cdot\hat{v}_t}\right)}
\quad\quad\text{# the only difference from RAdam}\\
&\text{else:}\\
&\quad\quad\theta_{t+1}\leftarrow\theta_t-\eta\cdot\hat{m}_t
\end{align*}$$

### SRAdam

In [15]:
# export
class SRAdam(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), weight_decay=0,
                 smooth=50):            
        # create parameter groups
        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
        super(SRAdam, self).__init__(params, defaults)
        
        # 初始化state
        self.state = self._get_init_state()
        
        # 在RAdam中，betas[1]应该是固定的，所以可以在这里把rho_inf计算出来
        self.rho_infs = [2/(group['betas'][1])-1 for group in self.param_groups]
        
        # softplus
        self.sp = torch.nn.Softplus(smooth)

        
    def _get_init_state(self):
        state = defaultdict(dict)
        for group in self.param_groups:
            for p in group['params']:
                state_p = state[p]
                if len(state_p) == 0:
                    state_p['step'] = 0
                    state_p['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['update'] = torch.zeros_like(p, memory_format=torch.preserve_format)
        return state
        
    def _per_group(self,group,rho_inf):
        one_param = group['params'][0] # 获取任一个param
        t = self.state[one_param]['step']+1 # 此时step是未更新的，所以+1
        beta1, beta2 = group['betas']
        bias1 = 1-beta1**t
        bias2 = 1-beta2**t
        
        rho = rho_inf - (2*t*beta2**t)/(1-beta2**t)
        if rho>4: 
            r = (rho_inf-4)*(rho_inf-2)*rho/((rho-4)*(rho-2)*rho_inf)
        else: 
            r = -1
            
        return bias1,bias2,r
            
        
    def update_state(self,group,state,grad):
        state['step'] += 1
        
        beta1, beta2 = group['betas']
        state['exp_avg'].mul_(beta1).add_(grad, alpha=1-beta1)
        state['exp_avg_sq'].mul_(beta2).addcmul_(grad, grad, value=1-beta2)
        
        return state['exp_avg'], state['exp_avg_sq']
        
        
    @torch.no_grad()
    def step(self):
        for group,rho_inf in zip(self.param_groups,self.rho_infs):
            if len(group['params'])==0: continue
                
            bias1,bias2,r = self._per_group(group,rho_inf)
            
            for p in group['params']:
                if p.grad is None:
                    continue
                    
                # grad
                grad = p.grad
                    
                # state
                state = self.state[p]
                m,v = self.update_state(group,state,grad)
                
                # dibias
                m_hat = m / bias1
                
                # update value
                if r > 0:
                    v_hat = v / bias2
                    update = group['lr']*m_hat/self.sp(v_hat.mul(r).sqrt())
                else:
                    update = group['lr']*m_hat
                if group['weight_decay'] != 0: update += group['weight_decay']*p.data
                
                # do update
                p.sub_(update)
                
                # update state with new update-value
                state['update'] = update

## UM (xty)

unit momentum

### algorithm

$$\begin{align*}
\text{initial:}&\\
&m_0=0\\
\\
\text{step:}&\\
&g_t\leftarrow\bigtriangledown_{\theta}\mathcal{L}(S_t,\theta_t)\\
\\
&m_t\leftarrow\beta_1\cdot m_{t-1}+(1-\beta_1)\cdot g_t\\
&\hat{m}_t\leftarrow\frac{m_t}{1-\beta_1^t}\\
\\
&\theta_{t+1}\leftarrow\theta_t-\eta\cdot\frac{\hat{m}_t}{\lVert\hat{m}_t^{l}\rVert_2}
\end{align*}$$

### UM

In [16]:
# export
class UM(Optimizer):
    def __init__(self, params, lr=1e-3, beta=0.9, weight_decay=0):            
        # create parameter groups
        # 虽然我们仅用到了一个beta，但是fastai的learner要求必须有betas=(#,#)，有点扯，但是...
        defaults = dict(lr=lr, betas=(beta,beta), weight_decay=weight_decay)
        super().__init__(params, defaults)
        
        # 初始化state
        self.state = self._get_init_state()

        
    def _get_init_state(self):
        state = defaultdict(dict)
        for group in self.param_groups:
            for p in group['params']:
                state_p = state[p]
                if len(state_p) == 0:
                    state_p['step'] = 0
                    state_p['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['update'] = torch.zeros_like(p, memory_format=torch.preserve_format)
        return state
        
    def _per_group(self,group):
        one_param = group['params'][0] # 获取任一个param
        t = self.state[one_param]['step']+1 # 此时step是未更新的，所以+1
        beta = group['betas'][0]
        bias = 1-beta**t
        return bias
        
        
    def update_state(self,group,state,grad):
        state['step'] += 1
        
        beta = group['betas'][0]
        state['exp_avg'].mul_(beta).add_(grad, alpha=1-beta)
        
        return state['exp_avg']
        
        
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            if len(group['params'])==0: continue
                
            bias = self._per_group(group)
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                # grad
                grad = p.grad
                
                # state
                state = self.state[p]
                m = self.update_state(group,state,grad)
                
                # debias
                m_hat = m / bias
                
                # update value
                update = group['lr']*m_hat/(m_hat.pow(2).mean().sqrt())
                if group['weight_decay'] != 0: update += group['weight_decay']*p.data
                
                # do update
                p.sub_(update)
                
                # update state with new update-value
                state['update'] = update

## Pre_LAMB (xty)

LAMB解决的问题是：
- 由于卷积层后加batchnorm层，使卷积层参数的整体缩放不对网络输出造成影响
- 若卷积层参数整体缩放A倍，则梯度尺度变为1/A倍，相对梯度尺度变为1/A^2倍

LAMB解决该问题的思路是：
- Adam的更新步长除以自身的整体尺度，解决“梯度尺度变化”的问题
- 再乘以该层参数的整体尺度，解决“相对梯度尺度变化”的问题

Adam 本身的 m/v 形式有没有解决“梯度尺度变为1/A倍”的问题？
- Adam的更新步长是m/v
- 通常v使用的beta很大（>0.99，甚至取到0.999），v对梯度变化的跟随能力很弱，所以当梯度尺度随时间变化时，**v（相对于m）是稳定的**
- 相比v来说，m使用的beta较小（0.9），m对梯度变化的跟随能力更强，所以当梯度尺度随时间变化时，**m（相对于v）是变化的**
- 所以，如果认为v随时间是相对稳定的，则1/v起到的作用是解决“各层**初始时**梯度尺度不同”的问题，而对解决“梯度尺度随时间变化”没有作用
- 所以 m/v 的形式没有解决“梯度尺度变为1/A倍”的问题
- 所以，参数尺度变为A倍时，与梯度一样，m/v 也变为 1/A 倍

对Adam的思考：
- Adam 中的 m/v 形式的作用，通常理解有两个：
    1. 对每个梯度做单位化，使各梯度尺度相当
    2. 惩罚梯度噪声，若梯度随时间变化很大，则v会变大，1/v使更新步长变小
- 然而，实际上第二各作用是不存在的，因为在大量的实践中，v的beta都使用很大的值（>0.99，甚至0.999)，这使v对梯度变化的跟随能力很弱，v是几乎不随时间变化的。
- 对于第一个作用，其实它有两个方面
    1. 使各层的梯度尺度相当，这个作用是正面的。
    2. 使一层内的各个参数的梯度尺度相当，这个作用是否positive值得考虑，因为这使梯度失去了方向信息。

Pre_LAMB
- 直接对 g 除以其自身的整体尺度，即 /||g||_2，解决“梯度尺度变为1/A”的问题
- 再对 g 乘以参数的整体尺度，即 \*||theta||_2，解决“相对梯度尺度变为1/A^2”的问题
- 在LAMB中是在Adam step后做自适应scale，而这里是对g做自适应scale，所以命名为pre lamb，意指把自适应scale提前了
- 由于对g做了自适应scale，所以m和v都是自适应scale的了，而如果再m/v，这个自适应scale的效果就相除抵消了，
- 所以不再做m/v，直接不用v了

### algorithm

$$\begin{align*}
\text{initial:}&\\
&m_0=0\\
&v_0=0\\
\\
\text{step:}&\\
&g_t\leftarrow\bigtriangledown_{\theta}\mathcal{L}(S_t,\theta_t)\\
&\hat{g}_t\leftarrow\frac{\lVert\theta_t^l\rVert_2}{\lVert g_t^l\rVert_2}\cdot g_t\quad\quad\text{# pre adaptive scale, compared to initial LAMB}\\
\\
&m_t\leftarrow\beta_1\cdot m_{t-1}+(1-\beta_1)\cdot \hat{g}_t\\
&\hat{m}_t\leftarrow\frac{m_t}{1-\beta_1^t}\\
\\
&s_t\leftarrow\eta\cdot\hat{m}_t+\lambda\cdot\theta_t\\
&\theta_{t+1}\leftarrow\theta_t-s_t
\end{align*}$$

### Pre_LAMB

In [17]:
# export
class Pre_LAMB(Optimizer):
    def __init__(self, params, lr=1e-3, beta=0.9, weight_decay=0):            
        # create parameter groups
        # 虽然我们仅用到了一个beta，但是fastai的learner要求必须有betas=(#,#)，有点扯，但是...
        defaults = dict(lr=lr, betas=(beta,beta), weight_decay=weight_decay)
        super().__init__(params, defaults)
        
        # 初始化state
        self.state = self._get_init_state()

        
    def _get_init_state(self):
        state = defaultdict(dict)
        for group in self.param_groups:
            for p in group['params']:
                state_p = state[p]
                if len(state_p) == 0:
                    state_p['step'] = 0
                    state_p['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['update'] = torch.zeros_like(p, memory_format=torch.preserve_format)
        return state
        
    def _per_group(self,group):
        one_param = group['params'][0] # 获取任一个param
        t = self.state[one_param]['step']+1 # 此时step是未更新的，所以+1
        beta = group['betas'][0]
        bias = 1-beta**t
        return bias
        
        
    def update_state(self,group,state,grad):
        state['step'] += 1
        
        beta = group['betas'][0]
        state['exp_avg'].mul_(beta).add_(grad, alpha=1-beta)
        
        return state['exp_avg']
        
        
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            if len(group['params'])==0: continue
                
            bias = self._per_group(group)
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                # grad
                grad = p.grad.data
                grad_hat = grad*p.data.pow(2).mean().sqrt()/grad.pow(2).mean().sqrt()
                
                # state
                state = self.state[p]
                m = self.update_state(group,state,grad_hat)
                
                # debias
                m_hat = m / bias
    
                # update value
                update = group['lr']*m_hat
            
                # do update
                p.sub_(update)
                
                # update state with new update-value
                state['update'] = update

## Duv_Adam (xty)

Debiasd and Unscaled Variance  Adam

### algorithm

在原Adam（实际上从AdaGrad开始）中，v实际上带偏的，这里把v设计为无偏的，并且v是无标的。

$$\begin{align*}
\text{initial:}&\\
&m_0=0\\
&v_0=0\\
\\
\text{step:}&\\
&g_t\leftarrow\bigtriangledown_{\theta}\mathcal{L}(S_t,\theta_t)\\
\\
&r_t\leftarrow\frac{g_t}{m_{t-1}}\quad\text{if}\quad m_{t-1}\neq0\quad\text{else}\quad1\\
&m_t\leftarrow\beta_1\cdot m_{t-1}+(1-\beta_1)\cdot g_t\\
&v_t\leftarrow\beta_2\cdot v_{t-1}+(1-\beta_2)\cdot(r_t-1)^2\\
\\
&\hat{m}_t\leftarrow\frac{m_t}{1-\beta_1^t}\\
&\hat{v}_t\leftarrow\frac{v_t}{1-\beta_2^t}\\
\\
&\theta_{t+1}\leftarrow\theta_t-\eta\cdot\frac{\hat{m}_t}{1+\sqrt{\hat{v}_t}}
\end{align*}$$

### Duv_Adam

In [18]:
# export
class Duv_Adam(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), weight_decay=0):            
        # create parameter groups
        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
        super().__init__(params, defaults)
        
        # 初始化state
        self.state = self._get_init_state()

        
    def _get_init_state(self):
        state = defaultdict(dict)
        for group in self.param_groups:
            for p in group['params']:
                state_p = state[p]
                if len(state_p) == 0:
                    state_p['step'] = 0
                    state_p['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['update'] = torch.zeros_like(p, memory_format=torch.preserve_format)
        return state
        
    def _per_group(self,group):
        one_param = group['params'][0] # 获取任一个param
        t = self.state[one_param]['step']+1 # 此时step是未更新的，所以+1
        beta1, beta2 = group['betas']
        bias1 = 1-beta1**t
        bias2 = 1-beta2**t
        return bias1,bias2
        
        
    def update_state(self,group,state,grad):
        state['step'] += 1
        
        beta1, beta2 = group['betas']
        r = grad/state['exp_avg'] if state['exp_avg']!=0 else 1
        
        state['exp_avg'].mul_(beta1).add_(grad, alpha=1-beta1)
        state['exp_avg_sq'].mul_(beta2).addcmul_(r-1, r-1, value=1-beta2)
        
        return state['exp_avg'], state['exp_avg_sq']
        
        
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            if len(group['params'])==0: continue
                
            bias1,bias2 = self._per_group(group)
            
            for p in group['params']:
                if p.grad is None:
                    continue
                    
                # grad
                grad = p.grad
                    
                # state
                state = self.state[p]
                m,v = self.update_state(group,state,grad)
                
                # debias
                m_hat = m / bias1
                v_hat = v / bias2
                
                # update value
                update = group['lr']*m_hat/(v_hat.sqrt().add(1))
                if group['weight_decay'] != 0: update += group['weight_decay']*p.data
                
                # do update
                p.sub_(update)
                
                # update state with new update-value
                state['update'] = update

## Su_ADAM (xty)

Scale adaptive and Unit gradient Adam

### 说明

**基于几个想法：**
- 使用**单位梯度**
    - 认为梯度大小跟与最优点的距离之间没有关系，因此仅使用其方向
    - 不受小噪声的影响，只要噪声不改变梯度方向，则其不会改变单位梯度
- 使用动量
    - 单位梯度的动量
    - 偶尔的方向变化会使动量幅度下降
    - 频繁的方向变化会使动量幅度接近0，频繁的方向变化可能表示两种情况
        - 大量的大幅度的噪声
        - 在极小点左右震荡
- 除以方差动量
    - 单位动量符号变化约频繁，则方差动量越大，则更新步长越小
    - 方差动量与单位梯度动量的作用一致，加强了“方向变化时减小更新步长”的效果
- 更新步长**正比于参数模**
    - 参数模越大则更新幅度越大，成正比
    - 使用单个参数本身的模还是一批（比例同一个channel）参数的整体的模？
        - 使用一批参数的模的问题：不能体现出这一批内部的各个参数的大小的差异，尤其是可能有模较大的参数，这些参数的更新速度会被拖慢。
        - 使用单个参数的模的问题：参数非常小时更新非常慢，尤其是参数模为0时停止更新
        - 方法：以一批参数的模的平均值为下限，单个参数的模低于下限时使用该下限，单个参数的模高于下限时使用单个参数的模。

### algorithm

$$\begin{align*}
\text{initial:}&\\
&m_0=0\\
&v_0=0\\
\\
\text{step:}&\\
&g_t\leftarrow\bigtriangledown_{\theta}\mathcal{L}(S_t,\theta_t)\\
&s_t\leftarrow\text{sign}(g_t)\\
\\
&r_t\leftarrow\frac{s_t}{m_{t-1}}\quad\text{if}\quad m_{t-1}\neq0\quad\text{else}\quad1\\
&m_t\leftarrow\beta_1\cdot m_{t-1}+(1-\beta_1)\cdot s_t\\
&v_t\leftarrow\beta_2\cdot v_{t-1}+(1-\beta_2)\cdot\lvert r_t-1\rvert\\
\\
&\hat{m}_t\leftarrow\frac{m_t}{1-\beta_1^t}\\
&\hat{v}_t\leftarrow\frac{v_t}{1-\beta_2^t}\\
\\
&\epsilon_t\leftarrow E[\lvert\theta_{t}\rvert^{layer}]\\
\\
&\theta_{t+1}\leftarrow\theta_t-\eta\cdot\epsilon_t\cdot\frac{\hat{m}_t}{1+\hat{v}_t}
\end{align*}$$

### SU_Adam

In [19]:
# export
class SU_Adam(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), weight_decay=0):            
        # create parameter groups
        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
        super().__init__(params, defaults)
        
        # 初始化state
        self.state = self._get_init_state()
    
    def _get_init_state(self):
        state = defaultdict(dict)
        for group in self.param_groups:
            for p in group['params']:
                state_p = state[p]
                if len(state_p) == 0:
                    state_p['step'] = 0
                    state_p['exp_avg_unit'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['exp_avg_ratio'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['update'] = torch.zeros_like(p, memory_format=torch.preserve_format)
        return state
        
    def _per_group(self,group):
        one_param = group['params'][0] # 获取任一个param
        t = self.state[one_param]['step']+1 # 此时step是未更新的，所以+1
        beta1, beta2 = group['betas']
        bias1 = 1-beta1**t
        bias2 = 1-beta2**t
        return bias1,bias2
        
        
    def _update_state(self,group,state,grad):
        state['step'] += 1
        
        beta1, beta2 = group['betas']
        s = grad.sign()
        r = torch.where(state['exp_avg_unit']==0, torch.tensor(1.,device=grad.device), s/state['exp_avg_unit'])
        
        state['exp_avg_unit'].mul_(beta1).add_(s, alpha=1-beta1)
        state['exp_avg_ratio'].mul_(beta2).add_((r-1).abs(), alpha=1-beta2)
        
        return state['exp_avg_unit'], state['exp_avg_ratio']
        
        
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            if len(group['params'])==0: continue
                
            bias1,bias2 = self._per_group(group)
            
            for p in group['params']:
                if p.grad is None:
                    continue
                    
                # grad
                grad = p.grad
                    
                # state
                state = self.state[p]
                m,v = self._update_state(group,state,grad)
                
                # debias
                m_hat = m / bias1
                v_hat = v / bias2
                
                # layer scale
                layer_scale = p.data.abs().mean()
#                 # rectify abs
#                 ele_abs = ele_abs.clamp_min(ele_abs.mean())
                              
                # update value
                update = group['lr']*layer_scale*m_hat/(1+v_hat)
#                 update = group['lr']*ele_abs*m_hat
                if group['weight_decay'] != 0: update += group['weight_decay']*p.data
                
                # do update
                p.sub_(update)
                
                # update state with new update-value
                state['update'] = update

## B_ADAM (xty)

Binomial Adam

### algorithm

$$\begin{align*}
\text{initial:}&\\
&m_0=0\\
&v_1=0\\
\\
\text{step:}&\\
&g_t\leftarrow\bigtriangledown_{\theta}\mathcal{L}(S_t,\theta_t)\\
\\
&m_t\leftarrow\beta_1\cdot m_{t-1}+(1-\beta_1)\cdot g_t\\
&\hat{m}_t\leftarrow\frac{m_t}{1-\beta_1^t}\\
\\
&\epsilon_t\leftarrow E[\lvert\theta_{t}\rvert^{layer}]\\
\\
&\text{if}\quad t>1:\\
&\quad\quad\Phi_t\leftarrow\frac{g_t-g_{t-1}}{\text{softplus}(s_{t-1})}\\
&\quad\quad v_t\leftarrow\beta_2\cdot v_{t-1}+(1-\beta_2)\cdot\Phi_t\\
&\quad\quad\hat{v}_t\leftarrow\frac{v_t}{1-\beta_2^{t-1}}\\
\\
&\quad\quad s_t\leftarrow\eta\cdot\epsilon_t\cdot\frac{\hat{m}_t}{\lvert\hat{v}_{t-1}\rvert}\\
&\text{else :}\\
&\quad\quad s_t\leftarrow\eta\cdot\epsilon_t\cdot\hat{m}_t\\
\\
&\theta_{t+1}\leftarrow\theta_t-s_t\\
\end{align*}$$

### B_Adam

In [20]:
# export
def sign_clamp(v,minv):
    res = torch.where(v.abs()>minv,v,v.sign()*minv)
    res = torch.where(res==0,minv,res)
    if res.abs().min()<minv: idb.set_trace()
    return res

In [21]:
# export
class B_Adam(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), weight_decay=0, smooth=50):            
        # create parameter groups
        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
        super().__init__(params, defaults)
        
        # 初始化state
        self.state = self._get_init_state()
        
        # softplus
        self.sp = torch.nn.Softplus(smooth)
    
    def _get_init_state(self):
        state = defaultdict(dict)
        for group in self.param_groups:
            for p in group['params']:
                state_p = state[p]
                if len(state_p) == 0:
                    state_p['step'] = 0
                    state_p['exp_avg_1st'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['exp_avg_2ed'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['update'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state_p['grad'] = torch.zeros_like(p, memory_format=torch.preserve_format)
        return state
        
    def _per_group(self,group):
        one_param = group['params'][0] # 获取任一个param
        t = self.state[one_param]['step']+1 # 此时step是未更新的，所以+1
        beta1, beta2 = group['betas']
        bias1 = 1-beta1**t
        bias2 = 1-beta2**(t-1)
        return bias1,bias2
    
    
    @torch.no_grad()
    def step(self):
#         idb.set_trace()
        for group in self.param_groups:
            if len(group['params'])==0: continue
                
            bias1,bias2 = self._per_group(group)
            beta1,beta2 = group['betas']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                    
                # grad
                grad = p.grad
                    
                # state
                state = self.state[p]
                
                # update step
                state['step'] += 1
                
                # update exp_avg_1st
                m = state['exp_avg_1st'].mul_(beta1).add_(grad, alpha=1-beta1)
                # debias exp_avg_1st
                m_hat = m / bias1
                
                # layer_scale
                layer_scale = p.data.abs().mean()
                
                if state['step']==1:
                    update = group['lr']*layer_scale*m_hat
                else:
#                     Phi = (grad-state['grad'])/self.sp(state['update'])
                    Phi = (grad-state['grad'])/sign_clamp(state['update'],state['update'].abs().mean())
                    v = state['exp_avg_2ed'].mul_(beta2).add_(Phi, alpha=1-beta2)
                    v_hat = v/bias2
                    
                    update = group['lr']*layer_scale*m_hat/(self.sp(v_hat.abs()))
                    
                # do update
                if group['weight_decay'] != 0: update += group['weight_decay']*p.data
                p.sub_(update)
                
                # update state with new update-value
                state['update'] = update
                state['grad'] = grad.data.clone().detach()
                

# test

# export

In [1]:
!python ../../notebook2script.py --fname 'optimizer.ipynb' --outputDir '../exp/'

Converted optimizer.ipynb to ../exp/optimizer.py
