# Optimizers


Covers various topics on optimization in the context of Deep Learning.


In [19]:
import jax.numpy as jnp
import jax.random as random
import random as py_random
from jax import grad
import plotly.express as px
import pandas as pd
import plotly.io as pio
import plotly.graph_objects as go
import helpers as hp

In [2]:
pio.templates.default = "ggplot2"

In [3]:
key = random.PRNGKey(577)

In [4]:
x = jnp.arange(0, 10, 0.1)


def f(x):
    return 4 * x

In [5]:
f_dx = grad(f)

In [None]:
df = pd.DataFrame()
df["x"] = list(x)
df["f"] = list(f(x))
df["df"] = list(f_dx(_x) for _x in x)
df.head()

In [None]:
fig = px.line(df, x="x", y=["f", "df"])
fig.show()

$$ \text{MSE} = \frac{1}{n} \sum\_{i=1}^n \left(y_i - \hat{y}\_i \right)^2 $$


If we naively define the MSE function it would look like this:


In [8]:
def mean_squared_error_naive(true, pred):
    return jnp.mean((true - pred) ** 2)

But if we want to take the derivative of this function using JAX's `grad` function, then we need to explicitly state all the variables in the function signature as such:


In [9]:
def mean_squared_error(theta, x, true):
    pred = x.dot(theta)
    return jnp.mean((true - pred) ** 2)

Suppose $\hat{y} = 2*x$. It is typical to use $w$ or $\theta$ as the parameter of our function, we will stick with $\theta$ as it extends naturally to probability. In our function we only have one parameter, hence $\theta = 2$.

We are interested in perturbing this parameter to minimize the loss over our observed (or true) sample set. To do this we will measure the impact the parameter had on the loss function, seeking to minimize this impact. To measure this impact, lets take the derivate of MSE using our parameterized function. First we will re-write the MSE loss function and then take the derivate with respect to the parameter $\theta$:

$$
\begin{aligned}
\text{MSE} &= \frac{1}{n} \sum_{i=1}^n \left(y_i - \hat{y}_i \right)^2 = \frac{1}{n} \sum_{i=1}^n \left(y_i - \theta x \right)^2 \\
D_{\theta} \text{MSE} &= D_{\theta} \frac{1}{n} \sum_{i=1}^n \left(y_i - \theta x  \right)^2 \\
&=\frac{1}{n} \sum_{i=1}^n 2 \left(y_i - \theta x \right) * (-1)x \\
&=-\frac{2}{n} \sum_{i=1}^n \left(y_i - \theta x \right)x
\end{aligned}
$$


In [10]:
def mean_squared_error_dx_man(theta, x, true):
    return -2 * jnp.mean((true - theta * x) * x)

In [None]:
loss_theta_man = mean_squared_error_dx_man(2.0, x, f(x)).item()
print(f"Loss (theta = 2): {loss_theta_man:.3f}")

In [12]:
mean_squared_error_d_theta = grad(mean_squared_error, argnums=0)

In [None]:
loss_theta = mean_squared_error_d_theta(2.0, x, f(x)).item()
print(f"Loss (theta = 2): {loss_theta:.3f}")

### Stochastic Gradient Descent


$$ \theta*{t+1} = \theta_t - \alpha D*\theta \text{MSE} $$


In [None]:
alpha = 0.01
theta = random.uniform(key)
n_iterations = 20
fig = go.Figure()
steps = []
x = jnp.arange(0, 10, 0.1)


def f(x):
    return 4 * x**3 + 3


print(f"Initial theta: {theta}")

fig.add_trace(go.Scatter(visible=False, name="True Function", x=x, y=f(x)))

