In [2]:
import torch
from torch.optim import Optimizer
import math

class GradaGradSimplified(Optimizer):
    def __init__(self, params, lr=1.0, rho=2.0, r=0.01, eps=1e-8, G_inf=1.0, D_inf=1.0):
        defaults = dict(lr=lr, rho=rho, r=r, eps=eps, G_inf=G_inf, D_inf=D_inf)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = closure() if closure is not None else None

        for group in self.param_groups:
            lr, rho, r, eps, G_inf, D_inf = group['lr'], group['rho'], group['r'], group['eps'], group['G_inf'], group['D_inf']

            grads = []
            states = []
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]
                if len(state) == 0:
                    state['alpha'] = torch.tensor(0.0, device=p.device)
                    state['gamma'] = torch.tensor(lr, device=p.device)
                    state['prev_grad'] = torch.zeros_like(p.data)

                grads.append(p.grad.view(-1))
                states.append((p, state))

            if not grads:
                continue

            grad_vec = torch.cat(grads)
            grad_norm_sq = grad_vec.dot(grad_vec)

            p0, state0 = states[0]
            alpha = state0['alpha']
            gamma = state0['gamma']
            prev_grad = state0['prev_grad'].view(-1)

            inner_prod = grad_vec.dot(prev_grad) if prev_grad is not None else torch.tensor(0.0, device=grad_vec.device)
            vk = grad_norm_sq - rho * inner_prod

            if vk >= 0:
                alpha_new = alpha + vk
                gamma_new = gamma
            else:
                vk_clipped = max(vk.item(), -r * alpha.item())
                gamma_new = gamma * math.sqrt(1 - vk_clipped / (alpha.item() + eps))
                alpha_new = alpha

            A = math.sqrt(alpha_new.item()) / (gamma_new.item() + eps)
            scale = -1.0 / A

            for p, state in states:
                if p.grad is None:
                    continue
                p.add_(p.grad, alpha=scale)

                state['alpha'] = alpha_new
                state['gamma'] = gamma_new
                state['prev_grad'] = grad_vec.detach().clone()

        return loss