In [47]:
import tensorflow as tf
import numpy as np

class NAG:
    def __init__(self, params, lr=0.01, beta=0.9):
        self.lr = lr
        self.beta = beta
        self.params = params
        self.momentum = [tf.Variable(tf.zeros_like(p), trainable=False) for p in params]

    def apply_gradients(self, grads):
        for i in range(len(self.params)):
            self.momentum[i].assign(self.beta * self.momentum[i] - self.lr * grads[i])
            self.params[i].assign_add(self.momentum[i])

    def train(self, loss_fn, n_epochs=100):
        for epoch in range(n_epochs):
            temp_params = [p + self.beta * m for p, m in zip(self.params, self.momentum)]

            with tf.GradientTape() as tape:
                # assign temp to original
                for orig, temp in zip(self.params, temp_params):
                    orig.assign(temp) 
                loss = loss_fn()

            grads = tape.gradient(loss, self.params)

            # get back to original
            for orig, temp, m in zip(self.params, temp_params, self.momentum):
                orig.assign(temp - self.beta * m)

            self.apply_gradients(grads)

            if epoch % 10 == 0:
                print(f"Epoch {epoch}: Loss = {float(loss):.4f}, w = {w.numpy()}, b = {b.numpy()}")

In [49]:
# simple loss: (w - 3)^2 + (b - 1)^2
def loss_fn():
    return (w - 3)**2 + (b - 1)**2

In [51]:
w = tf.Variable([5.0], dtype=tf.float32)
b = tf.Variable([2.0], dtype=tf.float32)

optimizer = NAG(params=[w, b], lr=0.1, beta=0.9)
optimizer.train(loss_fn, n_epochs=100)

Epoch 0: Loss = 5.0000, w = [4.6], b = [1.8]
Epoch 10: Loss = 0.0132, w = [3.0821772], b = [1.0410886]
Epoch 20: Loss = 0.0078, w = [2.936804], b = [0.9684021]
Epoch 30: Loss = 0.0000, w = [2.999456], b = [0.99972796]
Epoch 40: Loss = 0.0000, w = [3.002388], b = [1.001194]
Epoch 50: Loss = 0.0000, w = [2.9999247], b = [0.99996233]
Epoch 60: Loss = 0.0000, w = [2.9999137], b = [0.99995685]
Epoch 70: Loss = 0.0000, w = [3.0000062], b = [1.0000031]
Epoch 80: Loss = 0.0000, w = [3.0000029], b = [1.0000015]
Epoch 90: Loss = 0.0000, w = [2.9999995], b = [0.9999998]
