In [2]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap, hessian, jacobian
from jax import random
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

import copy, timeit
%matplotlib

Using matplotlib backend: <object object at 0x11b005b10>


In [3]:
def Rastrigin(x, A=10): # define objective (map from n dim to 1 dim)
    f = A * len(x) + jnp.sum(jnp.array([xi ** 2 - A * jnp.cos(2 * xi * jnp.pi) for xi in x]))
    return f

RastriginGrad = grad(Rastrigin) # use jacobian() if function map from n to m. 
#Here grad is less computational expensive that jacobian
RastriginGradSec = jacobian(grad(Rastrigin))
RastriginHess = hessian(Rastrigin) 
RastriginJac = jacobian(Rastrigin)
initial_x = np.random.rand(20)

backtracking line search

In [4]:
def Backtracking(xi, func, p, grads, alpha_init=1., c=0.7, low= 0.7): # perform inexact line search
    a = alpha_init
    while True:
        if func(xi + a * p) <= func(xi) + c * a * jnp.dot(p, grads):
            return a
        else:
            a = low * a

Implement Steepest descent and preconditioning

The search direction is given by
$$p_k = - \frac{{\nabla}f(\textbf{x}_k)}{\parallel{\nabla}f(\textbf{x}_k)\parallel}$$
And general update rule is given by
$$x_{k+1} \leftarrow x_k + \alpha p_k$$

In [5]:
theshold = 0.1
X_k = initial_x #np.random.rand(20)
eval_first = Rastrigin(X_k)
alpha = 1.
while True:
    grad_k = RastriginGrad(X_k)
    grad_norm_k = jnp.linalg.norm(grad_k)
    direction_p = - grad_k / grad_norm_k
    alpha_k = Backtracking(X_k, Rastrigin, direction_p, grad_k)
    X_k = X_k + alpha_k * direction_p
    alpha_k = alpha
    print(alpha_k, grad_norm_k, Rastrigin(X_k))
    if grad_norm_k < theshold:
        break
print('frist evaluation', eval_first, 'final evaluation', Rastrigin(X_k))




1.0 196.72731 59.103943
1.0 132.1322 24.33731
1.0 81.202286 13.579773
1.0 43.10411 11.651001
1.0 23.263014 11.14447
1.0 12.5198555 11.008545
1.0 7.108055 10.958191
1.0 3.2850423 10.947052
1.0 1.4052193 10.945282
1.0 0.7597969 10.944672
1.0 0.30787194 10.94458
1.0 0.15281661 10.944565
1.0 0.09963487 10.94455
frist evaluation 137.48909 final evaluation 10.94455


Implement Newton Method 

The search direction is given by
$$p_k = - \frac{{\nabla}f(\textbf{x}_k)}{{\nabla}^2f(\textbf{x}_k)}$$

In [539]:
X_k = initial_x
eval_first = Rastrigin(X_k)
evalx = []
alpha = 1.

while True:
    evalx.append(X_k)
    G2 = np.diag(1 / np.diag(RastriginHess(X_k))) #np.linalg.pinv(RastriginHess(X_k))
    G1 = RastriginGrad(X_k) 
    direction_p =  - G2.dot(G1)
    alpha_k = Backtracking(X_k, Rastrigin, direction_p, G1, c=0.1, low= 0.8)
    X_k = X_k + alpha_k * direction_p
    alpha_k = alpha_k
    
    grad_norm_k = jnp.linalg.norm(G1)
    print(alpha_k, grad_norm_k, Rastrigin(X_k))
    if grad_norm_k < theshold*0.1:
        break
print('frist evaluation', eval_first, 'final evaluation', Rastrigin(X_k))

1.0 208.40935 283.82776
1.0 155.61899 276.5726
0.5120000000000001 125.34485 271.59433
1.0 61.638542 268.36316
1.0 9.004759 268.2616
1.0 0.036696058 268.2616
frist evaluation 280.4375 final evaluation 268.2616


Example on linear regression

