In [1]:
import numpy as np

In [5]:
class NAdam:
    def __init__(self, learning_rate=0.002, beta1=0.9, beta2=0.999, epsilon=1e-8):
        self.learning_rate = learning_rate
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon

        self.t = 0

    def update(self, params, grads):
        m = [np.zeros_like(param) for param in params]
        v = [np.zeros_like(param) for param in params]

        updated_params = []
        self.t += 1

        for i, (param, grad) in enumerate(zip(params, grads)):
            m[i] = self.beta1 * m[i] + (1 - self.beta1) * grad
            v[i] = self.beta2 * v[i] + (1 - self.beta2) * (grad ** 2)

            m_hat = m[i] / (1 - self.beta1 ** self.t)
            v_hat = v[i] / (1 - self.beta2 ** self.t)

            m_nesterov = (self.beta1 * m_hat) + ((1 - self.beta1) * grad)

            param_update = param - self.learning_rate * m_nesterov / (np.sqrt(v_hat) + self.epsilon)
            updated_params.append(param_update)

        return updated_params

In [6]:
params = np.array([0.0])
grads = np.array([-6.0])

In [7]:
nadam = NAdam()
nadam.update(params, grads)

[0.001999999996666667]