# Optimizers in ML using JAX

> Optimization refers to the process of minimizing the loss function by systematically updating the network weights.

In this notebook, I implement a few popular optimizers from scratch for a simple model i.e., Linear Regression on a dataset of 5 features. The goal of this notebook was to understand how these optimizers work under the hood and try to to do a toy implementation myself. I also use a bit of JAX magic to perform the differentiation of the loss function w.r.t to the Weights and the Bias without explicitly writing their derivatives as a separate function. This can help to generalize this notebook for other types of loss functions as well.

The optimizers I have implemented are - 
* Batch Gradient Descent
* Batch Gradient Descent + Momentum
* Nesterov Accelerated Momentum
* Adagrad
* RMSprop
* Adam
* Adamax
* Nadam
* Adabelief

References -
* https://ruder.io/optimizing-gradient-descent/
* https://theaisummer.com/optimization/

## Libraries

In [1]:
import jax
import jax.numpy as np
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

## Helper Functions

In [2]:
def get_data():
    """Create dataset for linear regression using 5 features. Set fix bias = 2.0
    
    Returns:
        X: training data
        X_test: testing data
        y: train labels
        y_test: test labels
        coef: true weight matrix (coefficients) for the dataset
    """
    # create our dataset. Set fix bias of 2.0 and return weights (coef=True)
    X, y, coef = make_regression(n_features=5, coef=True, bias=2.0)
    X, X_test, y, y_test = train_test_split(X, y)
    return (X, X_test, y, y_test, coef)

def J(X, w, b, y):
    """Cost function for a linear regression. A forward pass of our model.

    Args:
        X: a features matrix.
        w: weights (a column vector).
        b: a bias.
        y: a target vector.

    Returns:
        scalar: a cost of this solution.    
    """
    y_hat = np.dot(X, w) + b # Predict values.
    return ((y_hat - y)**2).mean() # Return cost.

## Batch Gradient Descent

In [3]:
from optimizers.batch_gradient_descent import *
X, X_test, y, y_test, coef = get_data()
params = batch_gradient_descent(J, X, y)

print("True weights =", coef)
print("Calculated weights =", params['w'])
print("True bias = 2.0\tCalculated bias = {:.6f}".format(params['b']))
print("Test loss: {:.9f}".format(J(X_test, params['w'], params['b'], y_test)))

True weights = [30.68836406  8.19425997 52.32139276 98.63300635 35.70556397]
Calculated weights = [30.688354  8.194252 52.321373 98.632965 35.705536]
True bias = 2.0	Calculated bias = 2.000007
Test loss: 0.000000006


## Batch Gradient Descent + Momentum

In [4]:
from optimizers.momentum import *
X, X_test, y, y_test, coef = get_data()
params = momentum(J, X, y)

print("True weights =", coef)
print("Calculated weights =", params['w'])
print("True bias = 2.0\tCalculated bias = {:.6f}".format(params['b']))
print("Test loss: {:.9f}".format(J(X_test, params['w'], params['b'], y_test)))

True weights = [79.89865004 69.54844035  5.00857317 17.07146713 95.33941246]
Calculated weights = [79.89865  69.54844   5.00857  17.071468 95.33941 ]
True bias = 2.0	Calculated bias = 1.999996
Test loss: 0.000000000


## Nesterov accelerated momentum

In [5]:
from optimizers.nesterov_momentum import *
X, X_test, y, y_test, coef = get_data()
params = nesterov_momentum(J, X, y)

print("True weights =", coef)
print("Calculated weights =", params['w'])
print("True bias = 2.0\tCalculated bias = {:.6f}".format(params['b']))
print("Test loss: {:.9f}".format(J(X_test, params['w'], params['b'], y_test)))

True weights = [42.60473498 65.07277781 34.78011253 62.33332432 73.84550563]
Calculated weights = [42.604736 65.07278  34.780113 62.333324 73.845505]
True bias = 2.0	Calculated bias = 2.000000
Test loss: 0.000000000