prev_loss = 0
for i in range(n_iterations):
    ###########################################################################
    theta -= alpha * mean_squared_error_d_theta(theta, x, f(x))
    loss = mean_squared_error(theta, x, f(x)).item()
    if jnp.abs(loss - prev_loss) < 1e-6:
        print(f"Converged at iteration {i + 1}")
        break
    prev_loss = loss
    ###########################################################################
    fig.add_trace(
        go.Scatter(visible=False, name=f"Iteration {(i + 1):4d}", x=x, y=theta * x)
    )
    step = dict(
        method="update",
        args=[
            {"visible": [False] * (n_iterations + 1)},
            {"title": "SGD Iteration: " + str(i + 1)},
        ],
    )
    step["args"][0]["visible"][0] = True
    step["args"][0]["visible"][i + 1] = True
    steps.append(step)
    ###########################################################################
    print(f"Iteration {(i + 1):4d}: y_pred = [{theta:.3f}][x1].T, loss = {loss:.3f}")

sliders = [
    dict(
        active=n_iterations,
        currentvalue={"prefix": "SGD: "},
        pad={"t": 50},
        steps=steps,
    )
]

fig.data[0].visible = True
fig.data[1].visible = True
fig.update_layout(sliders=sliders)
fig.show()

In [15]:
# def f_2d(x, y):
#     return 3 * x**2 + 9 * y**2


def f_2d(x, y):
    return (
        -20 * jnp.exp(-0.2 * jnp.sqrt(0.5 * (x**2 + y**2)))
        - jnp.exp(0.5 * (jnp.cos(2 * jnp.pi * x) + jnp.cos(2 * jnp.pi * y)))
        + jnp.e
        + 20
    )

#### Ackley Function

https://en.wikipedia.org/wiki/Ackley_function


$$
\begin{aligned}
f(x, y) &= -20\exp{\left[-0.2\sqrt{0.5(x^2+y^2)} \right]} \\
        &-\exp{\left[0.5*(\cos(2 \pi x) + \cos(2 \pi y)) \right] + e + 20}
\end{aligned}
$$


In [16]:
def loss_2d(theta1, theta2, x, y, true):
    pred = theta1 * x**2 + theta2 * y**2 + jnp.exp(1) + theta2
    return jnp.mean((true - pred) ** 2)

#### SGD

https://en.wikipedia.org/wiki/Stochastic_gradient_descent


In [17]:
def sgd(
    lr: float,
    max_n_iterations: int,
    true_thetas: tuple[float | int, float | int],
    loss_fn,
    examples: list[tuple[float, float, float]],
    convergence_criteria: float,
    key,
    return_history: bool = False,
) -> dict | tuple[float, float]:
    r"""Perform Stochastic Gradient Descent.

    Args:
        lr: The learning rate.
        max_n_iterations: The maximum number of iterations.
        true_thetas: The true thetas.
        loss_fn: The loss function.
        examples: The examples.
        convergence_criteria: The convergence criteria.
        key: The random key.
        return_history: Whether to return the history.

    Returns:
        theta1s: The theta1 values recorded during optimization.
        theta2s: The theta2 values recorded during optimization.
        losses: The losses recorded during optimization.
        f_preds: The predicted functions.
        max_n_iterations: The maximum number of iterations

        or

        theta1: The final theta1 value.
        theta2: The final theta2 value
    """
    print(
        f"Running SGD with learning rate: {lr} and max iterations: {max_n_iterations}"
    )
    print(f"True thetas: {true_thetas}")

    loss_theta1 = grad(loss_fn, argnums=0)
    loss_theta2 = grad(loss_fn, argnums=1)
    theta1, theta2 = random.uniform(key), random.uniform(key)
    prev_loss = 0
    theta1s, theta2s, losses, f_preds = [], [], [], []
    for i in range(max_n_iterations):
        key, subkey = random.split(key)
        idx = random.randint(subkey, (1,), 0, len(examples))
        x, y, z = examples[idx[0]]
        theta1 -= lr * loss_theta1(theta1, theta2, x, y, z)
        theta2 -= lr * loss_theta2(theta1, theta2, x, y, z)
        loss = loss_2d(theta1, theta2, x, y, z).item()
        theta1s.append(theta1.item())
        theta2s.append(theta2.item())
        losses.append(loss)
        f_preds.append(
            lambda x, y, theta1, theta2: theta1 * x**2
            + theta2 * y**2
            + jnp.exp(1)
            + theta2
        )
        if jnp.abs(loss - prev_loss) < convergence_criteria:
            print(f"Converged at iteration {i + 1}!")
            print(
                f"y_pred = [{theta1:.3f}, {theta2:.3f}][x1^2, x2^2].T + theta2 + e, loss = {loss:.3f}"
            )
            max_n_iterations = i
            break
        prev_loss = loss
        print(
            f"Iteration {(i + 1):4d}: y_pred = [{theta1:.3f}, {theta2:.3f}][x1^2, x2^2].T + theta2 + e, loss = {loss:.3f}"
        )

    if return_history:
        return {
            "theta1s": theta1s,
            "theta2s": theta2s,
            "losses": losses,
            "f_preds": f_preds,
            "max_n_iterations": max_n_iterations,
        }
    else:
        return theta1s[-1], theta2s[-1]