In [6]:
def least_squares(A, b, x): #GRAD : A.T.dot(A.dot(x)-b)/len(A), HESS : #A.T.dot(A)/len(A)
    return (0.5/len(A)) * jnp.linalg.norm(jnp.dot(A, x)-b)**2

ghk = jacobian(grad(least_squares, 2), 2)
gk = grad(least_squares ,2)

Ns = 100
Axxk = np.random.rand(1000, Ns)
xxk = np.random.rand(Ns, )
b = Axxk.dot(xxk)

xs = np.random.rand(Ns, ) 
print('first eval', least_squares(Axxk, b, xs))
for step in [1, 1]:
    Hinv = np.linalg.pinv(ghk(Axxk, b, xs).reshape(Ns, Ns))
    xs = xs - step * Hinv.dot(gk(Axxk, b, xs))
print('last eval', least_squares(Axxk, b, xs))

first eval 0.8685655
last eval 7.566996e-13


Newton's with by solving system of linear equation
$$p_k = - \mathbf{H}^{-1}_k{\nabla}f(\textbf{x}_k)$$
$$\mathbf{H}_k p_k = g_k$$

In [13]:
EPS = 0.00000001
def conjugate_gradient(Ax, b, max_iters=100, tol=0.001, CG_time_step=0):
    'A is (S++)'
    x = np.zeros_like(b)
    r = copy.deepcopy(b)  
    p = copy.deepcopy(r)
    r_dot_old = np.dot(r, r)

    for _ in range(max_iters):
        z = np.dot(Ax, p) # for hessian vector product use Ax(p) as a function 
        alpha = r_dot_old /( np.dot(p, z) + EPS)
        x += alpha * p
        r -= alpha * z
        CG_time_step+=1
        if tol >= np.linalg.norm(r):
            return x #CG_time_step
        
        r_dot_new = np.dot(r, r)
        p = r + (r_dot_new / (r_dot_old + EPS)) * p
        r_dot_old = r_dot_new

    return x #CG_time_step

In [14]:
grad_lg = grad(least_squares, 2)
hess_lg = lambda as_, bs, xs: jacobian(grad(least_squares, 2), 2)(as_, bs, xs).reshape(Ns, Ns)
Hessian_vector_prod = lambda f

# def hvp(f, x, v):
#     return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

def Newton_ls(func, gradf, hessf, x_init, solver=conjugate_gradient, max_iters=20):
    x = x_init
    alpha = 0.1 # use fix step size
    First_eval = func(Axxk, b, x)
    start = timeit.default_timer()
    
    for i in range(max_iters):
        G2 = hessf(Axxk, b, x)
        G1 = gradf(Axxk, b, x)
        assert np.all(np.linalg.eigvals(G2) > 0), 'hessian is not possitive-definite'
        direction_p = - solver(G2, G1)
        x = x + alpha * direction_p
    execution_time = timeit.default_timer() - start
    print('frist_eval : ', First_eval, 'last_eval', func(Axxk, b, x), 'time :', execution_time)
    return x

In [32]:
xs = np.random.rand(Ns, ) 
solver0 = lambda A, b: np.linalg.lstsq(A, b, rcond=None)[0]
#solver1 = grad least square
x_solved = Newton_ls(least_squares, grad_lg, hess_lg, xs, solver=conjugate_gradient) # using conjugate gradient.
x_solved = Newton_ls(least_squares, grad_lg, hess_lg, xs, solver=solver0) # using least squre.

frist_eval :  0.6601352 last_eval 0.00975992 time : 0.40430069499961974
frist_eval :  0.6601352 last_eval 0.009757383 time : 0.4022046070003853


Approximate hessian newton method

