# Optimizers


Covers various topics on optimization in the context of Deep Learning.


In [1]:
import jax.numpy as jnp
import jax.random as random
import random as py_random
from jax import grad
import plotly.express as px
import pandas as pd
import plotly.io as pio
import plotly.graph_objects as go
import helpers as hp
from optimizers import sgd, sgd_batched, RMSProp, Adam
from plotly.subplots import make_subplots


In [2]:
pio.templates.default = "ggplot2"

In [3]:
key = random.PRNGKey(577)

In [4]:
x = jnp.arange(0, 10, 0.1)


def f(x):
    return 4 * x

In [5]:
f_dx = grad(f)

In [6]:
df = pd.DataFrame()
df["x"] = list(x)
df["f"] = list(f(x))
df["df"] = list(f_dx(_x) for _x in x)
df.head()

Unnamed: 0,x,f,df
0,0.0,0.0,4.0
1,0.1,0.4,4.0
2,0.2,0.8,4.0
3,0.3,1.2,4.0
4,0.4,1.6,4.0


In [7]:
fig = px.line(df, x="x", y=["f", "df"])
fig.show()

$$ \text{MSE} = \frac{1}{n} \sum\_{i=1}^n \left(y_i - \hat{y}\_i \right)^2 $$


If we naively define the MSE function it would look like this:


In [8]:
def mean_squared_error_naive(true, pred):
    return jnp.mean((true - pred) ** 2)

But if we want to take the derivative of this function using JAX's `grad` function, then we need to explicitly state all the variables in the function signature as such:


In [9]:
def mean_squared_error(theta, x, true):
    pred = x.dot(theta)
    return jnp.mean((true - pred) ** 2)

Suppose $\hat{y} = 2*x$. It is typical to use $w$ or $\theta$ as the parameter of our function, we will stick with $\theta$ as it extends naturally to probability. In our function we only have one parameter, hence $\theta = 2$.

We are interested in perturbing this parameter to minimize the loss over our observed (or true) sample set. To do this we will measure the impact the parameter had on the loss function, seeking to minimize this impact. To measure this impact, lets take the derivate of MSE using our parameterized function. First we will re-write the MSE loss function and then take the derivate with respect to the parameter $\theta$:

$$
\begin{aligned}
\text{MSE} &= \frac{1}{n} \sum_{i=1}^n \left(y_i - \hat{y}_i \right)^2 = \frac{1}{n} \sum_{i=1}^n \left(y_i - \theta x \right)^2 \\
D_{\theta} \text{MSE} &= D_{\theta} \frac{1}{n} \sum_{i=1}^n \left(y_i - \theta x  \right)^2 \\
&=\frac{1}{n} \sum_{i=1}^n 2 \left(y_i - \theta x \right) * (-1)x \\
&=-\frac{2}{n} \sum_{i=1}^n \left(y_i - \theta x \right)x
\end{aligned}
$$


In [10]:
def mean_squared_error_dx_man(theta, x, true):
    return -2 * jnp.mean((true - theta * x) * x)

In [11]:
loss_theta_man = mean_squared_error_dx_man(2.0, x, f(x)).item()
print(f"Loss (theta = 2): {loss_theta_man:.3f}")

Loss (theta = 2): -131.340


In [12]:
mean_squared_error_d_theta = grad(mean_squared_error, argnums=0)

In [13]:
loss_theta = mean_squared_error_d_theta(2.0, x, f(x)).item()
print(f"Loss (theta = 2): {loss_theta:.3f}")

Loss (theta = 2): -131.340


### Stochastic Gradient Descent


$$ \theta*{t+1} = \theta_t - \alpha D*\theta \text{MSE} $$


In [14]:
alpha = 0.01
theta = random.uniform(key)
n_iterations = 20
fig = go.Figure()
steps = []
x = jnp.arange(0, 10, 0.1)


def f(x):
    return 4 * x**3 + 3


print(f"Initial theta: {theta}")

fig.add_trace(go.Scatter(visible=False, name="True Function", x=x, y=f(x)))