In [None]:
lr = 1e-5
max_n_iterations = 50
convergence_criteria = 1e-2
true_thetas = (3, 9)

# Generate examples from the true function
n_examples = 100
examples = hp.generate_examples(f_2d, n_examples, key)

sgd_history = sgd(
    lr,
    max_n_iterations,
    true_thetas,
    loss_2d,
    examples,
    convergence_criteria,
    key,
    return_history=True,
)

fig = hp.create_optimizer_figure_2d(
    f_2d,
    loss_2d,
    true_thetas,
    convergence_criteria,
    sgd_history["theta1s"],
    sgd_history["theta2s"],
    sgd_history["losses"],
    sgd_history["f_preds"],
    sgd_history["max_n_iterations"],
    perf_profiling=False,
)

fig.show()

In [28]:
def sgd_batched(
    lr: float,
    batch_size: int,
    max_n_iterations: int,
    true_thetas: tuple[float | int, float | int],
    loss_fn,
    examples: list[tuple[float, float, float]],
    convergence_criteria: float,
    key,
    return_history: bool = False,
) -> dict | tuple[float, float]:
    r"""Perform Stochastic Gradient Descent.

    Args:
        lr: The learning rate.
        max_n_iterations: The maximum number of iterations.
        true_thetas: The true thetas.
        loss_fn: The loss function.
        examples: The examples.
        convergence_criteria: The convergence criteria.
        key: The random key.
        return_history: Whether to return the history.

    Returns:
        theta1s: The theta1 values recorded during optimization.
        theta2s: The theta2 values recorded during optimization.
        losses: The losses recorded during optimization.
        f_preds: The predicted functions.
        max_n_iterations: The maximum number of iterations

        or

        theta1: The final theta1 value.
        theta2: The final theta2 value
    """
    print(
        f"Running SGD with learning rate: {lr} and max iterations: {max_n_iterations}"
    )
    print(f"True thetas: {true_thetas}")

    loss_theta1 = grad(loss_fn, argnums=0)
    loss_theta2 = grad(loss_fn, argnums=1)
    theta1, theta2 = random.uniform(key), random.uniform(key)
    prev_loss = 0
    theta1s, theta2s, losses, f_preds = [], [], [], []
    for i in range(max_n_iterations):
        theta1_sum_loss, theta2_sum_loss = 0, 0
        for _ in range(batch_size):
            x, y, z = py_random.choices(examples)[0]
            theta1_sum_loss += loss_theta1(theta1, theta2, x, y, z)
            theta2_sum_loss += loss_theta2(theta1, theta2, x, y, z)

        theta1_loss = theta1_sum_loss / batch_size
        theta2_loss = theta2_sum_loss / batch_size

        theta1 -= lr * theta1_loss
        theta2 -= lr * theta2_loss
        loss = loss_2d(theta1, theta2, x, y, z).item()
        theta1s.append(theta1.item())
        theta2s.append(theta2.item())
        losses.append(loss)
        f_preds.append(
            lambda x, y, theta1, theta2: theta1 * x**2
            + theta2 * y**2
            + jnp.exp(1)
            + theta2
        )
        if jnp.abs(loss - prev_loss) < convergence_criteria:
            print(f"Converged at iteration {i + 1}!")
            print(
                f"y_pred = [{theta1:.3f}, {theta2:.3f}][x1^2, x2^2].T + theta2 + e, loss = {loss:.3f}"
            )
            max_n_iterations = i
            break
        prev_loss = loss
        print(
            f"Iteration {(i + 1):4d}: y_pred = [{theta1:.3f}, {theta2:.3f}][x1^2, x2^2].T + theta2 + e, loss = {loss:.3f}"
        )

    if return_history:
        return {
            "theta1s": theta1s,
            "theta2s": theta2s,
            "losses": losses,
            "f_preds": f_preds,
            "max_n_iterations": max_n_iterations,
        }
    else:
        return theta1s[-1], theta2s[-1]


