# Optimizers

Covers various topics on optimization in the context of Deep Learning.

In [79]:
import jax.numpy as jnp
import jax.random as random
from jax import grad
import plotly.express as px
import pandas as pd
import plotly.io as pio
import plotly.graph_objects as go

In [15]:
pio.templates.default = "plotly_dark"

In [16]:
key = random.PRNGKey(577)

In [100]:
x = jnp.arange(0, 10, 0.1)

def f(x):
    return 4 * x

In [101]:
f_dx = grad(f)

In [102]:
df = pd.DataFrame()
df['x'] = list(x)
df['f'] = list(f(x))
df['df'] = list(f_dx(_x) for _x in x)
df.head()

Unnamed: 0,x,f,df
0,0.0,0.0,4.0
1,0.1,0.4,4.0
2,0.2,0.8,4.0
3,0.3,1.2,4.0
4,0.4,1.6,4.0


In [103]:
fig = px.line(df, x="x", y=["f", "df"])
fig.show()

$$ \text{MSE} = \frac{1}{n} \sum_{i=1}^n \left(y_i - \hat{y}_i \right)^2 $$

If we naively define the MSE function it would look like this:

In [104]:
def mean_squared_error_naive(true, pred):
    return jnp.mean((true - pred) ** 2)

But if we want to take the derivative of this function using JAX's `grad` function, then we need to explicitly state all the variables in the function signature as such:

In [105]:
def mean_squared_error(theta, x, true):
    pred = x.dot(theta)
    return jnp.mean((true - pred) ** 2)

Suppose $\hat{y} = 2*x$. It is typical to use $w$ or $\theta$ as the parameter of our function, we will stick with $\theta$ as it extends naturally to probability. In our function we only have one parameter, hence $\theta = 2$.

We are interested in perturbing this parameter to minimize the loss over our observed (or true) sample set. To do this we will measure the impact the parameter had on the loss function, seeking to minimize this impact. To measure this impact, lets take the derivate of MSE using our parameterized function. First we will re-write the MSE loss function and then take the derivate with respect to the parameter $\theta$:

$$
\begin{aligned}
\text{MSE} &= \frac{1}{n} \sum_{i=1}^n \left(y_i - \hat{y}_i \right)^2 = \frac{1}{n} \sum_{i=1}^n \left(y_i - \theta x \right)^2 \\
\frac{\partial}{\partial \theta} \text{MSE} &= \frac{\partial}{\partial \theta} \frac{1}{n} \sum_{i=1}^n \left(y_i - \theta x  \right)^2 \\
&=\frac{1}{n} \sum_{i=1}^n 2 \left(y_i - \theta x \right) * (-1)x \\
&=-\frac{2}{n} \sum_{i=1}^n \left(y_i - \theta x \right)x
\end{aligned}
$$

In [106]:
def mean_squared_error_dx_man(theta, x, true):
    return -2 * jnp.mean((true - theta * x) * x)

In [107]:
loss_theta_man = mean_squared_error_dx_man(2., x, f(x)).item()
print(f"Loss (theta = 2): {loss_theta_man:.3f}")

Loss (theta = 2): -131.340


In [108]:
mean_squared_error_d_theta = grad(mean_squared_error, argnums=0)

In [109]:
loss_theta = mean_squared_error_d_theta(2., x, f(x)).item()
print(f"Loss (theta = 2): {loss_theta:.3f}")

Loss (theta = 2): -131.340


### Stochastic Gradient Descent

$$ \theta_{t+1} = \theta_t - \alpha D_\theta \text{MSE} $$

In [113]:
alpha = 0.01
theta = random.uniform(key)
n_iterations = 20
fig = go.Figure()
steps = []

print(f"Initial theta: {theta}")

fig.add_trace(
    go.Scatter(
        visible=False,
        name=f"True Function",
        x=x,
        y=f(x)
    )
)

prev_loss = 0
for i in range(n_iterations):
    ###########################################################################
    theta -= alpha * mean_squared_error_d_theta(theta, x, f(x))
    loss = mean_squared_error(theta, x, f(x)).item()
    if jnp.abs(loss - prev_loss) < 1e-6:
        print(f"Converged at iteration {i + 1}")
        break
    prev_loss = loss
    ###########################################################################
    fig.add_trace(
        go.Scatter(
            visible=False,
            name=f"Iteration {(i + 1):4d}",
            x=x,
            y=theta * x
        )
    )
    step = dict(
        method="update",
        args=[
            {"visible": [False] * (n_iterations + 1)},
            {"title": "SGD Iteration: " + str(i + 1)}
        ],
    )
    step["args"][0]["visible"][0] = True
    step["args"][0]["visible"][i+1] = True
    steps.append(step)
    ###########################################################################
    print(f"Iteration {(i + 1):4d}: loss = {loss:.3f} | y_pred = {theta:.3f}x")

