# AdaBound 
This is the Tensorflow Implementation of the AdaBound optimizer used in all experiments.
Its was based on the Tensorflow implementation of Adam(https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/optimizer_v2/adam.py#L32-L274). With some implementation details inspired by Github user huyu398's AdaBound implementation(https://github.com/taki0112/AdaBound-Tensorflow/blob/master/AdaBound.py). 

It works on dense and sparse data.    


In [None]:
# Imlpementation of AdaBound from : 
# Adaptive Gradient Methods with Dynamic Bound of Learning Rate:
# https://openreview.net/forum?id=Bkg3g2R9FX

# Modified Version of the Tensorflow Adam Implementation 
# https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/optimizer_v2/adam.py#L32-L274

# also inspired by GitHub user huyu398's AdaBound implementation 
# https://github.com/taki0112/AdaBound-Tensorflow/blob/master/AdaBound.py

from tensorflow.python.framework import ops 
from tensorflow.python.keras import backend_config 
from tensorflow.python.keras.optimizer_v2 import optimizer_v2 
from tensorflow.python.ops import array_ops 
from tensorflow.python.ops import control_flow_ops 
from tensorflow.python.ops import math_ops 
from tensorflow.python.ops import state_ops 
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import training_ops 
from tensorflow.python.util.tf_export import keras_export
from tensorflow import clip_by_value

class AdaBound(optimizer_v2.OptimizerV2):
    
    def __init__(self, 
                 learning_rate=1e-3, 
                 beta_1=0.9, 
                 beta_2=0.999, 
                 final_lr=0.1, 
                 gamma=1e-3, 
                 epsilon=1e-8, 
                 amsbound=False, 
                 name='AdaBound', 
                 **kwargs):
        super(AdaBound, self).__init__(name, **kwargs)
        self._set_hyper('learning_rate', kwargs.get('lr',learning_rate))
        self._set_hyper('decay', self._initial_decay)
        self._set_hyper('beta_1', beta_1)
        self._set_hyper('beta_2', beta_2)
        self._set_hyper('final_lr', final_lr)
        self._set_hyper('gamma', gamma)
        self.epsilon = epsilon or backend_config.epsilon()
        self.amsbound = amsbound
        self.base_lr = learning_rate
        
    def _create_slots(self, var_list):
        for var in var_list:
            self.add_slot(var, 'm')
        for var in var_list:
            self.add_slot(var, 'v')
        if self.amsbound:
            for var in var_list:
                self.add_slot(var, 'vhat')
                
    def _prepare_local(self, var_device, var_dtype, apply_state):
        super(AdaBound, self)._prepare_local(var_device, var_dtype, apply_state)
        
        local_step = math_ops.cast(self.iterations +1, var_dtype)
        beta_1_t = array_ops.identity(self._get_hyper('beta_1', var_dtype))
        beta_2_t = array_ops.identity(self._get_hyper('beta_2', var_dtype))
        beta_1_power = math_ops.pow(beta_1_t, local_step)
        beta_2_power = math_ops.pow(beta_2_t, local_step)
        gamma_power = math_ops.pow(self._get_hyper('gamma', var_dtype), local_step)
        lr = apply_state[(var_device, var_dtype)]['lr_t'] * \
              ((math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)))
        final_lr = math_ops.multiply(self._get_hyper('final_lr', var_dtype),
                    math_ops.divide(apply_state[(var_device, var_dtype)]['lr_t'], 
                                 ops.convert_to_tensor(self.base_lr, var_dtype)))
        apply_state[(var_device, var_dtype)].update(dict(
            lr=lr, 
            epsilon=ops.convert_to_tensor(self.epsilon, var_dtype), 
            gamma_power=gamma_power,
            final_lr=final_lr,
            beta_1_t=beta_1_t, 
            one_minus_beta_1_t=1 - beta_1_t, 
            beta_2_t=beta_2_t, 
            beta_2_power=beta_2_power,
            one_minus_beta_2_t=1 - beta_2_t
        ))
        
    def set_weights(self, weights):
        params = self.weights
        
        num_vars = int((len(params) -1) /2)
        if len(weights) == 3 * num_vars +1:
            weights = weights[:len(params)]
        super(AdaBound, self).set_weights(weights)
        
    def _resource_apply_dense(self, grad, var, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype 
        coefficients = ((apply_state or {}).get((var_device, var_dtype))
                        or self._fallback_apply_state(var_device, var_dtype))
        
        
        

        lower_bound = coefficients['final_lr'] * (1. - 1. / (coefficients['gamma_power'] + 1.))
        upper_bound = coefficients['final_lr'] * (1. + 1. / (coefficients['gamma_power']))
        
        m = self.get_slot(var, 'm')
        m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
        m_t = state_ops.assign(m, m*coefficients['beta_1_t'] + m_scaled_g_values, use_locking=self._use_locking)
        
        
        v = self.get_slot(var, 'v')
        v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
        v_t = state_ops.assign(v, v * coefficients['beta_2_t'] + v_scaled_g_values, use_locking=self._use_locking)
        
        if self.amsbound:
            vhat = self.get_slot(var, 'vhat')
            vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat), use_locking=self._use_locking)
            v_sqrt = math_ops.sqrt(vhat_t)
        else:
            v_sqrt = math_ops.sqrt(v_t)
        
        step_size_bound = coefficients['lr'] / (v_sqrt + coefficients['epsilon'])
        bounded_lr = m_t * clip_by_value(step_size_bound, lower_bound, upper_bound)
        
        var_update = state_ops.assign_sub(var, bounded_lr, use_locking=self._use_locking)
        
        if self.amsbound:
            return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t])
        return control_flow_ops.group(*[var_update, m_t, v_t])
    

    def _resource_apply_sparse(self, grad, var, indcs, apply_state=None):
        var_device, var_dtype = var.device, var.dtype.base_dtype 
        coefficients = ((apply_state or {}).get((var_device, var_dtype))
                        or self._fallback_apply_state(var_device, var_dtype))
        
        
        

        lower_bound = coefficients['final_lr'] * (1. - 1. / (coefficients['gamma_power'] + 1.))
        upper_bound = coefficients['final_lr'] * (1. + 1. / (coefficients['gamma_power']))
        
        m = self.get_slot(var, 'm')
        m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
        m_t = state_ops.assign(m, m*coefficients['beta_1_t'] , use_locking=self._use_locking, name='assign_m_t')
        with ops.control_dependencies([m_t]):
            m_t = self._resource_scatter_add(m, indcs, m_scaled_g_values)
        
        
        v = self.get_slot(var, 'v')
        v_scaled_g_values = (grad * grad) * coefficients['one_minus_beta_2_t']
        v_t = state_ops.assign(v, v * coefficients['beta_2_t'] , use_locking=self._use_locking, name='assign_v_t')
        with ops.control_dependencies([v_t]):
            v_t = self._resource_scatter_add(v, indcs, v_scaled_g_values )
        
        if self.amsbound:
            vhat = self.get_slot(var, 'vhat')
            vhat_t = state_ops.assign(vhat, math_ops.maximum(v_t, vhat), use_locking=self._use_locking)
            v_sqrt = math_ops.sqrt(vhat_t)
        else:
            v_sqrt = math_ops.sqrt(v_t)
        
        step_size_bound = coefficients['lr'] / (v_sqrt + coefficients['epsilon'])
        bounded_lr = m_t * clip_by_value(step_size_bound, lower_bound, upper_bound)
        
        var_update = state_ops.assign_sub(var, bounded_lr, use_locking=self._use_locking)
        
        
        if self.amsbound:
            return control_flow_ops.group(*[var_update, m_t, v_t, vhat_t])
        return control_flow_ops.group(*[var_update, m_t, v_t])
    
    
    
    def get_config(self):
        config = super(AdaBound, self).get_config()
        config.update({
            'learning_rate' : self._serialize_hyperparameter('learning_rate'),
            'decay': self._serialize_hyperparameter('decay'),
            'beta_1': self._serialize_hyperparameter('beta_1'),
            'beta_2': self._serialize_hyperparameter('beta_2'),
            'gamma': self._serialize_hyperparameter('gamma'),
            'final_lr': self._serialize_hyperparameter('final_lr'), 
            'epsilon': self.epsilon, 
            'amsbound': self.amsbound,
        })
        return config
        