In [10]:
import torch
from torch.optim import Optimizer
import math

class GradaGradSimplified(Optimizer):
    def __init__(self, params, gamma_0=1.0, rho=2.0, r=1.0, eps=1e-8):
        defaults = dict(gamma_0=gamma_0, rho=rho, r=r, eps=eps)
        super(GradaGradSimplified, self).__init__(params, defaults)

        for group in self.param_groups:
            for p in group['params']:
                if p.requires_grad:
                    state = self.state[p]
                    state['alpha'] = 1e-6 
                    state['gamma'] = group['gamma_0']
                    state['g_prev'] = torch.zeros_like(p.data)
                    state['A_inv'] = 1.0 

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            rho = group['rho']
            r = group['r']
            eps = group['eps']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad
                state = self.state[p]

                g_prev = state['g_prev']
                alpha = state['alpha']
                gamma = state['gamma']

                gk_dot_gk = grad.pow(2).sum().item()  
                gk_dot_gprev = torch.dot(grad.flatten(), g_prev.flatten()).item()  
                v_k = gk_dot_gk - rho * gk_dot_gprev

                if v_k >= 0:
                    alpha_new = alpha + v_k
                    gamma_new = gamma
                else:
                    v_k_clipped = max(v_k, -r * alpha)
                    denom = alpha + eps  
                    ratio = v_k_clipped / denom
                    ratio = min(ratio, 1.0 - eps)  
                    gamma_new = gamma * math.sqrt(1 - ratio)
                    alpha_new = alpha

                A_inv = math.sqrt(max(alpha_new, eps)) / (gamma_new + eps)  
                p.add_(grad, alpha=-A_inv)

                state['alpha'] = alpha_new
                state['gamma'] = gamma_new
                state['g_prev'] = grad.clone()
                state['A_inv'] = A_inv

        return loss