<a href="https://colab.research.google.com/github/tensorush/Machine-Learning-Notebooks/blob/master/Notebooks/Optimization-Methods/Gradient%20Descent%20Methods.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Import libraries

In [183]:
import jax
import sympy as sp
import jax.numpy as jnp
import plotly.graph_objects as go

Example of symbolic gradient computation function in SymPy (I'll be computing gradients with JAX, though)

In [184]:
def symbolic_grad_func(func, vars):
    x, y = sp.symbols('x y', real=True)
    df_dx = sp.diff(func(x, y), x).evalf(subs={x: vars[0], y: vars[1]})
    df_dy = sp.diff(func(x, y), y).evalf(subs={x: vars[0], y: vars[1]})
    grad = jnp.array([df_dx, df_dy], dtype=float)
    return grad

Define functions for gradient descent methods

In [194]:
def naive_gradient_descent(func, vars, lr=1e-3, eps_norm=1e-6):
    grad_norm = jnp.inf
    grad_func = jax.grad(func)
    grad = jnp.empty(vars.size)
    history = dict(x=[vars[0]], y=[vars[1]], z=[func(*vars)])
    while grad_norm > eps_norm:
        grad = grad_func(*vars)
        grad_norm = jnp.linalg.norm(grad)
        vars -= lr * grad
        history['x'].append(vars[0])
        history['y'].append(vars[1])
        history['z'].append(func(*vars))
    return history


def adagrad(func, vars, lr=1e1, eps_norm=2e1, eps_dummy=1e-6):
    acc_grad = 0
    grad_norm = jnp.inf
    grad_func = jax.grad(func)
    grad = jnp.empty(vars.size)
    history = dict(x=[vars[0]], y=[vars[1]], z=[func(*vars)])
    while grad_norm > eps_norm:
        grad = grad_func(*vars)
        acc_grad += grad ** 2
        grad_norm = jnp.linalg.norm(grad)
        vars -= lr * grad / (jnp.sqrt(acc_grad) + eps_dummy)
        history['x'].append(vars[0])
        history['y'].append(vars[1])
        history['z'].append(func(*vars))
    return history


def gradient_descent_with_momentum(func, vars, lr=1e-1, dr=9e-1, eps_norm=1e-6, eps_dummy=1e-6):
    grad_norm = jnp.inf
    grad_func = jax.grad(func)
    grad = jnp.empty(vars.size)
    momentum = jnp.zeros_like(vars)
    history = dict(x=[vars[0]], y=[vars[1]], z=[func(*vars)])
    while grad_norm > eps_norm:
        grad = grad_func(*vars)
        grad_norm = jnp.linalg.norm(grad)
        momentum = dr * momentum - lr * grad
        vars += momentum
        history['x'].append(vars[0])
        history['y'].append(vars[1])
        history['z'].append(func(*vars))
    return history

Define function for plotting gradient descent methods

In [195]:
def plot_method(title, method, func, vars):
    history = method(func, vars)
    x, y = jnp.meshgrid(jnp.linspace(start=-100, stop=100, num=100), jnp.linspace(start=-100, stop=100, num=100))
    surface = go.Figure(data=[go.Surface(x=x, y=y, z=func(x, y), colorscale='aggrnyl', opacity=0.7)])
    surface.add_trace(go.Scatter3d(x=history['x'], y=history['y'], z=history['z'], mode='lines', line=dict(color='red', width=10)))
    surface.update_layout(title=title, autosize=True, width=500, height=500, margin=dict(l=20, r=20, b=10, t=40))
    surface.show()

Define surfaces

In [196]:
paraboloid = lambda x, y: x ** 2 + y ** 2
matyas = lambda x, y: 0.26 * (x ** 2 + y ** 2) - 0.48 * x * y 
booth = lambda x, y: (x + 2 * y - 7) ** 2 + (2 * x + y - 5) ** 2
rosenbrock = lambda x, y: -100 * (y - x ** 2) ** 2 + (1 - x) ** 2
himmelblau = lambda x, y: (x ** 2 + y - 11) ** 2 + (x + y ** 2 - 7) ** 2
three_hump_camel = lambda x, y: 2 * x ** 2 - 1.05 * x ** 4 + x ** 6 / 6 + x * y + y ** 2
bulkin_6 = lambda x, y: 100 * jnp.sqrt(jnp.abs(y - 0.01 * x ** 2)) + 0.01 * jnp.abs(x + 10)
easom = lambda x, y: -jnp.cos(x) * jnp.cos(y) * jnp.exp(-((x - jnp.pi) ** 2 + (y - jnp.pi) ** 2))
beale = lambda x, y: (1.5 - x + x * y) ** 2 + (2.25 - x + x * y ** 2) ** 2 + (2.625 - x + x * y ** 3) ** 2
rastrigin = lambda x, y: 10 * 2 + x ** 2 - 10 * jnp.cos(2 * jnp.pi * x) + y ** 2 - 10 * jnp.cos(2 * jnp.pi * y)
levi_13 = lambda x, y: jnp.sin(3) * jnp.pi * x + (x - 1) ** 2 * (1 + jnp.sin(3) ** 2 * jnp.pi * y) + (y - 1) ** 2 * (1 + jnp.sin(2) ** 2 * jnp.pi * y)
ackley = lambda x, y: -20 * jnp.exp(-0.2 * jnp.sqrt(0.5 * (x ** 2 + y ** 2))) - jnp.exp(-0.5 * (0.5 * (jnp.cos(2) * jnp.pi * x + jnp.cos(2) * jnp.pi * y))) + jnp.e + 20

Run and plot gradient descent methods

In [197]:
plot_method('Naive Gradient Descent on Paraboloid', naive_gradient_descent, paraboloid, jnp.array([100.0, 100.0]))
plot_method('AdaGrad on Paraboloid', adagrad, paraboloid, jnp.array([100.0, 100.0]))
plot_method('Gradient Descent with Momentum on Paraboloid', gradient_descent_with_momentum, paraboloid, jnp.array([100.0, 100.0]))
# plot_method('Matyas function', matyas, jnp.array([-10.0, 10.0]))
# plot_method('Booth function', booth, jnp.array([-10.0, -10.0]))
# plot_method('Rosenbrock function', rosenbrock, jnp.array([2.0, 2.0]))
# plot_method('Himmelblau function', himmelblau, jnp.array([-10.0, 10.0]))
# plot_method('Three-hump camel function', three_hump_camel, jnp.array([-3.0, 3.0]))
# plot_method('Bulkin function №6', bulkin_6, jnp.array([-3.0, 3.0]))
# plot_method('Easom function', easom, jnp.array([-3.0, 3.0]))
# plot_method('Beale function', beale, jnp.array([-4.5, 4.5]))
# plot_method('Rastrigin function', rastrigin, jnp.array([-5.0, 5.0]))
# plot_method('Lévi function №13', levi_13, jnp.array([-10.0, -10.0]))
# plot_method('Ackley function', ackley, jnp.array([-5.0, 5.0]))