https://daily-tech.hatenablog.com/entry/2017/03/21/063518

https://medium.com/@hirok4/python-implementation-of-levenberg-marquardt-algorithm-8ff8abdec0f5

“The Levenberg-Marquardt Algorithm: Implementation and Theory,” Numerical Analysis, ed. G. A. Watson, Lecture Notes in Mathematics 630, Springer Verlag, pp. 105-116, 1977.

Levenberg method is a combination of gradient descent and Newton method.
$$x_{n+1} = x_n - (\Delta^2 f(x_n) + \lambda I)^{-1} \Delta f(x_n)$$
when $\lambda$->0, the LM method approaches the NR method and when $\lambda$->$\infin$, it approaches the GD mehtod with small step sizes.

When the gain $\lambda$ is sufficiently large, even if the Hessian matrix is not positive definite, the matrix $H + \lambda I$ can be positive definite and gaurantee a reduction in the function's value

If in each step of uptake, the cost function f(x) goes down (which implies that the curvature is helping), we accept the step and we reduce $\lambda$ (usually by a factor of 10) to reduce the influence of gradient descent. On the other hand, if the cost function goes up, we retract the step and increase $\lambda$ by a factor of 10 or some significant factor.

In terms of convergence speed: GD < LM < NR

In terms of convergence stability: GD > LM > NR

In [1]:
import jax.numpy as jnp 
from jax import grad, jit, vmap, jacfwd, hessian, random
from jax.tree_util import tree_map, tree_flatten
from jax.experimental.ode import odeint 

import matplotlib.pyplot as plt

### Newton-Raphson Algorithm

In [2]:
def forward(params, x):
    a, b = params[0], params[1]
    return b*jnp.exp(a * x)

# toy data
x = jnp.linspace(1, 2, 20)
params_true = jnp.array([
    2.5,
    0.8
])
y = forward(params_true, x) + random.normal(key=random.PRNGKey(123), shape=(len(x),)) * 0.5
# plt.plot(x, y)

def cost_fn(params, x, y):
    pred = forward(params, x)
    err  = pred - y
    return jnp.mean(jnp.square(err))

# print(cost_fn(params_true, x, y))

# def newton_raphson(cost_fn, params, x, y):
#     jacob = jacfwd(cost_fn, argnums=0)
#     hess  = hessian(cost_fn,argnums=0)
#     return tree_map(lambda p, j, h: p - jnp.invert(h) @ jnp.transpose(j), params, jacob, hess)

params = jnp.array([
    2.4,  # poor convergence when the starting values are far from the final estimates
    0.5
])

for i in range(1000):
    jacob = jacfwd(cost_fn, argnums=0)
    jac = jacob(params, x, y)

    hess = hessian(cost_fn, argnums=0)
    hes = hess(params, x, y)

    params -=  jnp.linalg.inv(hes) @ jac.T
    if i % 100 == 0:
        print(params)


[9.507141   0.48427495]
[2.4896717  0.81633145]
[2.4896717  0.81633145]
[2.4896717  0.81633145]
[2.4896717  0.81633145]
[2.4896717  0.81633145]
[2.4896717  0.81633145]
[2.4896717  0.81633145]
[2.4896717  0.81633145]
[2.4896717  0.81633145]


### Levenberg-Marquardt Algorithm

In [3]:
def forward(params, x):
    a, b= params[0], params[1]
    return b*jnp.exp(a * x)

# toy data
x = jnp.linspace(1, 2, 20)
params_true = jnp.array([
    2.5,
    0.8
])
y = forward(params_true, x) + random.normal(key=random.PRNGKey(123), shape=(len(x),)) * 0.5
# plt.plot(x, y)

def cost_fn(params, x, y):
    pred = forward(params, x)
    err  = pred - y
    return jnp.mean(jnp.square(err))


params = jnp.array([
    10.2,
    1.
])
n_iter = 5000

for i in range(n_iter):
    jacob = jacfwd(cost_fn, argnums=0)
    jac = jacob(params, x, y)

    hess = hessian(cost_fn, argnums=0)
    hes = hess(params, x, y)

    # add a small constat [identity / diagonal matrix] to Hes, such that the Hes remains positive definite
    params -=  jnp.linalg.inv(hes + 1e3*jnp.identity(hes.shape[0])) @ jac.T 
    
    # use diagonal such that when lambda -> 0, algorithm appcoaches the Gauss-Newton
    # when lambda -> inf, algorithm approaches the Gradient Descent
    # params -=  jnp.linalg.inv(hes + 1/(i+1)*jnp.diag(hes)) @ jac.T 
    if i % (n_iter/10) == 0:
        print(f"Iter: {i:8} | Loss: {cost_fn(params, x, y):8.4f} | Est: {params}")

Iter:        0 | Loss: 14650948416700416.0000 | Est: [9.946395  1.0005617]
Iter:      500 | Loss:   0.2475 | Est: [2.489671  0.8163325]
Iter:     1000 | Loss:   0.2475 | Est: [2.489671  0.8163325]
Iter:     1500 | Loss:   0.2475 | Est: [2.489671  0.8163325]
Iter:     2000 | Loss:   0.2475 | Est: [2.489671  0.8163325]
Iter:     2500 | Loss:   0.2475 | Est: [2.489671  0.8163325]
Iter:     3000 | Loss:   0.2475 | Est: [2.489671  0.8163325]
Iter:     3500 | Loss:   0.2475 | Est: [2.489671  0.8163325]
Iter:     4000 | Loss:   0.2475 | Est: [2.489671  0.8163325]
Iter:     4500 | Loss:   0.2475 | Est: [2.489671  0.8163325]