prev_loss = 0
for i in range(n_iterations):
    ###########################################################################
    theta -= alpha * mean_squared_error_d_theta(theta, x, f(x))
    loss = mean_squared_error(theta, x, f(x)).item()
    if jnp.abs(loss - prev_loss) < 1e-6:
        print(f"Converged at iteration {i + 1}")
        break
    prev_loss = loss
    ###########################################################################
    fig.add_trace(
        go.Scatter(visible=False, name=f"Iteration {(i + 1):4d}", x=x, y=theta * x)
    )
    step = dict(
        method="update",
        args=[
            {"visible": [False] * (n_iterations + 1)},
            {"title": "SGD Iteration: " + str(i + 1)},
        ],
    )
    step["args"][0]["visible"][0] = True
    step["args"][0]["visible"][i + 1] = True
    steps.append(step)
    ###########################################################################
    print(f"Iteration {(i + 1):4d}: y_pred = [{theta:.3f}][x1].T, loss = {loss:.3f}")

sliders = [
    dict(
        active=n_iterations,
        currentvalue={"prefix": "SGD: "},
        pad={"t": 50},
        steps=steps,
    )
]

fig.data[0].visible = True
fig.data[1].visible = True
fig.update_layout(sliders=sliders)
fig.show()

Initial theta: 0.15920352935791016
Iteration    1: y_pred = [156.378][x1].T, loss = 570793.625
Iteration    2: y_pred = [210.008][x1].T, loss = 377615.000
Iteration    3: y_pred = [228.420][x1].T, loss = 354848.000
Iteration    4: y_pred = [234.740][x1].T, loss = 352164.781
Iteration    5: y_pred = [236.910][x1].T, loss = 351848.594
Iteration    6: y_pred = [237.655][x1].T, loss = 351811.281
Iteration    7: y_pred = [237.911][x1].T, loss = 351806.875
Iteration    8: y_pred = [237.998][x1].T, loss = 351806.406
Iteration    9: y_pred = [238.029][x1].T, loss = 351806.312
Converged at iteration 10


In [15]:
def f_2d(x, y):
    return 3 * x**2 + 9 * y**2


def ackley_fn(x, y):
    return (
        -20 * jnp.exp(-0.2 * jnp.sqrt(0.5 * (x**2 + y**2)))
        - jnp.exp(0.5 * (jnp.cos(2 * jnp.pi * x) + jnp.cos(2 * jnp.pi * y)))
        + jnp.e
        + 20
    )

#### Ackley Function

https://en.wikipedia.org/wiki/Ackley_function


$$
\begin{aligned}
f(x, y) &= -20\exp{\left[-0.2\sqrt{0.5(x^2+y^2)} \right]} \\
        &-\exp{\left[0.5*(\cos(2 \pi x) + \cos(2 \pi y)) \right] + e + 20}
\end{aligned}
$$


In [16]:
def mse(theta1, theta2, f_pred, x, y, true):
    pred = f_pred(theta1, theta2, x, y)
    return jnp.mean((true - pred) ** 2)

#### SGD

https://en.wikipedia.org/wiki/Stochastic_gradient_descent


In [17]:
# lr = 1e-4
# max_n_iterations = 50
# convergence_criteria = 1e-2
# true_thetas = (3, 9)

# f_pred = lambda theta1, theta2, x, y: (  # noqa: E731
#     theta1 * x**2 + theta2 * y**2
# )

# # Generate examples from the true function
# n_examples = 50
# examples = hp.generate_examples(f_2d, n_examples, key)

# sgd_history = sgd(
#     f_pred,
#     lr,
#     max_n_iterations,
#     mse,
#     examples,
#     convergence_criteria,
#     key,
#     return_history=True,
# )

# fig = hp.create_optimizer_figure_2d(
#     f_true=f_2d,
#     f_pred=f_pred,
#     loss_fn=mse,
#     theta1s=sgd_history["theta1s"],
#     theta2s=sgd_history["theta2s"],
#     losses=sgd_history["losses"],
#     n_iterations=sgd_history["max_n_iterations"],
#     perf_profiling=False,
# )

# fig.show()

In [18]:
# lr = 1e-4
# max_n_iterations = 50
# convergence_criteria = 1e-2
# true_thetas = (3, 9)

# f_pred = lambda theta1, theta2, x, y: (  # noqa: E731
#     theta1 * x**2 + theta2 * y**2
# )

# # Generate examples from the true function
# n_examples = 50
# examples = hp.generate_examples(f_2d, n_examples, key)

# sgd_history = sgd_batched(
#     f_pred,
#     lr,
#     16,
#     max_n_iterations,
#     mse,
#     examples,
#     convergence_criteria,
#     key,
#     return_history=True,
# )