In [50]:
def Newton_bgfs(func, x, gradf, hessf, max_iters=20):
    First_eval = func(Axxk, b, x)
    
    I = np.eye(Ns) 
    G2_inv = I# initialize inverse H as identity
    alpha = 0.1 #fix search distance
    xss = []
    xss.append(x)
    
    for _ in range(max_iters):
        
        G2 = hessf(Axxk, b, x)
        G1 = gradf(Axxk, b, x)
        direction_p = - jnp.dot(G2_inv, G1)
        x = x + alpha * direction_p
        xss.append(x)
        
        s = xss[-1] - xss[-2]
        y = gradf(Axxk, b, xss[-1]) - gradf(Axxk, b, xss[-2])
        ys = jnp.dot(y, s)
        ssT = jnp.outer(s, s)
        
        G2 = G2 + np.outer(y, y)/ys
        G2 = G2 - G2.dot(ssT.dot(G2))/np.dot(s, G2.dot(s))

        
        left = I - np.outer(s, y)/ys
        right = I - np.outer(y, s)/ys
        G2_inv = (left.dot(G2_inv).dot(right) + ssT/ys)
        print('eval_function : ', func(Axxk, b, x), 'grad : ', np.linalg.norm(G1))
    last_eval = func(Axxk, b, x)
    return x, xss

In [51]:
xs = np.random.rand(Ns, ) 
xs_solved = Newton_bgfs(least_squares, xs, grad_lg, hess_lg, max_iters=20)

eval_function :  3.7706165 grad :  8.386166
eval_function :  3.160022 grad :  12.591924
eval_function :  2.64849 grad :  11.334605
eval_function :  2.2199438 grad :  10.186216
eval_function :  1.8609171 grad :  9.138845
eval_function :  1.5601273 grad :  8.185016
eval_function :  1.3081187 grad :  7.3176475
eval_function :  1.096974 grad :  6.5300703
eval_function :  0.92005783 grad :  5.816019
eval_function :  0.77181345 grad :  5.1696024
eval_function :  0.6475837 grad :  4.585316
eval_function :  0.54346925 grad :  4.058004
eval_function :  0.45620403 grad :  3.5828629
eval_function :  0.38305196 grad :  3.1554267
eval_function :  0.32172215 grad :  2.7715402
eval_function :  0.27029517 grad :  2.4273548
eval_function :  0.22716323 grad :  2.1193056
eval_function :  0.19098176 grad :  1.8440948
eval_function :  0.16062327 grad :  1.5986938
eval_function :  0.13514362 grad :  1.3803045


In [69]:
# for one dimensional visualization only !!!!!
#evalx = np.asarray(evalx).flatten()
#evaly = np.asarray([Rastrigin([x_i]) for x_i in evalx])
# fig = plt.figure()
# ax = plt.subplot(1, 1, 1)
# data_skip = 1
# def init_func():
#     ax.clear()
#     plt.title('optimizing Rastrigin function')
#     plt.xlabel('x')
#     plt.ylabel('RASTRIGIN')
#     k = np.linspace(-5.14, 5.15, 10000)
#     yy = np.asarray([Rastrigin([xi]) for xi in np.linspace(-5.14, 5.15, 10000)])
#     plt.plot(k, yy)
# def fram_plt(i):
#     #ax.plot(evalx[i:  i+data_skip], evaly[i:i+data_skip], color='k')
#     ax.scatter(evalx[i], evaly[i], marker='o', color='r')
# animation = FuncAnimation(fig, func=fram_plt, frames=np.arange(0, len(evalx), data_skip),
#                           init_func=init_func, interval=100)


In [57]:
xs_solved[1]
evalx =
evaly = np.asarray([least_squares(Axxk, b, xi) for xi in xs_solved[1]])

In [70]:
evaly

array([2.017699  , 3.7706165 , 3.160022  , 2.64849   , 2.2199438 ,
       1.8609171 , 1.5601273 , 1.3081187 , 1.096974  , 0.92005783,
       0.77181345, 0.6475837 , 0.54346925, 0.45620403, 0.38305196,
       0.32172215, 0.27029517, 0.22716323, 0.19098176, 0.16062327,
       0.13514362], dtype=float32)

In [73]:
plt.plot(evaly)
plt.show()