In [None]:
lr = 1e-5
max_n_iterations = 50
convergence_criteria = 1e-2
true_thetas = (3, 9)

# Generate examples from the true function
n_examples = 100
examples = hp.generate_examples(f_2d, n_examples, key)

sgd_history = sgd_batched(
    lr,
    32,
    max_n_iterations,
    true_thetas,
    loss_2d,
    examples,
    convergence_criteria,
    key,
    return_history=True,
)

fig = hp.create_optimizer_figure_2d(
    f_2d,
    loss_2d,
    true_thetas,
    convergence_criteria,
    sgd_history["theta1s"],
    sgd_history["theta2s"],
    sgd_history["losses"],
    sgd_history["f_preds"],
    sgd_history["max_n_iterations"],
    perf_profiling=False,
)

fig.show()

In [31]:
def RMSProp(
    lr: float,
    weight_decay: float,
    smoothing_constant: float,
    momentum: float,
    centered: bool,
    max_n_iterations: int,
    true_thetas: tuple[float | int, float | int],
    loss_fn,
    examples: list[tuple[float, float, float]],
    convergence_criteria: float,
    key,
    return_history: bool = False,
) -> dict | tuple[float, float]:
    r"""Perform Stochastic Gradient Descent.

    Algorithm taken from:
    https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html

    Args:
        lr: The learning rate.
        max_n_iterations: The maximum number of iterations.
        true_thetas: The true thetas.
        loss_fn: The loss function.
        examples: The examples.
        convergence_criteria: The convergence criteria.
        key: The random key.
        return_history: Whether to return the history.

    Returns:
        theta1s: The theta1 values recorded during optimization.
        theta2s: The theta2 values recorded during optimization.
        losses: The losses recorded during optimization.
        f_preds: The predicted functions.
        max_n_iterations: The maximum number of iterations

        or

        theta1: The final theta1 value.
        theta2: The final theta2 value
    """
    print(
        f"Running SGD with learning rate: {lr} and max iterations: {max_n_iterations}"
    )
    print(f"True thetas: {true_thetas}")

    loss_theta1 = grad(loss_fn, argnums=0)
    loss_theta2 = grad(loss_fn, argnums=1)
    theta1, theta2 = random.uniform(key), random.uniform(key)
    prev_loss = 0
    theta1s, theta2s, losses, f_preds = [], [], [], []
    square_avg_1, square_avg_2 = 0, 0
    buffer_1, buffer_2 = 0, 0
    g_avg_1, g_avg_2 = 0, 0
    for i in range(max_n_iterations):
        x, y, z = py_random.choices(examples)[0]
        g_t_1 = loss_theta1(theta1, theta2, x, y, z)
        g_t_2 = loss_theta2(theta1, theta2, x, y, z)

        if weight_decay != 0:
            g_t_1 += weight_decay * theta1
            g_t_2 += weight_decay * theta2

        square_avg_1 = smoothing_constant * square_avg_1 + (1 - lr) * g_t_1**2
        square_avg_2 = smoothing_constant * square_avg_2 + (1 - lr) * g_t_2**2

        square_avg_1_backup = square_avg_1
        square_avg_2_backup = square_avg_2

        if centered:
            g_avg_1 = smoothing_constant * g_avg_1 + (1 - lr) * g_t_1
            g_avg_2 = smoothing_constant * g_avg_2 + (1 - lr) * g_t_2
            square_avg_1_backup -= g_avg_1**2
            square_avg_2_backup -= g_avg_2**2

        if momentum > 0:
            buffer_1 = momentum * buffer_1 + g_t_1 / (
                jnp.sqrt(square_avg_1_backup) + 1e-8
            )
            buffer_2 = momentum * buffer_2 + g_t_2 / (
                jnp.sqrt(square_avg_2_backup) + 1e-8
            )
            theta1 -= lr * buffer_1
            theta2 -= lr * buffer_2
        else:
            theta1 -= lr * g_t_1 / (jnp.sqrt(square_avg_1) + 1e-8)
            theta2 -= lr * g_t_2 / (jnp.sqrt(square_avg_2) + 1e-8)

        loss = loss_2d(theta1, theta2, x, y, z).item()
        theta1s.append(theta1.item())
        theta2s.append(theta2.item())
        losses.append(loss)
        f_preds.append(
            lambda x, y, theta1, theta2: theta1 * x**2
            + theta2 * y**2
            + jnp.exp(1)
            + theta2
        )
        if jnp.abs(loss - prev_loss) < convergence_criteria:
            print(f"Converged at iteration {i + 1}!")
            print(
                f"y_pred = [{theta1:.3f}, {theta2:.3f}][x1^2, x2^2].T + theta2 + e, loss = {loss:.3f}"
            )
            max_n_iterations = i
            break
        prev_loss = loss
        print(
            f"Iteration {(i + 1):4d}: y_pred = [{theta1:.3f}, {theta2:.3f}][x1^2, x2^2].T + theta2 + e, loss = {loss:.3f}"
        )

    if return_history:
        return {
            "theta1s": theta1s,
            "theta2s": theta2s,
            "losses": losses,
            "f_preds": f_preds,
            "max_n_iterations": max_n_iterations,
        }
    else:
        return theta1s[-1], theta2s[-1]