# fig = hp.create_optimizer_figure_2d(
#     f_true=f_2d,
#     f_pred=f_pred,
#     loss_fn=mse,
#     theta1s=sgd_history["theta1s"],
#     theta2s=sgd_history["theta2s"],
#     losses=sgd_history["losses"],
#     n_iterations=sgd_history["max_n_iterations"],
#     perf_profiling=False,
# )

# fig.show()

In [19]:
# lr = 1e-1
# max_n_iterations = 500
# convergence_criteria = 1e-2
# true_thetas = (3, 9)
# f_pred = lambda theta1, theta2, x, y: (  # noqa: E731
#     theta1 * x**2 + theta2 * y**2
# )

# # Generate examples from the true function
# n_examples = 100
# examples = hp.generate_examples(f_2d, n_examples, key)

# sgd_history = RMSProp(
#     f_pred=f_pred,
#     lr=lr,
#     weight_decay=0.1,
#     smoothing_constant=0.99,
#     momentum=0.9,
#     centered=False,
#     max_n_iterations=max_n_iterations,
#     loss_fn=mse,
#     examples=examples,
#     convergence_criteria=convergence_criteria,
#     key=key,
#     return_history=True,
# )

# fig = hp.create_optimizer_figure_2d(
#     f_true=f_2d,
#     f_pred=f_pred,
#     loss_fn=mse,
#     theta1s=sgd_history["theta1s"],
#     theta2s=sgd_history["theta2s"],
#     losses=sgd_history["losses"],
#     n_iterations=sgd_history["max_n_iterations"],
#     perf_profiling=False,
# )

# fig.show()

In [20]:
# lr = 1e-1
# max_n_iterations = 5000
# convergence_criteria = 1e-2
# true_thetas = (3, 9)

# f_pred = lambda theta1, theta2, x, y: (  # noqa: E731
#     theta1 * x**2 + theta2 * y**2
# )

# # Generate examples from the true function
# n_examples = 100
# examples = hp.generate_examples(f_2d, n_examples, key)

# sgd_history = Adam(
#     f_pred=f_pred,
#     lr=lr,
#     weight_decay=0,
#     betas=(0.9, 0.999),
#     max_n_iterations=max_n_iterations,
#     loss_fn=mse,
#     examples=examples,
#     convergence_criteria=convergence_criteria,
#     key=key,
#     return_history=True,
# )

# fig = hp.create_optimizer_figure_2d(
#     f_true=f_2d,
#     f_pred=f_pred,
#     loss_fn=mse,
#     theta1s=sgd_history["theta1s"],
#     theta2s=sgd_history["theta2s"],
#     losses=sgd_history["losses"],
#     n_iterations=sgd_history["max_n_iterations"],
#     perf_profiling=False,
# )

# fig.show()

In [21]:
xs = jnp.arange(-10, 10, 0.1)
ys = jnp.arange(-10, 10, 0.1)
zs = jnp.array([ackley_fn(x, ys) for x in xs])
true_function_surface = go.Surface(
    z=zs, x=xs, y=ys, colorscale="Blues", showscale=False
)

ack_fig = go.Figure()
ack_fig.update_layout(height=800)
ack_fig.add_trace(true_function_surface)
ack_fig.update_layout(title="Ackley Function")
ack_fig.show()

In [22]:
def has_converged(true_val: float, test_val: float, convergence_criterion: float = 1e-2) -> bool:
    """Test if the parameter is within the convergence criteria."""
    return jnp.abs(true_val - test_val) < convergence_criterion

In [23]:
# SGD
lr = 1e-1
max_n_iterations = 20
convergence_criteria = 1e-2

loss = 0
x_grad = grad(ackley_fn, argnums=0)
y_grad = grad(ackley_fn, argnums=1)
gradients = (x_grad, y_grad)
min_x, min_y = 0, 0
x, y = 4., 4.
x_path, y_path = [], []

for i in range(max_n_iterations):
    x -= lr * x_grad(x, y)
    y -= lr * y_grad(x, y)
    x_path.append(x)
    y_path.append(y)
    if i % 10 == 0:
        print(f"Iteration: {i:2d} | (x, y) = ({x:.3f}, {y:.3f})")

    if has_converged(min_x, x) and has_converged(min_y, y):
        print(f"Converged at iteration {i + 1}")
        max_n_iterations = i
        break

    

