# Tests Below Are Adapted From SciPy

In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [3]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [5]:
from line_search import *

# Test Line Functions

In [63]:
N = 20
A = np.random.randn(N, N)

In [125]:
def xsquare():
    def f(x):
        return np.dot(x,x)
    def df(x):
        return 2*x
    return f, df

def xAx_plus1():
    def f(x):
        return np.dot(x, np.dot(A, x)) + 1
    def df(x):
        return np.dot(A + A.T, x)
    return f, df

line_funcs = [xAx_plus1]

In [160]:
def test_line_search(line_search_algo, assert_algo_conditions_f, use_wolfe=True, use_strong_wolfe=True):
    for line_func in line_funcs:
        name = line_func.__name__
        f, df = line_func()
        for k in range(9):
            x = np.random.randn(N) # current point
            p = np.random.randn(N) # direction vector
            if np.dot(p, df(x)) >= 0:
                # skip bec. this is not a descent direction
                continue
        (alpha, fval, dfval), num_fcalls, num_dfcalls = line_search_algo(f, df, x, p, \
                                use_wolfe=use_wolfe, use_strong_wolfe=use_strong_wolfe)
        if not alpha:
            #print("Failed to converge, no alpha")
            return False
        assert np.allclose(fval, f(x + alpha*p))
        assert_algo_conditions_f(x, p, alpha, f, df, use_wolfe=use_wolfe, use_strong_wolfe=use_strong_wolfe)
        return True

# Algorithm Specific checks

In [164]:
def assert_strong_wolfe(x, p, alpha, f, df, c1=DEFAULT_C1, c2=DEFAULT_C2, use_wolfe=True, use_strong_wolfe=True):
    phik = f(x + alpha * p)
    phi0 = f(x)
    dphik = np.dot(df(x + alpha * p), p)
    dphi0 = np.dot(df(x), p)
    assert phik <= phi0 + c1 * alpha * dphik, "armijo condition"
    if use_wolfe:
        assert dphik >= c2 * dphi0, "curvature condition"
    if use_strong_wolfe:
        assert dphik <= -c2 * dphi0, "strong wolfe condition"

# Test Backtracking

In [172]:
num_converged = 0
iters = 100
for i in range(iters):
    if test_line_search(backtracking_linesearch, assert_strong_wolfe, use_wolfe=False, use_strong_wolfe=False):
        num_converged += 1
print("percent_converged: %.4f, %d, %d" % (float(num_converged)/iters, num_converged, iters))

percent_converged: 0.5900, 59, 100


In [105]:
def _line_func_1(x):
    f = np.dot(x, x)
    df = 2*x
    return f, df
_line_func_1(np.random.randn(20))

(25.913157921037453,
 array([-2.4586231 ,  1.93379156, -5.62204109, -2.40694865, -0.01400421,
        -2.68526324,  1.25033187, -2.46393245, -0.64465757, -0.1371565 ,
         2.60833294,  3.96536624, -0.95070811,  3.19910055,  1.44995208,
        -1.40252077,  0.36504162, -1.69806175,  0.65493609, -0.08559938]))

In [109]:
import scipy.optimize.linesearch as ls
from numpy.testing import assert_, assert_equal, \
     assert_array_almost_equal, assert_array_almost_equal_nulp, assert_warns

def assert_fp_equal(x, y, err_msg="", nulp=50):
    """Assert two arrays are equal, up to some floating-point rounding error"""
    try:
        assert_array_almost_equal_nulp(x, y, nulp)
    except AssertionError as e:
        raise AssertionError("%s\n%s" % (e, err_msg))


def _line_func_1(x):
    f = np.dot(x, x)
    df = 2*x
    return f, df

def _line_func_2(x):
    f = np.dot(x, np.dot(A, x)) + 1
    df = np.dot(A + A.T, x)
    return f, df

def bind_index(func, idx):
    # Remember Python's closure semantics!
    return lambda *a, **kw: func(*a, **kw)[idx]
line_funcs = [(bind_index(_line_func_1, 0), bind_index(_line_func_1, 1))]

def line_iter():
    for f, fprime in line_funcs:
        k = 0
        while k < 9:
            x = np.random.randn(N)
            p = np.random.randn(N)
            if np.dot(p, fprime(x)) >= 0:
                # always pick a descent direction
                continue
            k += 1
            old_fv = float(np.random.randn())
            yield f, fprime, x, p, old_fv
            
def assert_line_armijo(x, p, s, f, **kw):
    assert_armijo(s, phi=lambda sp: f(x + p*sp), **kw)
    
def assert_armijo(s, phi, c1=1e-4, err_msg=""):
    """
    Check that Armijo condition applies
    """
    phi1 = phi(s)
    phi0 = phi(0)
    msg = "s = %s; phi(0) = %s; phi(s) = %s; %s" % (s, phi0, phi1, err_msg)
    assert_(phi1 <= (1 - c1*s)*phi0, msg)
        
def test_line_search_armijo():
    c = 0
    for idx, (f, fprime, x, p, old_f) in enumerate(line_iter()):
        f0 = f(x)
        g0 = fprime(x)
        #print(g0, "g0 value")
        s, fc, fv = ls.line_search_armijo(f, x, p, g0, f0)
        c += 1
        assert_fp_equal(fv, f(x + s*p))
        assert_line_armijo(x, p, s, f, err_msg=idx)
    assert_(c >= 9)

In [113]:
test_line_search_armijo()