## Uses the rosenbrock function to compare the LM methods:
1. the classical LM method available in the scipy package (using tensorflow to supply the full jacobian to scipy)
2. the classical LM method implemented in this package
3. using the LM method with the Generalized Gauss Newton matrix for a Newton-like approach, 
   with the *prediction function* is just the identity transform, and the entire rosenbrock function is inside the *loss function*
   
**Other notes**:
- Since the scipy calculation uses the actual matrix inversion, we require a highly precise CG result with the matrix-free method for a valid comparison. 
- The scipy optimizer records the number of jacobian evaluations. Each full *iteration* of the matrix-free method can be thought of as an equivalent operation.
- We want all the LM methods to have:
    - identical output variable
    - similar number of jacobian evaluations (or iterations).
- We perform the comparisons with 10 different random initializations

In [9]:
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from sopt.optimizers.tensorflow2 import LMA, NonLinearConjugateGradient
import scipy.optimize as spopt
from tqdm.notebook import tqdm
tf.get_logger().setLevel('ERROR')


In [10]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
def fun_rosenbrock(x):
    # outputs the residuals
    y1 = 10 * (x[1:] - x[:-1]**2)
    y2 = 1 - x[:-1]
    return tf.concat((y1, y2), axis=0)

def fun_loss(x):
    return 0.5 * tf.reduce_sum(x**2)

def ggn_preds_fn(v):
    return v
def ggn_loss_fn(v):
    return fun_loss(fun_rosenbrock(v))

In [12]:
def scipy_fun_rosenbrock(x):
    t = tf.constant(x, dtype='float32')
    out = fun_rosenbrock(t).numpy()
    return out
def scipy_jacobian_rosenbrock(x):
    t = tf.constant(x, dtype='float32')
    with tf.GradientTape() as gt:
        gt.watch(t)
        rsnbrck = fun_rosenbrock(t)
    jac = gt.jacobian(rsnbrck, t)
    return jac.numpy()

In [13]:
def runScipyLM(init):
    res = spopt.least_squares(scipy_fun_rosenbrock, init, jac=scipy_jacobian_rosenbrock, method='lm')
    return res

In [14]:
def runClassicLM(init, n_iterations, supply_diag_hessian=False):
    results = {}
    diag_hessian_fn = None
    if supply_diag_hessian:
        diag_hessian_fn = lambda v: 1.0
    v = tf.Variable(init, dtype='float32')
    lm = LMA(input_var=v, predictions_fn=fun_rosenbrock,
             loss_fn=fun_loss, diag_hessian_fn=diag_hessian_fn,
             name='lma', warm_start=True, 
             min_cg_tol=1e-5)
    for j in range(n_iterations):
        lm.minimize()
    results['x'] = v.numpy()
    results['cost'] = lm._loss_new.numpy()
    return results

In [15]:
def runGGNLM(init, n_iterations):
    results = {}
    v = tf.Variable(init, dtype='float32')
    lm = LMA(input_var=v, predictions_fn=ggn_preds_fn,
             loss_fn=ggn_loss_fn,
             name='lma', warm_start=True, 
             min_cg_tol=1e-5)
    for j in range(n_iterations):
        lm.minimize()
    results['x'] = v.numpy()
    results['cost'] = lm._loss_new.numpy()
    return results

In [30]:
scipy_results = []
lm_classic_results = []
lm_classic_with_supplied_hessian_results = [] # this should be identical to lm_classic_results
lm_ggn_results = []
for i in tqdm(range(10)):
    z_init = np.random.random(5).astype('float32')
    res = runScipyLM(z_init)
    scipy_results.append(res)
    lm_classic_results.append(runClassicLM(z_init, res.njev))
    lm_classic_with_supplied_hessian_results.append(runClassicLM(z_init, res.njev, True))
    lm_ggn_results.append(runGGNLM(z_init, res.njev))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




In [31]:
for r in scipy_results:
    print('x', r.x, 'iterations', r.njev)

x [1.00000001 1.00000006 1.00000001 1.00000002 0.99999999] iterations 11
x [1.00000002 0.99999997 1.00000002 1.00000002 0.99999999] iterations 6
x [0.99999998 1.00000005 1.00000001 0.99999998 1.00000001] iterations 14
x [1.00000002 1.00000004 1.00000005 1.         0.99999999] iterations 12
x [0.99999997 1.00000003 1.00000004 0.99999998 1.00000003] iterations 15
x [1.00000004 1.00000001 1.00000001 0.99999998 1.        ] iterations 14
x [1.         0.99999998 0.99999999 1.00000001 1.00000002] iterations 15
x [1.00000002 1.00000001 1.00000002 1.00000002 1.00000001] iterations 14
x [1.00000002 1.00000001 1.00000005 0.99999999 1.00000001] iterations 13
x [0.99999999 1.00000001 1.00000001 1.00000003 1.        ] iterations 14


In [32]:
for l in lm_classic_results:
    print('x', l['x'], 'cost', l['cost'])

x [0.99898297 0.99796224 0.995917   0.9918258  0.9836622 ] cost 4.453459e-05
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0


In [33]:
for l in lm_classic_with_supplied_hessian_results:
    print('x', l['x'], 'cost', l['cost'])

x [0.99898297 0.99796224 0.995917   0.9918258  0.9836622 ] cost 4.453459e-05
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0
x [1. 1. 1. 1. 1.] cost 0.0


In [34]:
for l in lm_ggn_results:
    print('x', l['x'], 'cost', l['cost'])

x [4.8016688e-01 2.3040651e-01 5.7718989e-02 1.2542443e-02 1.5040496e-04] cost 1.3680493
x [0.9977753  0.9953664  0.9900318  0.9774746  0.94599736] cost 0.005179659
x [0.99369323 0.98739755 0.97488266 0.9502482  0.90268326] cost 0.0016579088
x [0.98016894 0.96062785 0.9225483  0.85058093 0.7227015 ] cost 0.015182137
x [0.99998474 0.9999649  0.9999118  0.99975204 0.99922055] cost 4.329831e-06
x [0.99999994 0.9999999  0.99999976 0.9999996  0.99999917] cost 3.0198066e-13
x [0.9975235  0.99504036 0.99007386 0.98018396 0.96068937] cost 0.00026147277
x [-0.9606199   0.93296415  0.8754156   0.7681596   0.5888228 ] cost 1.9655604
x [0.9992161 0.9984289 0.9968506 0.9936918 0.9873958] cost 2.6459566e-05
x [-0.9511946   0.9148273   0.8413313   0.70765203  0.49441844] cost 1.9705827