## Adagrad

In [6]:
from optimizers.adagrad import *
X, X_test, y, y_test, coef = get_data()
params = adagrad(J, X, y)

print("True weights =", coef)
print("Calculated weights =", params['w'])
print("True bias = 2.0\tCalculated bias = {:.6f}".format(params['b']))
print("Test loss: {:.9f}".format(J(X_test, params['w'], params['b'], y_test)))

True weights = [14.77439555 65.65138433 13.16405989 56.48558691 90.74394822]
Calculated weights = [14.773779 65.65003  13.163625 56.486057 90.73908 ]
True bias = 2.0	Calculated bias = 1.999916
Test loss: 0.000044874


## RMSprop

In [7]:
from optimizers.rmsprop import *
X, X_test, y, y_test, coef = get_data()
params = rmsprop(J, X, y)

print("True weights =", coef)
print("Calculated weights =", params['w'])
print("True bias = 2.0\tCalculated bias = {:.6f}".format(params['b']))
print("Test loss: {:.9f}".format(J(X_test, params['w'], params['b'], y_test)))

True weights = [41.97869789 49.60297967 28.45822788 46.6347428  48.8575031 ]
Calculated weights = [41.9287   49.552982 28.40823  46.58474  48.807503]
True bias = 2.0	Calculated bias = 2.050000
Test loss: 0.013983894


## Adam

In [8]:
from optimizers.adam import *
X, X_test, y, y_test, coef = get_data()
params = adam(J, X, y)

print("True weights =", coef)
print("Calculated weights =", params['w'])
print("True bias = 2.0\tCalculated bias = {:.6f}".format(params['b']))
print("Test loss: {:.9f}".format(J(X_test, params['w'], params['b'], y_test)))

True weights = [91.82405194  5.98911475 51.21966463 76.09491592  3.97291692]
Calculated weights = [91.82391    5.989142  51.21963   76.09506    3.9729562]
True bias = 2.0	Calculated bias = 2.000038
Test loss: 0.000000036


## Adamax

In [9]:
from optimizers.adamax import *
X, X_test, y, y_test, coef = get_data()
params = adamax(J, X, y)

print("True weights =", coef)
print("Calculated weights =", params['w'])
print("True bias = 2.0\tCalculated bias = {:.6f}".format(params['b']))
print("Test loss: {:.9f}".format(J(X_test, params['w'], params['b'], y_test)))

True weights = [70.74526421 40.12974786  4.90109456 56.99685724 58.66588867]
Calculated weights = [70.745094  40.12968    4.9011374 56.99679   58.666    ]
True bias = 2.0	Calculated bias = 1.999996
Test loss: 0.000000061


## Nadam

In [10]:
from optimizers.nadam import *
X, X_test, y, y_test, coef = get_data()
params = nadam(J, X, y)

print("True weights =", coef)
print("Calculated weights =", params['w'])
print("True bias = 2.0\tCalculated bias = {:.6f}".format(params['b']))
print("Test loss: {:.9f}".format(J(X_test, params['w'], params['b'], y_test)))

True weights = [ 1.58186482 19.09535888 16.30502928 81.38680601 73.46240111]
Calculated weights = [ 1.581854 19.095356 16.305021 81.38672  73.46231 ]
True bias = 2.0	Calculated bias = 1.999992
Test loss: 0.000000011


## Adabelief

In [11]:
from optimizers.adabelief import *
X, X_test, y, y_test, coef = get_data()
params = adabelief(J, X, y)

print("True weights =", coef)
print("Calculated weights =", params['w'])
print("True bias = 2.0\tCalculated bias = {:.6f}".format(params['b']))
print("Test loss: {:.9f}".format(J(X_test, params['w'], params['b'], y_test)))

True weights = [71.25219567 15.10311738  7.49360451 90.84284877 87.45135942]
Calculated weights = [71.25219  15.10311   7.493615 90.84281  87.45141 ]
True bias = 2.0	Calculated bias = 1.999991
Test loss: 0.000000003