sliders = [dict(
    active=n_iterations,
    currentvalue={"prefix": "SGD: "},
    pad={"t": 50},
    steps=steps
)]

fig.data[0].visible = True
fig.data[1].visible = True
fig.update_layout(
    sliders=sliders
)
fig.show()

Initial theta: 0.15920352935791016
Iteration    1: loss = 57.086 | y_pred = 2.681x
Iteration    2: loss = 6.728 | y_pred = 3.547x
Iteration    3: loss = 0.793 | y_pred = 3.845x
Iteration    4: loss = 0.093 | y_pred = 3.947x
Iteration    5: loss = 0.011 | y_pred = 3.982x
Iteration    6: loss = 0.001 | y_pred = 3.994x
Iteration    7: loss = 0.000 | y_pred = 3.998x
Iteration    8: loss = 0.000 | y_pred = 3.999x
Iteration    9: loss = 0.000 | y_pred = 4.000x
Iteration   10: loss = 0.000 | y_pred = 4.000x
Converged at iteration 11


In [178]:
x2 = jnp.arange(-10, 10, 0.1)
y2 = jnp.arange(-10, 10, 0.1)
def f_2d(x, y):
    return 3*x**2 + 9*y**2
z = [f_2d(_x, y2) for _x in x2]

In [116]:
def loss_2d(theta1, theta2, x, y, true):
    pred = theta1 * x**2 + theta2 * y**2
    return jnp.mean((true - pred) ** 2)

In [117]:
loss_2d_th1 = grad(loss_2d, argnums=0)
loss_2d_th2 = grad(loss_2d, argnums=1)

In [None]:
alpha = 0.0001
theta1, theta2 = random.uniform(key), random.uniform(key)
n_iterations = 200

x2 = jnp.arange(-10, 10, 0.1)
y2 = jnp.arange(-10, 10, 0.1)
_z = [f_2d(_x, y2) for _x in x2]

fig = go.Figure()
fig.add_trace(go.Surface(z=_z, x=x2, y=y2, colorscale='Blues', showscale=False))
steps = []
fig.update_layout(
    scene = dict(
        xaxis = dict(range=[-12, 12],),
        yaxis = dict(range=[-12, 12],),
        zaxis = dict(range=[0, 1700],)
    ,)
)

print(f"Initial theta1: {theta1}")
print(f"Initial theta2: {theta2}")

# Generate samples from the true function
n_samples = 100
samples = []
for i in range(n_samples):
    key, subkey = random.split(key)
    x = random.uniform(subkey, (1,), minval=-10, maxval=10)
    key, subkey = random.split(key)
    y = random.uniform(subkey, (1,), minval=-10, maxval=10)
    z = f_2d(x, y)
    samples.append((x, y, z))

prev_loss = 0
for i in range(n_iterations):
    ###########################################################################
    key, subkey = random.split(key)
    idx = random.randint(subkey, (1,), 0, n_samples)
    x, y, z = samples[idx[0]]
    theta1 -= alpha * loss_2d_th1(theta1, theta2, x, y, z)
    theta2 -= alpha * loss_2d_th2(theta1, theta2, x, y, z)
    loss = loss_2d(theta1, theta2, x, y, z).item()
    if jnp.abs(loss - prev_loss) < 1e-2:
        print(f"Converged at iteration {i + 1}")
        n_iterations = i
        break
    prev_loss = loss
    ###########################################################################
    y_pred = lambda x, y: theta1 * x**2 + theta2 * y**2
    z = [y_pred(_x, y2) for _x in x2]
    fig.add_trace(
        go.Surface(z=z, x=x2, y=y2, visible=False, name=f"Iteration {(i + 1):4d}",  showscale=False)
    )
    step = dict(
        method="update",
        args=[
            {"visible": [False] * (n_iterations + 1)},
            {"title": "SGD Iteration: " + str(i + 1)}
        ],
    )
    step["args"][0]["visible"][0] = True
    step["args"][0]["visible"][i+1] = True
    steps.append(step)
    ###########################################################################
    print(f"Iteration {(i + 1):4d}: loss = {loss:.3f} | y_pred = {theta1:.3f}x**2 + {theta2:.3f}y**2")

sliders = [dict(
    active=n_iterations,
    currentvalue={"prefix": "SGD: "},
    pad={"t": 50},
    steps=steps
)]

fig.data[0].visible = True
fig.data[1].visible = True
fig.update_layout(
    title="SGD 2D",
    autosize=True,
    sliders=sliders,
    height=800,
    margin=dict(l=65, r=50, b=65, t=90)
)

fig.show()

In [150]:
_, subkey = random.split(key)
subkey

Array([1244123830, 3563417016], dtype=uint32)

In [151]:
_, subkey = random.split(key)
subkey

Array([1244123830, 3563417016], dtype=uint32)