Iteration:  0 | (x, y) = (3.910, 3.908)
Iteration: 10 | (x, y) = (3.653, 3.985)


In [24]:
hp.create_optimizer_figure_true(
    fn=ackley_fn,
    title="Ackley Function (SGD)",
    x_path=x_path,
    y_path=y_path,
    n_iterations=max_n_iterations,
    perf_profiling=False,
)

22


In [49]:
# Adam
lr = 1e-2
weight_decay = 0.7
max_n_iterations = 1000
convergence_criteria = 1e-2

loss = 0
betas=(0.99, 0.999)
beta1, beta2 = betas
min_x, min_y = 0, 0
variables = [4., 4.]
x_path, y_path = [], []
x_path.append(variables[0])
y_path.append(variables[1])
n_params = 2
first_moments = [0, 0]
second_moments = [0, 0]
bias_corrected_first_moments = [0, 0]
bias_corrected_second_moments = [0, 0]


for i in range(1, max_n_iterations):

    for param_idx in range(n_params):
            gradient = grad(ackley_fn, argnums=param_idx)(*variables)

            v = variables[param_idx]
            fm = first_moments[param_idx]
            sm = second_moments[param_idx]
            bc_fm = bias_corrected_first_moments[param_idx]
            bc_sm = bias_corrected_second_moments[param_idx]

            # Update the parameters

            if weight_decay != 0:
                gradient += weight_decay * theta

            fm = beta1 * fm + (1 - beta1) * gradient
            sm = beta2 * sm + (1 - beta2) * gradient**2

            # bias corrections
            bc_fm = fm / (1 - beta1**i)
            bc_sm = sm / (1 - beta2**i)

            # 1e-8 is added to avoid division by zero
            v -= lr * bc_fm / (jnp.sqrt(bc_sm) + 1e-8)

            variables[param_idx] = v
            first_moments[param_idx] = fm
            second_moments[param_idx] = sm
            bias_corrected_first_moments[param_idx] = bc_fm
            bias_corrected_second_moments[param_idx] = bc_sm

    x_path.append(variables[0])
    y_path.append(variables[1])

    if i % 50 == 0:
        print(f"Iteration: {i:2d} | (x, y) = ({variables[0]:.3f}, {variables[1]:.3f})")

    if has_converged(min_x, variables[0]) and has_converged(min_y, variables[1]):
        print(f"Converged at iteration {i + 1}")
        max_n_iterations = i
        break

Iteration: 10 | (x, y) = (3.900, 3.900)
Iteration: 20 | (x, y) = (3.800, 3.800)
Iteration: 30 | (x, y) = (3.700, 3.700)
Iteration: 40 | (x, y) = (3.600, 3.600)
Iteration: 50 | (x, y) = (3.500, 3.500)
Iteration: 60 | (x, y) = (3.400, 3.400)
Iteration: 70 | (x, y) = (3.300, 3.300)
Iteration: 80 | (x, y) = (3.200, 3.200)
Iteration: 90 | (x, y) = (3.099, 3.099)
Iteration: 100 | (x, y) = (2.999, 2.999)
Iteration: 110 | (x, y) = (2.899, 2.899)
Iteration: 120 | (x, y) = (2.798, 2.798)
Iteration: 130 | (x, y) = (2.698, 2.698)
Iteration: 140 | (x, y) = (2.598, 2.598)
Iteration: 150 | (x, y) = (2.498, 2.498)
Iteration: 160 | (x, y) = (2.398, 2.398)
Iteration: 170 | (x, y) = (2.298, 2.298)
Iteration: 180 | (x, y) = (2.198, 2.198)
Iteration: 190 | (x, y) = (2.098, 2.098)
Iteration: 200 | (x, y) = (1.997, 1.997)
Iteration: 210 | (x, y) = (1.897, 1.897)
Iteration: 220 | (x, y) = (1.797, 1.797)
Iteration: 230 | (x, y) = (1.697, 1.697)
Iteration: 240 | (x, y) = (1.597, 1.597)
Iteration: 250 | (x, y) =

In [50]:
hp.create_optimizer_figure_true(
    fn=ackley_fn,
    title="Ackley Function (Adam)",
    x_path=x_path,
    y_path=y_path,
    n_iterations=max_n_iterations,
    perf_profiling=False,
)

401
