In [1]:
import tensorflow as tf

class AdamOptimizer:
    def __init__(self, params, lr=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8):
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.params = params

        # Initialize moment estimates
        self.m = [tf.Variable(tf.zeros_like(p), trainable=False) for p in params]
        self.v = [tf.Variable(tf.zeros_like(p), trainable=False) for p in params]

    def apply_gradients(self, grads, t):
        for i in range(len(self.params)):
            self.m[i].assign(self.beta1 * self.m[i] + (1 - self.beta1) * grads[i])
            self.v[i].assign(self.beta2 * self.v[i] + (1 - self.beta2) * tf.square(grads[i]))

            # Bias-corrected estimates
            m_hat = self.m[i] / (1 - tf.pow(self.beta1, tf.cast(t + 1, tf.float32)))
            v_hat = self.v[i] / (1 - tf.pow(self.beta2, tf.cast(t + 1, tf.float32)))

            update = -self.lr * m_hat / (tf.sqrt(v_hat) + self.epsilon)
            self.params[i].assign_add(update)

    def train(self, loss_fn, n_epochs=100):
        for epoch in range(n_epochs):
            with tf.GradientTape() as tape:
                loss = loss_fn()
            grads = tape.gradient(loss, self.params)
            self.apply_gradients(grads, epoch)

            if epoch % 10 == 0:
                print(f"Epoch {epoch}: Loss = {float(loss):.4f}, w = {w.numpy()}, b = {b.numpy()}")

In [2]:
# simple loss: (w - 3)^2 + (b - 1)^2
def loss_fn():
    return (w - 3)**2 + (b - 1)**2

In [3]:
w = tf.Variable([5.0], dtype=tf.float32)
b = tf.Variable([2.0], dtype=tf.float32)

optimizer = AdamOptimizer(params=[w, b], lr=0.1, beta1=0.9, beta2=0.999)
optimizer.train(loss_fn, n_epochs=100)

Epoch 0: Loss = 5.0000, w = [4.9000006], b = [1.9000007]
Epoch 10: Loss = 1.0556, w = [3.9331152], b = [1.0051373]
Epoch 20: Loss = 0.1230, w = [3.1601958], b = [0.73729897]
Epoch 30: Loss = 0.0433, w = [2.7759838], b = [0.9932178]
Epoch 40: Loss = 0.0648, w = [2.7780929], b = [1.0962484]
Epoch 50: Loss = 0.0058, w = [2.9404826], b = [0.9850265]
Epoch 60: Loss = 0.0030, w = [3.0483694], b = [0.9699673]
Epoch 70: Loss = 0.0028, w = [3.0472353], b = [1.0167819]
Epoch 80: Loss = 0.0001, w = [3.0041025], b = [1.0026567]
Epoch 90: Loss = 0.0003, w = [2.9844701], b = [0.9927366]
