In [26]:
import torch
from torch.optim import Optimizer

class GradaGrad(Optimizer):
    def __init__(self, params, gamma=1.0, rho=2.0, beta=0.9, G_inf=None, D_inf=None):
        defaults = dict(gamma=gamma, rho=rho, beta=beta, G_inf=G_inf, D_inf=D_inf)
        super(GradaGrad, self).__init__(params, defaults)

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                state['alpha'] = torch.zeros_like(p)      
                state['momentum'] = torch.zeros_like(p)   
                state['gamma_i'] = torch.full_like(p, group['gamma'])  
                state['z'] = p.clone()                   
                state['step'] = 0                         

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            gamma = group['gamma']
            rho = group['rho']
            beta = group['beta']
            G_inf = group['G_inf']
            D_inf = group['D_inf']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                    
                grad = p.grad.data
                state = self.state[p]
                state['step'] += 1
                
                alpha = state['alpha']
                m_prev = state['momentum']
                gamma_i = state['gamma_i']
                z = state['z']
                
                v_k = torch.empty_like(grad)

                if state['step'] == 1:
                    v_k[:] = G_inf**2 if G_inf is not None else grad.pow(2)
                else:
                    if D_inf is not None:
                        mask_dinf = gamma_i == D_inf
                        v_k[mask_dinf] = grad[mask_dinf].pow(2)

                        mask_else = ~mask_dinf
                        v_k[mask_else] = grad[mask_else].pow(2) - rho * grad[mask_else] * m_prev[mask_else]
                    else:
                        v_k = grad.pow(2) - rho * grad * m_prev
                
                pos_mask = v_k >= 0
                neg_mask = ~pos_mask
                
                gamma_i[pos_mask] = gamma_i[pos_mask] 
                alpha[pos_mask] += v_k[pos_mask]       
                
                if neg_mask.any():
                    g_neg = grad[neg_mask]
                    m_neg = m_prev[neg_mask]
                    alpha_neg = alpha[neg_mask]
                    gamma_neg = gamma_i[neg_mask]
                    v_neg = v_k[neg_mask]
                    
                    r = ((rho * m_neg) / (g_neg + 1e-12)).pow(2) - 1
                    
                    v_k[neg_mask] = torch.maximum(v_neg, -r * alpha_neg)
                    
                    gamma_update = gamma_neg * torch.sqrt(1 - v_k[neg_mask] / (alpha_neg + 1e-12))
                    if D_inf is not None:
                        gamma_i[neg_mask] = torch.minimum(gamma_update, torch.full_like(gamma_update, D_inf))
                    else:
                        gamma_i[neg_mask] = gamma_update
                
                A_inv = gamma_i / (alpha.sqrt() + 1e-12)
                
                z_new = z - A_inv * grad
                if D_inf is not None:
                    z_new = torch.clamp(z_new, -D_inf, D_inf)
                
                p.mul_(beta).add_(z_new, alpha=1-beta)
                state['momentum'] = A_inv * (p.detach() - z_new)
                state['z'] = z_new.clone()
                
        return loss