# Tests Below Are Adapted From SciPy

In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [174]:
from line_search import *

# Test Line Functions

In [175]:
N = 20
A = np.random.randn(N, N)

In [176]:
def xsquare():
    def f(x):
        return np.dot(x,x)
    def df(x):
        return 2*x
    return f, df

def xAx_plus1():
    def f(x):
        return np.dot(x, np.dot(A, x)) + 1
    def df(x):
        return np.dot(A + A.T, x)
    return f, df

line_funcs = [xAx_plus1]

In [177]:
def test_line_search(line_search_algo, assert_algo_conditions_f, use_wolfe=True, use_strong_wolfe=True):
    for line_func in line_funcs:
        name = line_func.__name__
        f, df = line_func()
        for k in range(9):
            x = np.random.randn(N) # current point
            p = np.random.randn(N) # direction vector
            if np.dot(p, df(x)) >= 0:
                # skip bec. this is not a descent direction
                continue
        (alpha, fval, dfval), num_fcalls, num_dfcalls = line_search_algo(f, df, x, p, \
                                use_wolfe=use_wolfe, use_strong_wolfe=use_strong_wolfe)
        if not alpha:
            #print("Failed to converge, no alpha")
            return False
        assert np.allclose(fval, f(x + alpha*p))
        assert_algo_conditions_f(x, p, alpha, f, df, use_wolfe=use_wolfe, use_strong_wolfe=use_strong_wolfe)
        return True

# Algorithm Specific checks

In [178]:
def assert_strong_wolfe(x, p, alpha, f, df, c1=DEFAULT_C1, c2=DEFAULT_C2, use_wolfe=True, use_strong_wolfe=True):
    phik = f(x + alpha * p)
    phi0 = f(x)
    dphik = np.dot(df(x + alpha * p), p)
    dphi0 = np.dot(df(x), p)
    assert phik <= phi0 + c1 * alpha * dphik, "armijo condition"
    if use_wolfe:
        assert dphik >= c2 * dphi0, "curvature condition"
    if use_strong_wolfe:
        assert dphik <= -c2 * dphi0, "strong wolfe condition"

# Test Backtracking

In [179]:
num_converged = 0
iters = 100
for i in range(iters):
    if test_line_search(backtracking_linesearch, assert_strong_wolfe, use_wolfe=False, use_strong_wolfe=False):
        num_converged += 1
print("percent_converged: %.4f, %d, %d" % (float(num_converged)/iters, num_converged, iters))

percent_converged: 0.6000, 60, 100


# Test Interpolating

In [185]:
num_converged = 0
iters = 100
for i in range(iters):
    if test_line_search(interpolating_line_search, assert_strong_wolfe, use_wolfe=False, use_strong_wolfe=False):
        num_converged += 1
print("percent_converged: %.4f, %d, %d" % (float(num_converged)/iters, num_converged, iters))

percent_converged: 1.0000, 100, 100
