In [5]:
import torch
from torch.optim import Optimizer
import math

class GradaGradSimplified(Optimizer):
    def __init__(self, params, lr=1.0, rho=2.0, r=0.01, eps=1e-8):
        defaults = dict(lr=lr, rho=rho, r=r, eps=eps)
        super().__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            rho = group['rho']
            r = group['r']
            eps = group['eps']

            if 'alpha' not in group:
                group['alpha'] = torch.tensor(0.0, device=group['params'][0].device)
                group['gamma'] = torch.tensor(group['lr'], device=group['params'][0].device)
                group['prev_grad'] = None

            alpha = group['alpha']
            gamma = group['gamma']
            prev_grad = group['prev_grad']

            grads = []
            for p in group['params']:
                if p.grad is None:
                    continue
                grads.append(p.grad.view(-1))
            if not grads:
                continue
            grad_vec = torch.cat(grads)

            grad_norm_sq = grad_vec.dot(grad_vec)
            inner_prod = grad_vec.dot(prev_grad) if prev_grad is not None else torch.tensor(0.0, device=grad_vec.device)

            v_k = grad_norm_sq - rho * inner_prod

            if v_k >= 0:
                alpha_new = alpha + v_k
            else:
                v_k = max(v_k.item(), -r * alpha.item())
                gamma = gamma * math.sqrt(1 - v_k / (alpha.item() + eps))
                alpha_new = alpha

            A = math.sqrt(alpha_new.item()) / (gamma.item() + eps)

            for p in group['params']:
                if p.grad is None:
                    continue
                g = p.grad
                p.data.add_(g, alpha=-1.0 / A)

            group['alpha'] = alpha_new
            group['gamma'] = gamma
            group['prev_grad'] = grad_vec.detach().clone()

        return loss