In [None]:
lr = 1e-2
max_n_iterations = 50
convergence_criteria = 1e-2
true_thetas = (3, 9)

# Generate examples from the true function
n_examples = 100
examples = hp.generate_examples(f_2d, n_examples, key)

sgd_history = RMSProp(
    lr=lr,
    weight_decay=0.1,
    smoothing_constant=0.99,
    momentum=0.9,
    centered=False,
    max_n_iterations=max_n_iterations,
    true_thetas=true_thetas,
    loss_fn=loss_2d,
    examples=examples,
    convergence_criteria=convergence_criteria,
    key=key,
    return_history=True,
)

fig = hp.create_optimizer_figure_2d(
    f_2d,
    loss_2d,
    true_thetas,
    convergence_criteria,
    sgd_history["theta1s"],
    sgd_history["theta2s"],
    sgd_history["losses"],
    sgd_history["f_preds"],
    sgd_history["max_n_iterations"],
    perf_profiling=False,
)

fig.show()

In [36]:
def Adam(
    lr: float,
    weight_decay: float,
    betas: tuple[float, float],
    momentum: float,
    max_n_iterations: int,
    true_thetas: tuple[float | int, float | int],
    loss_fn,
    examples: list[tuple[float, float, float]],
    convergence_criteria: float,
    key,
    return_history: bool = False,
) -> dict | tuple[float, float]:
    r"""Perform Stochastic Gradient Descent.

    Algorithm taken from:
    https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html

    Args:
        lr: The learning rate.
        max_n_iterations: The maximum number of iterations.
        true_thetas: The true thetas.
        loss_fn: The loss function.
        examples: The examples.
        convergence_criteria: The convergence criteria.
        key: The random key.
        return_history: Whether to return the history.

    Returns:
        theta1s: The theta1 values recorded during optimization.
        theta2s: The theta2 values recorded during optimization.
        losses: The losses recorded during optimization.
        f_preds: The predicted functions.
        max_n_iterations: The maximum number of iterations

        or

        theta1: The final theta1 value.
        theta2: The final theta2 value
    """
    print(
        f"Running SGD with learning rate: {lr} and max iterations: {max_n_iterations}"
    )
    print(f"True thetas: {true_thetas}")

    loss_theta1 = grad(loss_fn, argnums=0)
    loss_theta2 = grad(loss_fn, argnums=1)
    theta1, theta2 = random.uniform(key), random.uniform(key)
    prev_loss = 0
    theta1s, theta2s, losses, f_preds = [], [], [], []
    first_moment_1, first_moment_2 = 0, 0
    second_moment_1, second_moment_2 = 0, 0
    v_max_1, v_max_2 = 0, 0
    beta1, beta2 = betas
    for i in range(max_n_iterations):
        x, y, z = py_random.choices(examples)[0]
        g_t_1 = loss_theta1(theta1, theta2, x, y, z)
        g_t_2 = loss_theta2(theta1, theta2, x, y, z)

        if weight_decay != 0:
            g_t_1 += weight_decay * theta1
            g_t_2 += weight_decay * theta2

        first_moment_1 = beta1 * first_moment_1 + (1 - beta1) * g_t_1
        first_moment_2 = beta1 * first_moment_2 + (1 - beta1) * g_t_2

        second_moment_1 = beta2 * second_moment_1 + (1 - beta2) * g_t_1**2
        second_moment_2 = beta2 * second_moment_2 + (1 - beta2) * g_t_2**2

        first_moment_1_hat = first_moment_1 / (1 - beta1)
        first_moment_2_hat = first_moment_2 / (1 - beta1)

        second_moment_1_hat = second_moment_1 / (1 - beta2)
        second_moment_2_hat = second_moment_2 / (1 - beta2)

        theta1 -= lr * first_moment_1_hat / (jnp.sqrt(second_moment_1_hat) + 1e-8)
        theta2 -= lr * first_moment_2_hat / (jnp.sqrt(second_moment_2_hat) + 1e-8)

        loss = loss_2d(theta1, theta2, x, y, z).item()
        theta1s.append(theta1.item())
        theta2s.append(theta2.item())
        losses.append(loss)
        f_preds.append(
            lambda x, y, theta1, theta2: theta1 * x**2
            + theta2 * y**2
            + jnp.exp(1)
            + theta2
        )
        if jnp.abs(loss - prev_loss) < convergence_criteria:
            print(f"Converged at iteration {i + 1}!")
            print(
                f"y_pred = [{theta1:.3f}, {theta2:.3f}][x1^2, x2^2].T + theta2 + e, loss = {loss:.3f}"
            )
            max_n_iterations = i
            break
        prev_loss = loss
        print(
            f"Iteration {(i + 1):4d}: y_pred = [{theta1:.3f}, {theta2:.3f}][x1^2, x2^2].T + theta2 + e, loss = {loss:.3f}"
        )

    if return_history:
        return {
            "theta1s": theta1s,
            "theta2s": theta2s,
            "losses": losses,
            "f_preds": f_preds,
            "max_n_iterations": max_n_iterations,
        }
    else:
        return theta1s[-1], theta2s[-1]


In [None]:
lr = 1e-2
max_n_iterations = 50
convergence_criteria = 1e-2
true_thetas = (3, 9)

# Generate examples from the true function
n_examples = 100
examples = hp.generate_examples(f_2d, n_examples, key)

sgd_history = Adam(
    lr=lr,
    weight_decay=0,
    betas=(0.9, 0.999),
    momentum=0.2,
    max_n_iterations=max_n_iterations,
    true_thetas=true_thetas,
    loss_fn=loss_2d,
    examples=examples,
    convergence_criteria=convergence_criteria,
    key=key,
    return_history=True,
)

fig = hp.create_optimizer_figure_2d(
    f_2d,
    loss_2d,
    true_thetas,
    convergence_criteria,
    sgd_history["theta1s"],
    sgd_history["theta2s"],
    sgd_history["losses"],
    sgd_history["f_preds"],
    sgd_history["max_n_iterations"],
    perf_profiling=False,
)

fig.show()