In [3]:
import numpy as np

In [4]:
class AdamW:
    def __init__(self,
                 learning_rate=0.001,
                 beta1=0.9,
                 beta2=0.999,
                 epsilon=1e-8,
                 weight_decay=0.01,
                 num_iterations=1000):

        self.learning_rate = learning_rate
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.weight_decay = weight_decay
        self.num_iterations = num_iterations

    def loss_function(self, X, y):
        m = len(y)
        predictions = self.theta * X
        return (1 / m) * np.sum((predictions - y) ** 2)

    def grad(self, X, y):
        return 2 / len(X) * X.T.dot(X.dot(self.theta) - y)

    def fit(self, X, y):
        X = np.c_[np.ones((X.shape[0], 1)), X]

        self.theta = np.zeros(X.shape[1])

        m = np.zeros(X.shape[1])
        v = np.zeros(X.shape[1])

        for iteration in range(1, self.num_iterations + 1):
            gradients = self.grad(X, y)

            m = self.beta1 * m + (1 - self.beta1) * gradients
            v = self.beta2 * v + (1 - self.beta2) * gradients ** 2

            m_hat = m / (1 - self.beta1 ** iteration)
            v_hat = v / (1 - self.beta2 ** iteration)

            self.theta -= self.learning_rate * (m_hat / (np.sqrt(v_hat) + self.epsilon) + self.weight_decay * self.theta)

            if iteration % 100 == 0:
                loss = self.loss_function(X, y)

    def predict(self, X):
        X = np.c_[np.ones((X.shape[0], 1)), X]
        return X.dot(self.theta)

In [12]:
X = np.array([[1], [2]])
y = np.array([2, 4])

In [22]:
adamw = AdamW(learning_rate=0.1, weight_decay=0.01, num_iterations=1)
adamw.fit(X, y)

In [23]:
predictions = adamw.predict(X)
print(f"Predictions: {predictions}")

Predictions: [0.2 0.3]


In [24]:
final_loss = adamw.loss_function(np.c_[np.ones((X.shape[0], 1)), X], y)
print(f"Final Loss: {final_loss}")
print(f"Final Weights (Theta): {adamw.theta}")

Final Loss: 18.435000001783333
Final Weights (Theta): [0.1 0.1]
