In [1]:
import numpy as np
import torch
from torch.optim import Optimizer
from torch import nn 

In [43]:

class AdanOptimizer(Optimizer):
    def __init__(self, params, learning_rate=0.001, betas=(0.02, 0.08, 0.01), epsilon=1e-8, weight_decay = 0, 
                 restart_cond: callable = None):
        
        defaults = dict(
            lr = learning_rate,
            betas = betas,
            eps = epsilon,
            weight_decay = weight_decay,
            restart_cond = restart_cond
        )
        
        super().__init__(params, defaults)
        
        

    def step(self):
        """Weights update using Adam.

          m_k = (1-beta1) * m_k-1 + (beta1) * g_k
          v_k = (1-beta2) * v_k-1 + (beta2) * (g_k-g_k_-1)
          n_k = (1-beta3) * n_k-1 + (beta3) * (g_k+(1-beta2)(g_k-g_k_-1))^2
          lr_k = lr/(n_k+epsilon)^0.5
          w_k+1 = (1+wd)
          w_k+1 = (1+weight_decay*lr)^-1 (w_k - lr_k(m_k + (1-beta2)v_k))
        """
                
        for group in self.param_groups:

            lr = group['lr']
            beta1, beta2, beta3 = group['betas']
            weight_decay = group['weight_decay']
            eps = group['eps']
            restart_cond = group['restart_cond']

            for p in group['params']:
                if not (p.grad is not None):
                    continue

                data, grad = p.data, p.grad.data
                assert not grad.is_sparse

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['prev_grad'] = torch.zeros_like(grad)
                    state['m'] = torch.zeros_like(grad)
                    state['v'] = torch.zeros_like(grad)
                    state['n'] = torch.zeros_like(grad)

                step, m, v, n, prev_grad = state['step'], state['m'], state['v'], state['n'], state['prev_grad']

                if step > 0:
                    prev_grad = state['prev_grad']
                    m.mul_(1 - beta1).add_(grad, alpha = beta1)
                    grad_diff = grad - prev_grad
                    v.mul_(1 - beta2).add_(grad_diff, alpha = beta2)
                    next_n = (grad + (1 - beta2) * grad_diff) ** 2
                    n.mul_(1 - beta3).add_(next_n, alpha = beta3)

                # bias correction terms

                step += 1

                correct_m, correct_v, correct_n = map(lambda n: 1 / (1 - (1 - n) ** step), (beta1, beta2, beta3))

                # gradient step

                def grad_step_(data, m, v, n):
                    weighted_step_size = lr / (n * correct_n).sqrt().add_(eps)# changed

                    denom = 1 + weight_decay * lr

                    data.addcmul_(weighted_step_size, (m * correct_m + (1 - beta2) * v * correct_v), value = -1.).div_(denom)

                grad_step_(data, m, v, n)

                # restart condition

                # if exists(restart_cond) and restart_cond(state):
                #     m.data.copy_(grad)
                #     v.zero_()
                #     n.data.copy_(grad ** 2)

                #     grad_step_(data, m, v, n)

                # set new incremented step

                prev_grad.copy_(grad)
                state['step'] = step

        return loss

In [44]:
from adan_pytorch import Adan

model = torch.nn.Sequential(
    nn.Linear(16, 16),
    nn.GELU()
)

model2 = torch.nn.Sequential(
    nn.Linear(16, 16),
    nn.GELU()
)
model2.load_state_dict(model.state_dict())

<All keys matched successfully>

In [45]:
model_features = torch.randn(16)

In [46]:
optim = AdanOptimizer(
    model.parameters(),
    learning_rate = 1e-3,                  # learning rate (can be much higher than Adam, up to 5-10x)
    betas = (0.02, 0.08, 0.01), # beta 1-2-3 as described in paper - author says most sensitive to beta3 tuning
    weight_decay = 0.02         # weight decay 0.02 is optimal per author
)

optim_2 = Adan(
    model2.parameters(),
    lr = 1e-3,                  # learning rate (can be much higher than Adam, up to 5-10x)
    betas = (0.02, 0.08, 0.01), # beta 1-2-3 as described in paper - author says most sensitive to beta3 tuning
    weight_decay = 0.02         # weight decay 0.02 is optimal per author
)

# train
for _ in range(10):
    loss = model(model_features).sum()
    loss.backward()
    optim.step()
    optim.zero_grad()

# train
for _ in range(10):
    loss = model2(model_features).sum()
    loss.backward()
    optim_2.step()
    optim_2.zero_grad()

In [47]:
list(model[0].parameters())[1].detach().numpy()

array([-0.08049092,  0.2285945 , -0.07219817, -0.00345977,  0.16656336,
        0.00050445, -0.14645079,  0.04010449, -0.1872931 , -0.06345155,
       -0.09969974,  0.03333162,  0.18819708, -0.20207672, -0.2046853 ,
        0.23350479], dtype=float32)

In [48]:
list(model2[0].parameters())[1].detach().numpy()

array([-0.08049092,  0.2285945 , -0.07219817, -0.00345977,  0.16656336,
        0.00050445, -0.14645079,  0.04010449, -0.1872931 , -0.06345155,
       -0.09969974,  0.03333162,  0.18819708, -0.20207672, -0.2046853 ,
        0.23350479], dtype=float32)